Created
July 1, 2024 14:18
-
-
Save tylerdrewwork/9f3131edbdbafe501028876446ace90d to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| using System.Collections; | |
| using System.Collections.Generic; | |
| using UnityEngine; | |
| using Unity.MLAgents; | |
| using Unity.MLAgents.Sensors; | |
| using Unity.MLAgents.Actuators; | |
| public class RunFast : Agent | |
| { | |
| [SerializeField] private Transform targetTransform; | |
| [SerializeField] private bool isRunner = false; // Flag to identify if this agent is the Runner | |
| public override void OnEpisodeBegin() | |
| { | |
| // Reset the environment at the beginning of each episode | |
| if (isRunner) | |
| { | |
| transform.localPosition = new Vector3(13f, 0.6f, 2f); | |
| } | |
| else | |
| { | |
| transform.localPosition = Vector3.zero; | |
| } | |
| } | |
| public override void CollectObservations(VectorSensor sensor) | |
| { | |
| // Collect observations for the agent | |
| sensor.AddObservation(transform.localPosition); | |
| sensor.AddObservation(targetTransform.localPosition); | |
| } | |
| public override void OnActionReceived(ActionBuffers actions) | |
| { | |
| // Handle the actions received | |
| float moveX = actions.ContinuousActions[0]; | |
| float moveZ = actions.ContinuousActions[1]; | |
| float moveY = actions.ContinuousActions[2]; | |
| float moveSpeed = 2f; | |
| transform.localPosition += new Vector3(moveX, moveY, moveZ) * Time.deltaTime * moveSpeed; | |
| // Reward the agent (implementation needed) | |
| } | |
| private void OnTriggerEnter(Collider other) | |
| { | |
| if (isRunner) | |
| { | |
| // Logic for Runner agent | |
| if (other.TryGetComponent<Treasure>(out Treasure treasure)) | |
| { | |
| SetReward(+1f); | |
| Debug.Log("Gained a Point") | |
| EndEpisode(); | |
| } | |
| if (other.TryGetComponent<Goal>(out Goal goal)) | |
| { | |
| SetReward(-1f); | |
| EndEpisode(); | |
| } | |
| } | |
| } | |
| else | |
| { | |
| // Logic for Chaser agent | |
| if (other.TryGetComponent<Goal>(out Goal goal)) | |
| { | |
| SetReward(+1f); | |
| EndEpisode(); | |
| } | |
| if (other.TryGetComponent<Wall>(out Wall wall)) | |
| { | |
| SetReward(-1f); | |
| EndEpisode(); | |
| } | |
| } | |
| } | |
| public override void Heuristic(in ActionBuffers actionsOut) | |
| { | |
| ActionSegment<float> continuousActions = actionsOut.ContinuousActions; | |
| continuousActions[0] = Input.GetAxisRaw("Horizontal"); | |
| continuousActions[1] = Input.GetAxisRaw("Vertical"); | |
| continuousActions[2] = Input.GetAxisRaw("Up"); | |
| } | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment