Skip to content

Instantly share code, notes, and snippets.

@tylerdrewwork
Created July 1, 2024 14:18
Show Gist options
  • Select an option

  • Save tylerdrewwork/9f3131edbdbafe501028876446ace90d to your computer and use it in GitHub Desktop.

Select an option

Save tylerdrewwork/9f3131edbdbafe501028876446ace90d to your computer and use it in GitHub Desktop.
using System.Collections;
using System.Collections.Generic;
using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.Sensors;
using Unity.MLAgents.Actuators;
public class RunFast : Agent
{
[SerializeField] private Transform targetTransform;
[SerializeField] private bool isRunner = false; // Flag to identify if this agent is the Runner
public override void OnEpisodeBegin()
{
// Reset the environment at the beginning of each episode
if (isRunner)
{
transform.localPosition = new Vector3(13f, 0.6f, 2f);
}
else
{
transform.localPosition = Vector3.zero;
}
}
public override void CollectObservations(VectorSensor sensor)
{
// Collect observations for the agent
sensor.AddObservation(transform.localPosition);
sensor.AddObservation(targetTransform.localPosition);
}
public override void OnActionReceived(ActionBuffers actions)
{
// Handle the actions received
float moveX = actions.ContinuousActions[0];
float moveZ = actions.ContinuousActions[1];
float moveY = actions.ContinuousActions[2];
float moveSpeed = 2f;
transform.localPosition += new Vector3(moveX, moveY, moveZ) * Time.deltaTime * moveSpeed;
// Reward the agent (implementation needed)
}
private void OnTriggerEnter(Collider other)
{
if (isRunner)
{
// Logic for Runner agent
if (other.TryGetComponent<Treasure>(out Treasure treasure))
{
SetReward(+1f);
Debug.Log("Gained a Point")
EndEpisode();
}
if (other.TryGetComponent<Goal>(out Goal goal))
{
SetReward(-1f);
EndEpisode();
}
}
}
else
{
// Logic for Chaser agent
if (other.TryGetComponent<Goal>(out Goal goal))
{
SetReward(+1f);
EndEpisode();
}
if (other.TryGetComponent<Wall>(out Wall wall))
{
SetReward(-1f);
EndEpisode();
}
}
}
public override void Heuristic(in ActionBuffers actionsOut)
{
ActionSegment<float> continuousActions = actionsOut.ContinuousActions;
continuousActions[0] = Input.GetAxisRaw("Horizontal");
continuousActions[1] = Input.GetAxisRaw("Vertical");
continuousActions[2] = Input.GetAxisRaw("Up");
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment