Repository URL to install this package:
|
Version:
1.1.0 ▾
|
using System;
using System.Diagnostics;
using System.Linq;
using System.Reflection;
using Fluctio.FluctioSim.Core.MujocoExtensions.MujocoDataProxy;
using Fluctio.FluctioSim.Core.MujocoExtensions.States;
using Fluctio.FluctioSim.Utils.Extensions;
using Mujoco;
namespace Fluctio.FluctioSim.Core.MujocoExtensions
{
public static unsafe class MujocoComponentsExtensions
{
#region Helpers
internal static MujocoLib.mjData_* Data => MjScene.Instance.Data;
internal static MujocoLib.mjModel_* Model => MjScene.Instance.Model;
public static void SyncState(bool doSyncState = true)
{
if (!doSyncState)
{
return;
}
MujocoLib.mj_forward(Model, Data);
MjScene.Instance.SyncUnityToMjState();
}
[Conditional("UNITY_EDITOR")]
public static void CheckMjScene()
{
if (!MjScene.InstanceExists)
{
throw new InvalidOperationException("This operation is allowed only when mujoco scene exists");
}
}
#endregion
#region Data factory
private static readonly ReflectionExtensions.GenericType[] DataClasses = typeof(ComponentData<>).FindGenericDescendants(new[] {Assembly.GetExecutingAssembly()});
public static ComponentData GetData<T>(this T component) where T : MjComponent
{
if (component == null)
{
return null;
}
var matchingTypes = DataClasses.Where(dataClass => dataClass.Arguments[0] == component.GetType()).Take(2).ToList();
var typeToCreate = matchingTypes.Count switch
{
<= 0 => throw new InvalidOperationException($"Mujoco data proxy type not found for class {typeof(T)}"),
>= 2 => throw new InvalidOperationException($"Multiple mujoco data proxy types for class {typeof(T)}"),
1 => matchingTypes[0],
};
var constructor = typeToCreate.Type.GetConstructor(typeToCreate.Arguments);
if (constructor == null)
{
throw new InvalidOperationException($"Wrong constructor in mujoco data proxy type {typeToCreate.Type}");
}
var dataObject = constructor.Invoke(new object[] {component});
return (ComponentData)dataObject;
}
#endregion
#region Whole scene state
private const uint AllPropertiesSpec = uint.MaxValue;
public static double[] GetState(this MjScene scene)
{
var size = MujocoLib.mj_stateSize(Model, AllPropertiesSpec);
var state = new double[size];
fixed (double* statePointer = &state[0])
{
MujocoLib.mj_getState(Model, Data, statePointer, AllPropertiesSpec);
}
return state;
}
public static void SetState(this MjScene scene, double[] state)
{
fixed (double* statePointer = &state[0])
{
MujocoLib.mj_setState(Model, Data, statePointer, AllPropertiesSpec);
}
}
#endregion
#region Chain states
private static MjBaseJoint[] GetChainJoints(this MjBody rootBody)
{
return rootBody
.GetComponentsInChildren<MjBaseJoint>()
.OrderBy(joint => joint.MujocoName)
.ToArray();
}
public static ChainState GetChainState(this MjBody rootBody)
{
var joints = rootBody.GetChainJoints();
var chainState = new ChainState();
foreach (var joint in joints)
{
var jointState = joint.GetData().GetState();
chainState.JointStates.Add(jointState);
}
return chainState;
}
public static void SetChainState(this MjBody rootBody, ChainState chainState, bool doSyncState = true)
{
var joints = rootBody.GetChainJoints();
if (chainState.JointStates.Count != joints.Length)
{
throw new InvalidOperationException("Joints amount mismatch.\nEither MuJoCo structure was changed between saving and restoring the state, or you are restoring state to wrong body.");
}
// ReSharper disable once InvokeAsExtensionMethod
var jointStatesToRestore = Enumerable.Zip(joints, chainState.JointStates, Tuple.Create);
foreach (var (joint, jointState) in jointStatesToRestore)
{
joint.GetData().SetState(jointState, false);
}
SyncState(doSyncState);
}
#endregion
}
}