Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
Size: Mime:
using System;
using System.ComponentModel;
using Fluctio.FluctioSim.Common.Configuration;
using Fluctio.FluctioSim.Core.Components.Base;
using Fluctio.FluctioSim.Core.Components.MachineLearning.TriggerActions;
using Fluctio.FluctioSim.Core.Components.MachineLearning.Triggers;
using Fluctio.FluctioSim.Utils.Extensions;
using Fluctio.FluctioSim.Utils.General;
using Fluctio.FluctioSim.Utils.SerializableClasses;
using Mujoco;
using Unity.MLAgents;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Policies;
using Unity.InferenceEngine;
using UnityEngine;
using UnityBehaviorType = Unity.MLAgents.Policies.BehaviorType;

namespace Fluctio.FluctioSim.Core.Components.MachineLearning.Agents
{
	[AddComponentMenu(Config.PrefixedName+"/Machine Learning/Agent", Config.ComponentMenuOrder + 20)]
	[DisallowMultipleComponent]
	// TODO: support multiagent scenarios
	public class Agent : InternalsEditorComponent
	{
		[field: SerializeField] public string BehaviorName { get; private set; } = "";
		[field: SerializeField] private SerializableTimeSpan TimeLimit { get; set; } = TimeSpan.FromSeconds(20).MakeSerializable();
		[field: SerializeField] private SerializableTimeSpan DecisionsPeriod { get; set; } = TimeSpan.FromSeconds(0.02).MakeSerializable();
		[field: SerializeField] public bool RecommendTimeStep { get; set; } = false;
		[field: SerializeField] public float RecommendedTimeStep { get; set; } = 0.002f;
		[field: SerializeField] public AgentBehaviorType BehaviorType { get; set; } = AgentBehaviorType.TrainingOrManual;
		[field: SerializeField] public ModelAsset Model { get; set; } = null;
		[field: SerializeField, HideInInspector] private float lastPhysicsTimeStep = -1;

		[field: SerializeField, HideInInspector] public BehaviorParameters BehaviorParameters { get; private set; }
		[field: SerializeField, HideInInspector] public ProxyMlAgent UnityAgent { get; private set; }
		[field: SerializeField, HideInInspector] public DecisionRequester DecisionRequester { get; private set; }
		[field: SerializeField, HideInInspector] public MujocoChainReseter Resetter { get; private set; }

		public override void InitializePrefab()
		{
			if (BehaviorName == "")
			{
				BehaviorName = NameGenerator.GenerateName();	
			}
		}
		
		protected override void InitializeOnce()
		{
			base.InitializeOnce();
			
			BehaviorParameters = CreateInternalComponent<BehaviorParameters>(gameObject);
			UnityAgent = CreateInternalComponent<ProxyMlAgent>(gameObject);
			DecisionRequester = CreateInternalComponent<DecisionRequester>(gameObject);
			
			Resetter = CreateInternalComponent<MujocoChainReseter>(gameObject);
			Resetter.trigger = CreateInternalComponent<EnvironmentResetTrigger>(gameObject);
			
			BehaviorParameters.BrainParameters.VectorObservationSize = 0;
			BehaviorParameters.BrainParameters.NumStackedVectorObservations = 0;
			BehaviorParameters.BrainParameters.ActionSpec = new ActionSpec(0, null);
			BehaviorParameters.UseChildSensors = true;
			BehaviorParameters.UseChildActuators = true;

			DecisionRequester.TakeActionsBetweenDecisions = false;
		}

		public override void OnSelfChanged()
		{
			base.OnSelfChanged();
			UnityAgent.MaxStep = TimeLimit.TimeSpan.ToSteps();
			DecisionRequester.DecisionPeriod = DecisionsPeriod.TimeSpan.ToSteps();
			lastPhysicsTimeStep = Time.fixedDeltaTime;
			BehaviorParameters.BehaviorName = BehaviorName;

			switch (BehaviorType)
			{
				case AgentBehaviorType.TrainingOrManual:
					BehaviorParameters.BehaviorType = UnityBehaviorType.Default;
					BehaviorParameters.Model = null;
					break;
				case AgentBehaviorType.Manual:
					BehaviorParameters.BehaviorType = UnityBehaviorType.HeuristicOnly;
					BehaviorParameters.Model = null;
					break;
				case AgentBehaviorType.Model:
					BehaviorParameters.BehaviorType = UnityBehaviorType.InferenceOnly;
					BehaviorParameters.Model = Model;
					break;
				default:
					throw new InvalidEnumArgumentException(
						nameof(BehaviorType),
						(int)BehaviorType,
						typeof(AgentBehaviorType));
			}
		}

		public override void OnOtherChanged()
		{
			base.OnOtherChanged();
			
			// ReSharper disable once CompareOfFloatsByEqualityOperator
			if (Time.fixedDeltaTime != lastPhysicsTimeStep)
			{
				OnSelfChanged();
				return;
			}

			TimeLimit = TimeSpanExtensions.FromSteps(UnityAgent.MaxStep);
			DecisionsPeriod = TimeSpanExtensions.FromSteps(DecisionRequester.DecisionPeriod);
			BehaviorName = BehaviorParameters.BehaviorName;
			if (Resetter.body == null)
			{
				Resetter.body = GetComponentInChildren<MjBody>();
			}
		}
	}
}