Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
ai.fluctio.fluctio-sim / Core / MujocoExtensions / MujocoComponentsExtensions.cs
Size: Mime:
using System;
using System.Diagnostics;
using System.Linq;
using System.Reflection;
using Fluctio.FluctioSim.Core.MujocoExtensions.MujocoDataProxy;
using Fluctio.FluctioSim.Core.MujocoExtensions.States;
using Fluctio.FluctioSim.Utils.Extensions;
using Mujoco;

namespace Fluctio.FluctioSim.Core.MujocoExtensions
{
	public static unsafe class MujocoComponentsExtensions
	{

		#region Helpers
		
		internal static MujocoLib.mjData_* Data => MjScene.Instance.Data;
		internal static MujocoLib.mjModel_* Model => MjScene.Instance.Model;

		public static void SyncState(bool doSyncState = true)
		{
			if (!doSyncState)
			{
				return;
			}
			MujocoLib.mj_forward(Model, Data);
			MjScene.Instance.SyncUnityToMjState();
		}

		[Conditional("UNITY_EDITOR")]
		public static void CheckMjScene()
		{
			if (!MjScene.InstanceExists)
			{
				throw new InvalidOperationException("This operation is allowed only when mujoco scene exists");
			}
		}
		
		#endregion
		
		#region Data factory

		private static readonly ReflectionExtensions.GenericType[] DataClasses = typeof(ComponentData<>).FindGenericDescendants(new[] {Assembly.GetExecutingAssembly()});

		public static ComponentData GetData<T>(this T component) where T : MjComponent
		{
			if (component == null)
			{
				return null;
			}
			
			var matchingTypes = DataClasses.Where(dataClass => dataClass.Arguments[0] == component.GetType()).Take(2).ToList();
			var typeToCreate = matchingTypes.Count switch
			{
				<= 0 => throw new InvalidOperationException($"Mujoco data proxy type not found for class {typeof(T)}"),
				>= 2 => throw new InvalidOperationException($"Multiple mujoco data proxy types for class {typeof(T)}"),
				1 => matchingTypes[0],
			};
			
			var constructor = typeToCreate.Type.GetConstructor(typeToCreate.Arguments);
			if (constructor == null)
			{
				throw new InvalidOperationException($"Wrong constructor in mujoco data proxy type {typeToCreate.Type}");
			}
			
			var dataObject = constructor.Invoke(new object[] {component});
			return (ComponentData)dataObject;
		}
		
		#endregion

		#region Whole scene state

		private const uint AllPropertiesSpec = uint.MaxValue;
		
		public static double[] GetState(this MjScene scene)
		{
			var size = MujocoLib.mj_stateSize(Model, AllPropertiesSpec);
			var state = new double[size];
			fixed (double* statePointer = &state[0])
			{
				MujocoLib.mj_getState(Model, Data, statePointer, AllPropertiesSpec);	
			}
			return state;
		}

		public static void SetState(this MjScene scene, double[] state)
		{
			fixed (double* statePointer = &state[0])
			{
				MujocoLib.mj_setState(Model, Data, statePointer, AllPropertiesSpec);	
			}
		}
		
		#endregion

		#region Chain states

		private static MjBaseJoint[] GetChainJoints(this MjBody rootBody)
		{
			return rootBody
	            .GetComponentsInChildren<MjBaseJoint>()
	            .OrderBy(joint => joint.MujocoName)
	            .ToArray();
		}

		public static ChainState GetChainState(this MjBody rootBody)
		{
			var joints = rootBody.GetChainJoints();
			var chainState = new ChainState();
			foreach (var joint in joints)
			{
				var jointState = joint.GetData().GetState();
				chainState.JointStates.Add(jointState);
			}
			return chainState;
		}

		public static void SetChainState(this MjBody rootBody, ChainState chainState, bool doSyncState = true)
		{
			var joints = rootBody.GetChainJoints();
			if (chainState.JointStates.Count != joints.Length)
			{
				throw new InvalidOperationException("Joints amount mismatch.\nEither MuJoCo structure was changed between saving and restoring the state, or you are restoring state to wrong body.");
			}

			// ReSharper disable once InvokeAsExtensionMethod
			var jointStatesToRestore = Enumerable.Zip(joints, chainState.JointStates, Tuple.Create);
			foreach (var (joint, jointState) in jointStatesToRestore)
			{
				joint.GetData().SetState(jointState, false);
			}
			SyncState(doSyncState);
		}

		#endregion
		
	}
}