본문 바로가기
TIL

[ML-Agent를 활용한 머신러닝] MummyBasic

by imagineer_jinny 2021. 8. 10.

 

 

개발환경

Python 3.7.9 , ML-Agents, Git , Unity 2020.3.12f

 

만들고자 하는 것

ML-Agent를 활용하여 Mummy가 Target에 닿으면 보상을 주고 DeadZone에 닿으면 벌을 준다.

학습을 시켜서 Target에 닿는 빈도수가 높아지게 한다.

 

배운 것

1. Mummy가 Target, DeadZone에 충돌했을 때 처리: Tag, Collider, RigidBody

-충돌 처리 되려면 타겟이나 충돌시키는대상이나 둘중 하나에 RigidBody가 있어야 한다.

여기서는 Mummy에 Rigidbody와 Capsule Collider을 주고 Target, DeadZone에 각각 Box Collider(큐브에 원래있던것)와 태그를 달아줌(TARGET, DEADZONE) 

(이제 이런 비슷한거 구현 못하면 아주아주 문제있는거다!!!)

 

2. Target에 닿으면 Floor 초록색, DeadZone에 닿으면 빨간색 (코루틴으로 구현해보기)

- 뭐가 필요한지도 헷갈리고 변수를 잘 못쓰겠음...

- Floor 색 바꾸려면 Floor의 material에 접근해야함.

- goodMt, badMt 말고도 원래 색을 담을 변수와 바뀔 색 변수 필요함.

changeMt 안에 goodMt, badMt가 왔다갔다 하겠지.

 

3. Episode : 머신러닝에서 학습의 단위, '즉 학습이 될때마다 호출되는 함수' 에서는 뭘 해야할지 생각해봐.

학습이 될때마다 라는 말은 곧 학습이 처음부터 끝까지 이루어지고 계속 반복되는거니까

리셋 개념인 것 같다. 

리셋될때마다 위치(Mummy, Target) 랜덤으로 다시 받아와야 하고 물리력, 회전값도 초기화해야한다.

 

4. 주변 환경을 관측하는 것은 Brain(Pytorch, TF)사용

 

5. 키값으로 전후좌우 움직이기

 

 

몰랐던 것

에피소드, 브레인 개념

 

 

느낀 것

건방지게 2D, 3D 유니티 프로젝트 2주에 걸쳐서 수업 했던걸 다 아는거야~ 하면서 안들었던 것.

아주아주아주 많이 후회한다!!!!!

지나간 과거는 어쩔 수 없으니 앞으로 남은 수업들은 무조건 열심히 참여하자. 뭐든 남는다.

 

 

코드

using System.Collections;
using System.Collections.Generic;
using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.Sensors;
using Unity.MLAgents.Actuators;

public class BasicAgent : Agent
{
    /*
     1. 주변환경을 관측(Observations)  --> Brain(Pytorch, TF)
     2. 정책(Policy)에 의해서 행동(Actions)
     3. 보상(Reward)
    
       관측정보
    1. 자신의 위치
    2. 타겟의 위치
    3. Rigidbody Velocity
     */

    private Transform tr;
    private Transform targetTr;
    private Rigidbody rb;

    public Material goodMt;
    public Material badMt;
    private Material originMt;

    private Renderer floorRd;


    //초기화 작업을 위한 메소드
    public override void Initialize()
    {
        tr = GetComponent<Transform>();
        targetTr = tr.parent.Find("Target").GetComponent<Transform>();

        rb = GetComponent<Rigidbody>();

        floorRd = tr.parent.Find("Floor").GetComponent<Renderer>();
        originMt = floorRd.material;
    }

    IEnumerator RevertMaterial(Material changedMt)
    {
        floorRd.material = changedMt;
        yield return new WaitForSeconds(0.2f);

        floorRd.material = originMt;

    }
   

    //학습이 될때마다 호출되는 함수
    //머신러닝에서 학습의 단위는 에피소드라고 함.
    public override void OnEpisodeBegin()
    {
        //물리력 초기화
        rb.velocity = Vector3.zero;
        //회전값 초기화
        rb.angularVelocity = Vector3.zero;

        //에이전의 위치를 불규칙하게 변경
        tr.localPosition = new Vector3(Random.Range(-4.0f,+4.0f),0.05f,Random.Range(-4.0f,+4.0f));

        //타겟의 위치도 변경
        targetTr.localPosition= new Vector3(Random.Range(-4.0f, +4.0f), 0.55f, Random.Range(-4.0f, +4.0f));

    }

    
    //주변환경을 관측해서 수집 브레인에 전달하는 메소드
    public override void CollectObservations(VectorSensor sensor)
    {
        sensor.AddObservation(tr.localPosition); //Vector3 (x,y,3) => 3개 (데이터 수)
        sensor.AddObservation(targetTr.localPosition); //Vector3 (x,y,3) => 3개

        sensor.AddObservation(rb.velocity.x); // =>1개
        sensor.AddObservation(rb.velocity.z); // =>1개
    }

    //정책에 의해 결정된 명령을 실행하는 메소드
    public override void OnActionReceived(ActionBuffers actions)
    {
        var action = actions.ContinuousActions;
        //Debug.Log($"[0]={action[0]},[1] ={ action[1]}");

        //[0] UP, Down
        //[1] Left, Right
        Vector3 dir = (Vector3.forward * action[0]) + (Vector3.right * action[1]);
        rb.AddForce(dir.normalized * 50.0f);

        //마이너스 패널티
        SetReward(-0.001f);

    }

    //개발자가 테스트 용도로 활용하는 메소드, 모방학습을 위한 데모 파일을 생성할 때 사용하는 메소드
    public override void Heuristic(in ActionBuffers actionsOut)
    {
        // 전진/후진 Input.GetAxis("Vertical")
        //좌/우 Input.GetAxis("Horizontal")

        var actions = actionsOut.ContinuousActions;
        actions[0] = Input.GetAxis("Vertical");
        actions[1] = Input.GetAxis("Horizontal");
    }

    private void OnCollisionEnter(Collision collision)
    {
        if(collision.collider.CompareTag("DEAD_ZONE"))
        {
            SetReward(-1.0f);
            EndEpisode();
            StartCoroutine(RevertMaterial(badMt));

        }

        if(collision.collider.CompareTag("TARGET"))
        {
            SetReward(+1.0f);
            EndEpisode();
            StartCoroutine(RevertMaterial(goodMt));
        }
    }

}

 

 

 

댓글