본문 바로가기
TIL

[ML-Agent를 활용한 머신러닝] Soccer

by imagineer_jinny 2021. 8. 17.

 

 

만들고자 하는 것

ML-Agent로 플레이어를 학습시켜서 스스로 축구를 하게 하는 게임을 만든다.

 

기본 세팅

SoccerField, Soccer Ball, Agent Blue팀 만들기 

 

 

배운 것

1. Branch 이용해서 키보드 W,A,S,D,Q,E 키로 플레이어 움직임 조작하기

-꼭 있어야 할 스크립트와 조건들 체크 한번 다시 하기

 

Decision Requester 추가

Discreate Branch 개수 정하고 각각 개수에 대한 사이즈 입력해주기

 

스크립트

 public override void OnActionReceived(ActionBuffers actions)
    {
        var action = actions.DiscreteActions;
        //Debug.Log($"[0]={action[0]}, [1]={action[1]}, [2]={action[2]}");

        Vector3 dir = Vector3.zero;
        Vector3 rot = Vector3.zero;

        int forward = action[0]; //(0,1,2)
        int right = action[1]; //(0,1,2)
        int rotate = action[2]; //(0,1,2)

        switch(forward)
        {
            case 1:dir = tr.forward; break;
            case 2:dir = -tr.forward; break;
            
        }

        switch(right)
        {
            case 1: dir = -tr.right; break;
            case 2: dir = tr.right; break;
        }

        switch(rotate)
        {
            case 1: rot = -tr.up; break;
            case 2: rot = tr.up; break;
        }

        tr.Rotate(rot, Time.fixedDeltaTime * 100.0f);
        rb.AddForce(dir * moveSpeed, ForceMode.VelocityChange);


    }
    
    public override void Heuristic(in ActionBuffers actionsOut)
    {
        /*
         * Branch 분리하자
         
         전후이동처리 : 정지, 전진, 후진 (0,1,2)
         좌우이동처리: 정지, 왼, 오 (0,1,2)
         회전처리: 정지, 왼쪽, 오른쪽 (0,1,2)
         
         */

        var actions = actionsOut.DiscreteActions;
        actions.Clear();

        //브랜치 하나만 쓰니까 인덱스 다 0번인거 알지?

        //Branch 0 이동: 정지/전진/후진 = 0,1,2
        if (Input.GetKey(KeyCode.W)) actions[0] = 1;
        if (Input.GetKey(KeyCode.S)) actions[0] = 2;

        //Branch 1 이동: 정지/좌이동/우이동 =0,1,2
        if (Input.GetKey(KeyCode.Q)) actions[1] = 1;
        if (Input.GetKey(KeyCode.E)) actions[1] = 2;

        //Branch 2 회전: 정지/왼쪽회전/오른쪽회전 = 0,1,2
        if (Input.GetKey(KeyCode.A)) actions[2] = 1;
        if (Input.GetKey(KeyCode.D)) actions[2] = 2;

    }

 

2. 골대에 콜라이더 설치하고 볼 들어가면 스코어 조절하기

- 레드팀 골대에 공 들어가면 블루팀 +1 이런식으로

private void OnCollisionEnter(Collision collision)
    {
        if(collision.collider.CompareTag("ball"))
        {
            //Ball Touch ==> +Reward
            AddReward(0.2f);

            //Ball Kick 
            //첫번째 맞은 지점 -
            Vector3 kickDir = collision.GetContact(0).point - tr.position;
            collision.gameObject.GetComponent<Rigidbody>().AddForce(kickDir * kickForce);
        }
    }

 

3. Behavior Parameters란?

Class BehaviorParameters

A component for setting an Agent instance's behavior and brain properties.

 

 

코드

PlayerAgent

using System.Collections;
using System.Collections.Generic;
using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.Policies; //Behavior Parameter 접근 위해 선언
using Unity.MLAgents.Actuators;

public class PlayerAgent : Agent
{
    public enum TEAM
    {
        BLUE, RED
    }

    public TEAM team = TEAM.BLUE;
    public Material[] materials;

    private BehaviorParameters bps;
    private Transform tr;
    private Rigidbody rb;

    //Player's Init position, rotation
    private Vector3 initBluePos = new Vector3(-5.5f, 0.5f, 0.0f);
    private Vector3 initRedPos = new Vector3(5.5f, 0.5f, 0.0f);
    private Quaternion initBlueRot = Quaternion.Euler(Vector3.up * 90.0f);
    private Quaternion initRedRot = Quaternion.Euler(-Vector3.up * 90.0f);

    public float moveSpeed = 1.0f;
    public float kickForce = 800.0f;

    void InitPlayer()
    {
        tr.localPosition = (team == TEAM.BLUE) ? initBluePos : initRedPos;
        tr.localRotation = (team == TEAM.BLUE) ? initBlueRot : initRedRot;
    }

    public override void Initialize()
    {
        bps = GetComponent<BehaviorParameters>();
        tr = GetComponent<Transform>();
        rb = GetComponent<Rigidbody>();

        //Setting Team ID
        bps.TeamId = (int)team;
        //Setting Team Color
        GetComponent<Renderer>().material = materials[bps.TeamId];

        //Setting Step
        MaxStep = 10000;

    }

    public override void OnEpisodeBegin()
    {
        InitPlayer();
        //Init Physics
        rb.velocity = rb.angularVelocity = Vector3.zero;
    }

    public override void OnActionReceived(ActionBuffers actions)
    {
        var action = actions.DiscreteActions;
        //Debug.Log($"[0]={action[0]}, [1]={action[1]}, [2]={action[2]}");

        Vector3 dir = Vector3.zero;
        Vector3 rot = Vector3.zero;

        int forward = action[0]; //(0,1,2)
        int right = action[1]; //(0,1,2)
        int rotate = action[2]; //(0,1,2)

        switch(forward)
        {
            case 1:dir = tr.forward; break;
            case 2:dir = -tr.forward; break;
            
        }

        switch(right)
        {
            case 1: dir = -tr.right; break;
            case 2: dir = tr.right; break;
        }

        switch(rotate)
        {
            case 1: rot = -tr.up; break;
            case 2: rot = tr.up; break;
        }

        tr.Rotate(rot, Time.fixedDeltaTime * 100.0f);
        rb.AddForce(dir * moveSpeed, ForceMode.VelocityChange);


    }

    public override void Heuristic(in ActionBuffers actionsOut)
    {
        /*
         * Branch 분리하자
         
         전후이동처리 : 정지, 전진, 후진 (0,1,2)
         좌우이동처리: 정지, 왼, 오 (0,1,2)
         회전처리: 정지, 왼쪽, 오른쪽 (0,1,2)
         
         */

        var actions = actionsOut.DiscreteActions;
        actions.Clear();

        //브랜치 하나만 쓰니까 인덱스 다 0번인거 알지?

        //Branch 0 이동: 정지/전진/후진 = 0,1,2
        if (Input.GetKey(KeyCode.W)) actions[0] = 1;
        if (Input.GetKey(KeyCode.S)) actions[0] = 2;

        //Branch 1 이동: 정지/좌이동/우이동 =0,1,2
        if (Input.GetKey(KeyCode.Q)) actions[1] = 1;
        if (Input.GetKey(KeyCode.E)) actions[1] = 2;

        //Branch 2 회전: 정지/왼쪽회전/오른쪽회전 = 0,1,2
        if (Input.GetKey(KeyCode.A)) actions[2] = 1;
        if (Input.GetKey(KeyCode.D)) actions[2] = 2;

    }

    private void OnCollisionEnter(Collision collision)
    {
        if(collision.collider.CompareTag("ball"))
        {
            //Ball Touch ==> +Reward
            AddReward(0.2f);

            //Ball Kick 
            //첫번째 맞은 지점 -
            Vector3 kickDir = collision.GetContact(0).point - tr.position;
            collision.gameObject.GetComponent<Rigidbody>().AddForce(kickDir * kickForce);
        }
    }
}

 

BallControl

using System.Collections;
using System.Collections.Generic;
using UnityEngine;
using Unity.MLAgents;
using TMPro;

public class BallControl : MonoBehaviour
{
    public Agent[] players;
    private Rigidbody rb;

    public TMP_Text blueScoreText, redScoreText;
    private int blueScore, redScore;

    private void Start()
    {
        rb = GetComponent<Rigidbody>();

    }

    void InitBall()
    {
        rb.velocity = rb.angularVelocity = Vector3.zero;
        transform.localPosition = new Vector3(0.0f, 1.5f, 0.0f);

    }

    private void OnCollisionEnter(Collision collision)
    {
        if(collision.collider.CompareTag("BLUE_GOAL"))
        {
            //RED TEAM +1 REWARD
            players[1].AddReward(+1.0f);
            //BLUE TEAM -1 REWARD
            players[0].AddReward(-1.0f);

            //Init Ball 공 초기화
            InitBall();

            //Player Init
            players[0].EndEpisode();
            players[1].EndEpisode();

            //Red Team Score+1
            redScoreText.text = (++redScore).ToString();
        }

        if (collision.collider.CompareTag("RED_GOAL"))
        {
            //RED TEAM -1 REWARD
            players[1].AddReward(-1.0f);
            //BLUE TEAM +1 REWARD
            players[0].AddReward(+1.0f);

            //Init Ball 공 초기화
            InitBall();

            //Player Init
            players[0].EndEpisode();
            players[1].EndEpisode();

            //Blue Team Score+1
            blueScoreText.text = (++blueScore).ToString();
        }
    }

}

댓글