Repository URL to install this package:
|
Version:
1.3.0 ▾
|
using System;
using System.ComponentModel;
using System.Diagnostics;
using System.IO;
using System.Linq;
using System.Text.RegularExpressions;
using System.Threading.Tasks;
using Fluctio.FluctioSim.Common.Configuration;
using Fluctio.FluctioSim.EditorUtils.EditorGeneral;
using Fluctio.FluctioSim.EditorUtils.OperatingSystem;
using JetBrains.Annotations;
using MoreLinq;
using UnityEditor;
using UnityEditor.Build.Reporting;
using UnityEngine;
using UnityEngine.SceneManagement;
using Debug = UnityEngine.Debug;
namespace Fluctio.FluctioSim.EditorCore.Training.PythonWrappers
{
public static class MlAgentsLearn
{
#region General
public static string ResultsFolder { get; } = Path.Combine(EditorUtil.ProjectFolder, "results");
public enum GraphicsType
{
[Description("Hide all")]
Disable,
[Description("Show all")]
Enable,
[Description("Show one, hide the rest")]
Monitor,
}
public enum OverwriteType
{
[Description("Start training")]
None,
[Description("Continue training")]
Continue,
[Description("Start from scratch")]
Force,
}
#endregion
#region Process and state
private static readonly SessionProcess SessionProcess = new(typeof(MlAgentsLearn));
public static event Action TrainingStopped;
[CanBeNull]
private static Process Process
{
get => SessionProcess.Process;
set => SessionProcess.Process = value;
}
[CanBeNull]
private static string RunId
{
get
{
var savedRunId = SessionState.GetString($"{typeof(MlAgentsLearn).FullName}.RunId", "");
return savedRunId == "" ? null : savedRunId;
}
set => SessionState.SetString($"{typeof(MlAgentsLearn).FullName}.RunId", value ?? string.Empty);
}
public static bool IsProcessLaunched => Process is { HasExited: false };
private static void CheckNotRunning()
{
if (IsProcessLaunched)
{
throw new InvalidOperationException("Training is already running");
}
}
[InitializeOnLoadMethod]
private static void StopOnExit()
{
EditorApplication.wantsToQuit += () =>
{
if (!IsProcessLaunched)
{
return true;
}
var userConfirmedQuit = EditorUtility.DisplayDialog(
"Training in progress",
$"There is a {Config.RawName} training in progress. Closing Unity Editor will stop the training.",
"Stop training and quit",
"Cancel");
return userConfirmedQuit;
};
EditorApplication.quitting += async () =>
{
await Stop();
};
}
#endregion
#region Start/stop functions
public static event Action TrainingStarted;
private static TaskCompletionSource<object> _trainingStartedCompletionSource = null;
private static TaskCompletionSource<object> _listeningStartedCompletionSource = null;
public static async Task StartListening(string configPath, string runId, OverwriteType overwriteType)
{
Start(configPath, runId, overwriteType, "--num-envs 1");
await _listeningStartedCompletionSource.Task;
}
public static void StartProcesses(string configPath, string runId, OverwriteType overwriteType, string envPath, GraphicsType graphicsType, int numEnvs)
{
var graphicsFlag = graphicsType switch
{
GraphicsType.Disable => "--no-graphics",
GraphicsType.Enable => "",
GraphicsType.Monitor => "--no-graphics-monitor",
_ => throw new InvalidEnumArgumentException(
nameof(graphicsType),
(int)graphicsType,
graphicsType.GetType())
};
var otherFlags = $"{graphicsFlag} --env {ProcessUtil.InQuotes(envPath)} --num-envs {numEnvs}";
Start(configPath, runId, overwriteType, otherFlags);
}
private static void Start(string configPath, string runId, OverwriteType overwriteType, string otherFlags)
{
CheckNotRunning();
var overwriteFlag = overwriteType switch
{
OverwriteType.None => "",
OverwriteType.Continue => "--resume",
OverwriteType.Force => "--force",
_ => throw new InvalidEnumArgumentException(
nameof(overwriteType),
(int)overwriteType,
overwriteType.GetType())
};
var command = $"mlagents-learn {ProcessUtil.InQuotes(configPath, false)} --run-id {ProcessUtil.InQuotes(runId)} {overwriteFlag} {otherFlags} --time-scale 1";
RunId = runId;
Process = PythonVenv.StartShellProcess(command);
Debug.Log($"Starting training:\n{command}");
_trainingStartedCompletionSource = new TaskCompletionSource<object>();
_listeningStartedCompletionSource = new TaskCompletionSource<object>();
ConnectProcess();
}
[InitializeOnLoadMethod]
private static void ConnectProcess()
{
if (Process == null)
{
return;
}
Process.Exited += delegate
{
EditorUpdateEvents.MainThread += () =>
{
TrainingStopped?.Invoke();
EditorApplication.UnlockReloadAssemblies();
_ = OutputOnExit();
StopRunLogsWatching();
Process = null;
RunId = null;
};
};
EditorApplication.LockReloadAssemblies();
ConnectProcessOutput();
WatchRunLogs();
}
public static async Task Stop()
{
if (Process == null)
{
return;
}
Debug.Log("Stopping training...");
await Process.KillTree();
}
#endregion
#region Start/stop functions with helpers
public static async Task StartInEditor(string configPath, string runId, OverwriteType overwriteType)
{
await StartListening(configPath, runId, overwriteType);
EditorApplication.EnterPlaymode();
}
public static async Task StopInEditor()
{
var task = Stop();
EditorApplication.ExitPlaymode();
await task;
}
public static void BuildAndStart(string configPath, string runId, OverwriteType overwriteType, GraphicsType graphicsType, int numEnvs)
{
CheckNotRunning();
var buildTarget = Platform.Type switch
{
PlatformType.Windows => BuildTarget.StandaloneWindows64,
PlatformType.Linux => BuildTarget.StandaloneLinux64,
PlatformType.MacOS => BuildTarget.StandaloneOSX,
PlatformType.Other or _ => throw Platform.OSNotSupportedException,
};
var fileExtension = Platform.Type switch
{
PlatformType.Windows => "exe",
PlatformType.Linux => "x86_64",
PlatformType.MacOS => "app",
PlatformType.Other or _ => throw Platform.OSNotSupportedException,
};
var report = BuildPipeline.BuildPlayer(new BuildPlayerOptions
{
locationPathName = Path.Combine(EditorUtil.ProjectFolder, "Build", $"{Application.productName}.{fileExtension}"),
scenes = new [] { SceneManager.GetActiveScene().path },
target = buildTarget,
});
if (report.summary.result != BuildResult.Succeeded)
{
throw new AggregateException($"Build to {report.summary.outputPath} was not finished (status = {report.summary.result})");
}
Debug.Log($"Current scene was built to {report.summary.outputPath}");
StartProcesses(configPath, runId, overwriteType, report.summary.outputPath, graphicsType, numEnvs);
}
#endregion
#region Output parsing
private static readonly Regex LogFormat = new(@"^\[(?<level>.*?)] (?<message>.*)$");
private static readonly Regex ListeningStartedLine = new(@"^Listening on port (?<port>\d+)\. Start training by pressing the Play button in the Unity Editor\.$");
private static readonly Regex UnityConnectedLine = new(@"^Connected to Unity environment with package version (?<packageVersion>\d+\.\d+\.\d+) and communication version (?<communicationVersion>\d+\.\d+\.\d+)$");
private static readonly Regex BrainConnectedLine = new(@"^Connected new brain: (?<behaviorName>.*)?team=(?<team>\d+)$");
private static void ConnectProcessOutput()
{
if (Process == null || !Process.StartInfo.RedirectStandardOutput)
{
// after domain reload
return;
}
Process.OutputDataReceived += (_, args) =>
{
if (args.Data != null)
{
OnOutput(args.Data);
}
};
Process.BeginOutputReadLine();
}
private static async Task OutputOnExit()
{
if (Process == null || !Process.StartInfo.RedirectStandardError)
{
// after domain reload
return;
}
var errorMessage = await Process.StandardError.ReadToEndAsync();
if (errorMessage == "")
{
Debug.Log("Training was stopped");
return;
}
var importantMessage = errorMessage
.Split("\n")
.Where(line => !line.StartsWith(" "))
.Where(line => !line.Contains("Traceback"))
.Where(line => !string.IsNullOrWhiteSpace(line))
.ToDelimitedString("\n");
Debug.LogError($"{importantMessage}\n\n{errorMessage}");
}
private static void OnOutput(string line)
{
var lineMatch = LogFormat.Match(line);
if (!lineMatch.Success)
{
Debug.Log(line);
return;
}
var message = lineMatch.Groups["message"].Value;
if (ListeningStartedLine.IsMatch(message))
{
_listeningStartedCompletionSource.SetResult(true);
return;
}
if (BrainConnectedLine.IsMatch(message) || UnityConnectedLine.IsMatch(message))
{
var isResultSet = _trainingStartedCompletionSource.TrySetResult(true);
if (isResultSet && TrainingStarted != null)
{
EditorUpdateEvents.MainThread += TrainingStarted.Invoke;
}
return;
}
var logLevel = lineMatch.Groups["level"].Value switch
{
"INFO" => LogType.Log,
"WARNING" => LogType.Warning,
"ERROR" or "FATAL" or "CRITICAL" => LogType.Error,
_ => LogType.Warning,
};
Debug.unityLogger.Log(logLevel, message);
}
#endregion
#region Run logs watcher
private static FileSystemWatcher _runLogsWatcher;
private static FileSystemEventHandler _runLogsChangedHandler;
public static event Action<string, string> RunLogsChanged;
public static event Action<string, string, string, int> CheckpointCreated;
private static async void WatchRunLogs()
{
var runId = RunId;
if (Process == null || runId == null)
{
return;
}
if (_trainingStartedCompletionSource != null)
{
await _trainingStartedCompletionSource.Task;
}
var runPath = Path.Combine(ResultsFolder, runId);
_runLogsWatcher = new FileSystemWatcher(runPath);
_runLogsWatcher.NotifyFilter = NotifyFilters.LastWrite;
_runLogsWatcher.IncludeSubdirectories = true;
_runLogsWatcher.EnableRaisingEvents = true;
_runLogsChangedHandler = (sender, args) => OnRunLogsChanged(runId, args);
_runLogsWatcher.Changed += _runLogsChangedHandler;
_runLogsWatcher.Created += _runLogsChangedHandler;
}
private static void StopRunLogsWatching()
{
if (_runLogsWatcher == null)
{
return;
}
_runLogsWatcher.Changed -= _runLogsChangedHandler;
_runLogsWatcher.Created -= _runLogsChangedHandler;
_runLogsChangedHandler = null;
_runLogsWatcher.Dispose();
_runLogsWatcher = null;
}
private static void OnRunLogsChanged(string runId, FileSystemEventArgs args)
{
var filename = Path.GetFileName(args.Name);
var behaviorName = Path.GetFileName(Path.GetDirectoryName(args.Name));
var extension = Path.GetExtension(args.Name);
CheckRunLogsChanged(runId, behaviorName, filename);
CheckCheckpointCreated(runId, behaviorName, filename, extension, args.FullPath, args.ChangeType);
}
private static void CheckRunLogsChanged(string runId, string behaviorName, string filename)
{
if (!filename.StartsWith("events.out.tfevents"))
{
return;
}
EditorUpdateEvents.MainThread += () =>
{
RunLogsChanged?.Invoke(runId, behaviorName);
};
}
private static void CheckCheckpointCreated(
string runId,
string behaviorName,
string filename,
string extension,
string fullPath,
WatcherChangeTypes changeType)
{
if (behaviorName == null || extension != ".onnx" || changeType != WatcherChangeTypes.Created)
{
return;
}
var match = Regex.Match(filename, $@"^{behaviorName}-(?<step>\d+)\.onnx$");
if (!match.Success)
{
return;
}
var step = int.Parse(match.Groups["step"].Value);
EditorUpdateEvents.MainThread += () =>
{
CheckpointCreated?.Invoke(fullPath, runId, behaviorName, step);
};
}
#endregion
}
}