> 文章列表 > Unity-ML-Agents-代码解读-Making a New Learning Environment-RollerBall

Unity-ML-Agents-代码解读-Making a New Learning Environment-RollerBall

Unity-ML-Agents-代码解读-Making a New Learning Environment-RollerBall

使用版本:https://github.com/Unity-Technologies/ml-agents/releases/tag/release_19

文件路径:ml-agents-release_19/docs/Learning-Environment-Create-New.md

本文只涉及代码解析,对于unity的对应操作请看相应文档

目录

1. Initialization and Resetting the Agent

1.1 代码总括

1.2 代码分解

1.2.1 导入包和命名空间

1.2.2 this.transform.localPosition.y < 0

1.2.3 this.transform.localPosition = new Vector3( 0, 0.5f, 0)

1.2.4 Target.localPosition = new Vector3(Random.value * 8 - 4, 0.5f, Random.value * 8 - 4); }

2. Observing the Environment

2.1 代码总括

2.2 代码分解

2.2.1 public override void CollectObservations(VectorSensor sensor)

2.2.2 sensor.AddObservation(Target.localPosition)

2.2.3 sensor.AddObservation(this.transform.localPosition)这一行代码将 

2.2.4 sensor.AddObservation(rBody.velocity.x)

2.2.5 sensor.AddObservation(rBody.velocity.z)

3. Taking Actions and Assigning Rewards

3.1 代码总括

3.2 代码分解

3.2.1 public float forceMultiplier = 10

3.2.2 public override void OnActionReceived(ActionBuffers actionBuffers)

3.2.3 controlSignal.x = actionBuffers.ContinuousActions[0]

3.2.4 float distanceToTarget = Vector3.Distance(this.transform.localPosition, Target.localPosition)


1. Initialization and Resetting the Agent

1.1 代码总括

using System.Collections.Generic; // 导入命名空间 System.Collections.Generic
using UnityEngine; // 导入 Unity 引擎的命名空间
using Unity.MLAgents; // 导入 Unity 引擎的 Machine Learning Agents(ML-Agents)的命名空间
using Unity.MLAgents.Sensors; // 导入 Unity 引擎的 ML-Agents 的 Sensors 的命名空间public class RollerAgent : Agent // 声明一个 RollerAgent 类,继承自 Agent 类
{Rigidbody rBody; // 声明一个 Rigidbody 类型的变量 rBodyvoid Start () { // 声明一个 Start 方法,当脚本启动时会执行该方法rBody = GetComponent<Rigidbody>(); // 获取当前 GameObject 的 Rigidbody 组件}public Transform Target; // 声明一个 Transform 类型的公共变量 Target,用于储存目标物体的位置信息public override void OnEpisodeBegin() // 重写 OnEpisodeBegin 方法{// 如果智能体掉落,则将其动量归零if (this.transform.localPosition.y < 0){this.rBody.angularVelocity = Vector3.zero;this.rBody.velocity = Vector3.zero;this.transform.localPosition = new Vector3( 0, 0.5f, 0); // 将智能体放在起点}// 将目标物体移动到新的位置Target.localPosition = new Vector3(Random.value * 8 - 4, 0.5f, Random.value * 8 - 4);}
}

这段代码是一个简单的机器学习智能体(RollerAgent),用于学习如何控制一个球体移动到场景中随机位置的目标点。其中,using 语句用于导入所需的命名空间,public 关键字用于声明公共变量,Start 方法是 Unity 引擎中的一个特殊方法,当脚本启动时会自动执行,OnEpisodeBegin 方法则是重写了 Agent 类的该方法,用于在每个新的 episode(一次训练)开始时执行。

1.2 代码分解

1.2.1 导入包和命名空间

using System.Collections.Generic; // 导入命名空间 System.Collections.Generic
using UnityEngine; // 导入 Unity 引擎的命名空间
using Unity.MLAgents; // 导入 Unity 引擎的 Machine Learning Agents(ML-Agents)的命名空间
using Unity.MLAgents.Sensors // 导入 Unity 引擎的 ML-Agents 的 Sensors 的命名空间

using System.Collections.Generic 是一个 C# 中的命名空间,用于提供实现各种通用集合数据结构的类型。在使用 C# 中的 List、Dictionary 等数据结构时需要使用该命名空间。例如,当需要使用 List<T> 类型时,需要在代码文件的头部使用 using System.Collections.Generic; 声明该命名空间。

using UnityEngine 是 C# 编程语言中的一个关键字,用于引入 Unity 引擎的命名空间,使得在代码中可以直接使用 Unity 引擎提供的类和方法。这些类和方法包括用于创建游戏对象、控制游戏对象的位置、旋转、缩放、碰撞检测等等,以及访问 Unity 引擎中的其他功能。

Unity.MLAgents 是 Unity 官方提供的一个工具包,用于支持在 Unity 中训练机器学习模型,尤其是用于支持强化学习算法的训练。其中包含了许多用于强化学习的组件和类。通过使用这个工具包,开发者可以快速地在 Unity 中构建强化学习场景,训练自己的智能体,以及评估和部署训练好的模型。

using Unity.MLAgents.Sensors 引入了用于创建智能体感知信息的 ML-Agents 的 Sensor 类和相关组件。感知信息由智能体的观测(observation)和其他一些信息(如reward等)组成,用于训练智能体进行决策。该命名空间包含了许多不同类型的Sensor,例如VisualSensor和RayPerceptionSensor等,这些Sensor可用于感知不同类型的环境。

1.2.2 this.transform.localPosition.y < 0

this.transform.localPosition.y < 0

这段代码用于检测智能体在环境中是否跌落。具体地,它检查智能体的本地坐标(local position)的y值是否小于0,如果是,说明智能体已经跌落到地面以下。

1.2.3 this.transform.localPosition = new Vector3( 0, 0.5f, 0)

this.transform.localPosition = new Vector3( 0, 0.5f, 0)

this.transform.localPosition = new Vector3( 0, 0.5f, 0) 是将当前游戏对象的本地坐标设置为 (0, 0.5, 0)。其中 localPosition 是相对于父物体坐标系的位置,因此将其设置为 (0, 0.5, 0) 会将游戏对象相对于其父物体向上移动 0.5 个单位。

1.2.4 Target.localPosition = new Vector3(Random.value * 8 - 4, 0.5f, Random.value * 8 - 4); }

Target.localPosition = new Vector3(Random.value * 8 - 4, 0.5f, Random.value * 8 - 4);}

这行代码的作用是在每个Episode(一次训练周期)开始时将Target的位置设置为一个新的随机位置。具体来说,它使用Random.value生成一个介于0和1之间的随机值,并将其乘以8。然后,从这个结果中减去4,以使得结果介于-4和4之间。这个结果被用作新位置的X和Z坐标,而Y坐标设置为0.5f。因此,这行代码将随机移动目标的位置,从而使智能体必须适应不同的目标位置来完成任务。

2. Observing the Environment

2.1 代码总括

public override void CollectObservations(VectorSensor sensor)
{// 目标点和智能体位置sensor.AddObservation(Target.localPosition);sensor.AddObservation(this.transform.localPosition);// 智能体速度sensor.AddObservation(rBody.velocity.x);sensor.AddObservation(rBody.velocity.z);
}
  • public override void CollectObservations(VectorSensor sensor): 重写了父类 Agent 中的 CollectObservations 方法,该方法用于收集智能体的观测数据VectorSensor 类型的参数 sensor 用于存储观测数据。

  • sensor.AddObservation(Target.localPosition): 将目标点的位置作为一个观测数据添加到 sensor 中。

  • sensor.AddObservation(this.transform.localPosition): 将智能体的位置作为一个观测数据添加到 sensor 中。

  • sensor.AddObservation(rBody.velocity.x): 将智能体在 x 轴方向的速度作为一个观测数据添加到 sensor 中。

  • sensor.AddObservation(rBody.velocity.z): 将智能体在 z 轴方向的速度作为一个观测数据添加到 sensor 中。

2.2 代码分解

2.2.1 public override void CollectObservations(VectorSensor sensor)

public override void CollectObservations(VectorSensor sensor)

public override void CollectObservations(VectorSensor sensor) 是一个重写了父类 Agent 中的 CollectObservations 方法的公共方法。在这个方法中,我们可以定义 Agent 用于学习的所有观察。这些观察将成为 Agent 的状态信息,并用于在训练过程中决策其行动。

VectorSensor sensor 参数是一个用于添加观察的 VectorSensor 实例。

这个方法的主要目的是向传感器添加观察。这个方法的实现可以根据具体的场景和需求而有所不同。

2.2.2 sensor.AddObservation(Target.localPosition)

sensor.AddObservation(Target.localPosition)

这行代码的作用是将 RollerAgent 脚本中定义的 Target(Unity Transform 组件)的本地位置添加到传感器观察中。在使用机器学习代理训练时,传感器负责向代理提供状态信息,这些信息可以是代理需要了解的任何东西,例如环境中其他物体的位置、速度、颜色等等。通过将目标的位置添加到传感器观察中,代理就可以了解到目标的位置。 

2.2.3 sensor.AddObservation(this.transform.localPosition)这一行代码将 

sensor.AddObservation(this.transform.localPosition);

这一行代码将 RollerAgent 自身的位置信息添加到 Sensor 中,具体来说,会将该 Agent 在 x、y、z 轴上的本地坐标分别作为 3 个观测值加入 Sensor 中,以供智能体在后续的训练中学习和使用。 

2.2.4 sensor.AddObservation(rBody.velocity.x)

sensor.AddObservation(rBody.velocity.x)

这一行代码将 RollerAgent 的刚体在 x 轴上的速度信息添加到 Sensor 中,用于作为智能体的观测值。具体来说,它将 RollerAgent 的刚体在 x 轴上的速度值作为一个观测值加入到 Sensor 中,以供智能体在后续的训练中学习和使用。

2.2.5 sensor.AddObservation(rBody.velocity.z)

sensor.AddObservation(rBody.velocity.z)

 这一行代码将 RollerAgent 的刚体在 z 轴上的速度信息添加到 Sensor 中,用于作为智能体的观测值。具体来说,它将 RollerAgent 的刚体在 z 轴上的速度值作为一个观测值加入到 Sensor 中,以供智能体在后续的训练中学习和使用。

3. Taking Actions and Assigning Rewards

3.1 代码总括

public float forceMultiplier = 10; // 添加一个 public 变量 forceMultiplier,用于调节施加的力的倍数
public override void OnActionReceived(ActionBuffers actionBuffers) // Override 基类 Agent 的 OnActionReceived 方法
{// Actions, size = 2Vector3 controlSignal = Vector3.zero; // 创建一个大小为 0 的 Vector3 变量 controlSignalcontrolSignal.x = actionBuffers.ContinuousActions[0]; // 将连续控制输入的第一个元素赋值给 controlSignal 的 x 分量controlSignal.z = actionBuffers.ContinuousActions[1]; // 将连续控制输入的第二个元素赋值给 controlSignal 的 z 分量rBody.AddForce(controlSignal * forceMultiplier); // 施加力,乘以 forceMultiplier 调节施加力的大小// Rewardsfloat distanceToTarget = Vector3.Distance(this.transform.localPosition, Target.localPosition); // 计算 Agent 与 Target 之间的距离// Reached targetif (distanceToTarget < 1.42f) // 如果 Agent 到达了 Target 的范围内(1.42f 是指一个单位立方体的对角线长度){SetReward(1.0f); // 奖励 AgentEndEpisode(); // 结束 episode}// Fell off platformelse if (this.transform.localPosition.y < 0) // 如果 Agent 掉下了平台{EndEpisode(); // 结束 episode}
}

这段代码是 RollerAgent 类的 OnActionReceived 方法,主要是处理智能体的行动和奖励等逻辑。在该方法中,首先解析智能体的行动,将其转换成控制信号,然后通过给定的力量系数将其施加在 RollerAgent 的刚体上。然后,根据智能体距离目标的距离计算奖励。如果智能体到达目标位置,则奖励为 1,并结束该 episode;如果 RollerAgent 掉落出平台,则直接结束该 episode。

为了解决向目标移动的任务,Agent (Sphere)需要能够在x和z方向上移动。因此,代理需要两个动作:第一个决定沿x轴施加的力;第二个决定了作用在z轴上的力。(如果我们允许Agent在三维空间中移动,那么我们就需要第三个动作。)

3.2 代码分解

3.2.1 public float forceMultiplier = 10

public float forceMultiplier = 10

这段代码定义了一个名为forceMultiplier的公共变量,并将其初始化为10。该变量的作用是控制代理施加的力的强度。

3.2.2 public override void OnActionReceived(ActionBuffers actionBuffers)

public override void OnActionReceived(ActionBuffers actionBuffers)

OnActionReceived 是 Agent 类的方法,表示 Agent 接收到一个 action 后需要执行的操作。在这个 RollerAgent 脚本中,它用于接收动作向量,并对其进行处理。

参数 actionBuffers 是一个 ActionBuffers 对象,代表着一个动作向量。此处的动作向量是一个连续的向量,包含两个浮点数,分别表示在 x 轴和 z 轴上的控制力。

因此,该函数的主要功能是对 actionBuffers 进行解析,并将其转化为 controlSignal 控制力,再将该力传递给物理引擎,以使得 RollerAgent 在场景中移动。

3.2.3 controlSignal.x = actionBuffers.ContinuousActions[0]

    Vector3 controlSignal = Vector3.zero; // 创建一个大小为 0 的 Vector3 变量 controlSignalcontrolSignal.x = actionBuffers.ContinuousActions[0]; // 将连续控制输入的第一个元素赋值给 controlSignal 的 x 分量controlSignal.z = actionBuffers.ContinuousActions[1]; // 将连续控制输入的第二个元素赋值给 controlSignal 的 z 分量rBody.AddForce(controlSignal * forceMultiplier); // 施加力,乘以 forceMultiplier 调节施加力的大小

RollerAgent使用Rigidbody. addforce()将action[]数组中的值应用到它的Rigidbody组件rBody上:

这行代码定义了一个名为controlSignal的Vector3类型变量,初始值为(0,0,0),用于存储代理(Agent)的操作信号

这行代码将第一个连续动作的值设置为控制信号向量controlSignal的x分量。actionBuffers.ContinuousActions[0]获取代理的第一个连续动作的值。在这个RollerBall例子中,代理有两个连续动作,每个动作的值范围在-1到1之间,代表x和z轴上的力量。

这行代码设置了控制信号的z轴分量,它的值从ActionBuffers中获取,index为1,代表模型的前后运动方向。具体来说,模型根据这个分量来决定是向前还是向后运动。

这行代码是将计算出来的控制向量controlSignal乘上一个forceMultiplier系数,然后通过调用AddForce函数对物体施加一个力的作用,从而实现对RollerAgent物体的控制。在这个环境下,RollerAgent物体被控制的目标是移动到一个随机移动的目标位置。

注意,forcMultiplier 类变量是在方法定义之前定义的。因为 forcMultiplier 是公共的,所以您可以设置检查器窗口中的值。

3.2.4 float distanceToTarget = Vector3.Distance(this.transform.localPosition, Target.localPosition)

float distanceToTarget = Vector3.Distance(this.transform.localPosition, Target.localPosition)

这段代码计算当前Agent和目标物体之间的距离,使用的是Vector3.Distance()方法,它返回两个向量之间的欧几里得距离(也称为长度)

具体地说,它使用当前Agent位置this.transform.localPosition和目标物体位置Target.localPosition作为两个向量的参数,计算它们之间的距离,并将结果赋值给变量distanceToTarget