Repository URL to install this package:
Version:
1.3.1 ▾
|
using System;
using System.Collections.Generic;
using System.IO;
using Fluctio.FluctioSim.EditorCore.Training.PythonWrappers;
using Fluctio.FluctioSim.EditorCore.Training.TensorboardApi;
using Fluctio.FluctioSim.Utils.General;
using JetBrains.Annotations;
using Newtonsoft.Json;
using UnityEditor;
using UnityEngine;
namespace Fluctio.FluctioSim.EditorCore.Training.General
{
internal static class BestModelRetainer
{
private static readonly string BestRewardsSessionKey = $"{typeof(BestModelRetainer).FullName}.BestRewards";
[CanBeNull]
private static Dictionary<string, float> BestRewards
{
get
{
var serializedString = SessionState.GetString(BestRewardsSessionKey, null);
return serializedString == null
? null
: JsonConvert.DeserializeObject<Dictionary<string, float>>(serializedString);
}
set
{
var serializedString = JsonConvert.SerializeObject(value);
SessionState.SetString(BestRewardsSessionKey, serializedString);
}
}
[InitializeOnLoadMethod]
private static void RetainBestModel()
{
MlAgentsLearn.CheckpointCreated += OnCheckpoint;
MlAgentsLearn.TrainingStarted += OnStart;
MlAgentsLearn.TrainingStopped += OnStop;
}
private static void OnStart()
{
BestRewards = new Dictionary<string, float>();
}
private static void OnStop()
{
BestRewards = null;
}
//TODO: move to ScalarData class?
private class StepComparer : IComparer<ScalarEvent>
{
public int Compare(ScalarEvent x, ScalarEvent y)
{
if (ReferenceEquals(x, y)) return 0;
if (y is null) return 1;
if (x is null) return -1;
return x.Step.CompareTo(y.Step);
}
}
private static ScalarEvent BinarySearchClosest(this List<ScalarEvent> events, int step)
{
var foundIndex = events.BinarySearch(new ScalarEvent(default, step, 0), new StepComparer());
if (foundIndex >= 0)
{
return events[foundIndex];
}
var insertionIndex = ~foundIndex;
if (insertionIndex == 0)
{
return events[0];
}
if (insertionIndex == events.Count)
{
return events[^1];
}
var event1 = events[insertionIndex - 1];
var event2 = events[insertionIndex];
return Math.Abs(event1.Step - step) < Math.Abs(event2.Step - step) ? event1 : event2;
}
private static async void OnCheckpoint(string newModelPath, string runId, string behaviorName, int step)
{
using var profilerScope = new ProfilerScope("BestModelRetainer.OnCheckpoint");
var rewards = await Tensorboard.GetCumulativeReward(runId, behaviorName);
lock (BestRewardsSessionKey)
{
var bestRewards = BestRewards;
if (bestRewards == null)
{
return;
}
var rewardEvent = rewards.BinarySearchClosest(step);
var doesBestRewardExist = bestRewards.TryGetValue(behaviorName, out var bestReward);
if (doesBestRewardExist && rewardEvent.Value <= bestReward)
{
return;
}
var bestModelPath = Path.Combine(MlAgentsLearn.ResultsFolder, runId, $"{behaviorName}-best.onnx");
File.Copy(newModelPath, bestModelPath, true);
bestRewards[behaviorName] = rewardEvent.Value;
BestRewards = bestRewards;
Debug.Log($"New best model for \"{behaviorName}\" (step = {rewardEvent.Step}, reward = {rewardEvent.Value})\nSaved to {bestModelPath}");
}
}
}
}