Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
ai.fluctio.fluctio-sim / EditorCore / Training / General / BestModelRetainer.cs
Size: Mime:
using System;
using System.Collections.Generic;
using System.IO;
using Fluctio.FluctioSim.EditorCore.Training.PythonWrappers;
using Fluctio.FluctioSim.EditorCore.Training.TensorboardApi;
using Fluctio.FluctioSim.Utils.General;
using JetBrains.Annotations;
using Newtonsoft.Json;
using UnityEditor;
using UnityEngine;

namespace Fluctio.FluctioSim.EditorCore.Training.General
{
	internal static class BestModelRetainer
	{

		private static readonly string BestRewardsSessionKey = $"{typeof(BestModelRetainer).FullName}.BestRewards";
		
		[CanBeNull]
		private static Dictionary<string, float> BestRewards
		{
			get
			{
				var serializedString = SessionState.GetString(BestRewardsSessionKey, null);
				return serializedString == null
					? null
					: JsonConvert.DeserializeObject<Dictionary<string, float>>(serializedString);
			}
			set
			{
				var serializedString = JsonConvert.SerializeObject(value);
				SessionState.SetString(BestRewardsSessionKey, serializedString);
			}
		}

		[InitializeOnLoadMethod]
		private static void RetainBestModel()
		{
			MlAgentsLearn.CheckpointCreated += OnCheckpoint;
			MlAgentsLearn.TrainingStarted += OnStart;
			MlAgentsLearn.TrainingStopped += OnStop;
		}

		private static void OnStart()
		{
			BestRewards = new Dictionary<string, float>();
		}

		private static void OnStop()
		{
			BestRewards = null;
		}

		//TODO: move to ScalarData class?
		private class StepComparer : IComparer<ScalarEvent>
		{
			public int Compare(ScalarEvent x, ScalarEvent y)
			{
				if (ReferenceEquals(x, y)) return 0;
				if (y is null) return 1;
				if (x is null) return -1;
				return x.Step.CompareTo(y.Step);
			}
		}

		private static ScalarEvent BinarySearchClosest(this List<ScalarEvent> events, int step)
		{
			var foundIndex = events.BinarySearch(new ScalarEvent(default, step, 0), new StepComparer());
			if (foundIndex >= 0)
			{
				return events[foundIndex];
			}
			
			var insertionIndex = ~foundIndex;
			if (insertionIndex == 0)
			{
				return events[0];
			}
			if (insertionIndex == events.Count)
			{
				return events[^1];
			}

			var event1 = events[insertionIndex - 1];
			var event2 = events[insertionIndex];
			return Math.Abs(event1.Step - step) < Math.Abs(event2.Step - step) ? event1 : event2;
		}

		private static async void OnCheckpoint(string newModelPath, string runId, string behaviorName, int step)
		{
			using var profilerScope = new ProfilerScope("BestModelRetainer.OnCheckpoint");
			var rewards = await Tensorboard.GetCumulativeReward(runId, behaviorName);
			lock (BestRewardsSessionKey)
			{
				var bestRewards = BestRewards;
				if (bestRewards == null)
				{
					return;
				}

				var rewardEvent = rewards.BinarySearchClosest(step);
				var doesBestRewardExist = bestRewards.TryGetValue(behaviorName, out var bestReward);
				if (doesBestRewardExist && rewardEvent.Value <= bestReward)
				{
					return;
				}

				var bestModelPath = Path.Combine(MlAgentsLearn.ResultsFolder, runId, $"{behaviorName}-best.onnx");
				File.Copy(newModelPath, bestModelPath, true);
				bestRewards[behaviorName] = rewardEvent.Value;
				BestRewards = bestRewards;
				Debug.Log($"New best model for \"{behaviorName}\" (step = {rewardEvent.Step}, reward = {rewardEvent.Value})\nSaved to {bestModelPath}");
			}
		}
	}
}