Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
Size: Mime:
using System;
using System.ComponentModel;
using System.Diagnostics;
using System.IO;
using System.Linq;
using System.Text.RegularExpressions;
using System.Threading.Tasks;
using Fluctio.FluctioSim.Common.Configuration;
using Fluctio.FluctioSim.EditorUtils.EditorGeneral;
using Fluctio.FluctioSim.EditorUtils.OperatingSystem;
using JetBrains.Annotations;
using MoreLinq;
using UnityEditor;
using UnityEditor.Build.Reporting;
using UnityEngine;
using UnityEngine.SceneManagement;
using Debug = UnityEngine.Debug;

namespace Fluctio.FluctioSim.EditorCore.Training.PythonWrappers
{
	public static class MlAgentsLearn
	{

		#region General

		public static string ResultsFolder { get; } = Path.Combine(EditorUtil.ProjectFolder, "results");

		public enum GraphicsType
		{
			[Description("Hide all")]
			Disable,
			[Description("Show all")]
			Enable,
			[Description("Show one, hide the rest")]
			Monitor,
		}

		public enum OverwriteType
		{
			[Description("Start training")]
			None,
			[Description("Continue training")]
			Continue,
			[Description("Start from scratch")]
			Force,
		}

		#endregion
		
		#region Process and state
		
		private static readonly SessionProcess SessionProcess = new(typeof(MlAgentsLearn));
		public static event Action TrainingStopped;
		
		[CanBeNull]
		private static Process Process
		{
			get => SessionProcess.Process;
			set => SessionProcess.Process = value;
		}

		[CanBeNull]
		private static string RunId
		{
			get
			{
				var savedRunId = SessionState.GetString($"{typeof(MlAgentsLearn).FullName}.RunId", "");
				return savedRunId == "" ? null : savedRunId;
			}
			set => SessionState.SetString($"{typeof(MlAgentsLearn).FullName}.RunId", value ?? string.Empty);
		}

		public static bool IsProcessLaunched => Process is { HasExited: false };

		private static void CheckNotRunning()
		{
			if (IsProcessLaunched)
			{
				throw new InvalidOperationException("Training is already running");
			}
		}

		[InitializeOnLoadMethod]
		private static void StopOnExit()
		{
			EditorApplication.wantsToQuit += () =>
			{
				if (!IsProcessLaunched)
				{
					return true;
				}
				var userConfirmedQuit = EditorUtility.DisplayDialog(
					"Training in progress",
					$"There is a {Config.RawName} training in progress. Closing Unity Editor will stop the training.",
					"Stop training and quit",
					"Cancel");
				return userConfirmedQuit;
			};
			EditorApplication.quitting += async () =>
			{
				await Stop();
			};
		}
		
		#endregion

		#region Start/stop functions
		
		public static event Action TrainingStarted;
		private static TaskCompletionSource<object> _trainingStartedCompletionSource = null;
		private static TaskCompletionSource<object> _listeningStartedCompletionSource = null;
		
		public static async Task StartListening(string configPath, string runId, OverwriteType overwriteType)
		{
			Start(configPath, runId, overwriteType, "--num-envs 1");
			await _listeningStartedCompletionSource.Task;
		}

		public static void StartProcesses(string configPath, string runId, OverwriteType overwriteType, string envPath, GraphicsType graphicsType, int numEnvs)
		{
			var graphicsFlag = graphicsType switch
			{
				GraphicsType.Disable => "--no-graphics",
				GraphicsType.Enable => "",
				GraphicsType.Monitor => "--no-graphics-monitor",
				_ => throw new InvalidEnumArgumentException(
					nameof(graphicsType),
					(int)graphicsType,
					graphicsType.GetType())
			};
			var otherFlags = $"{graphicsFlag} --env {ProcessUtil.InQuotes(envPath)} --num-envs {numEnvs}";
			Start(configPath, runId, overwriteType, otherFlags);
		}

		private static void Start(string configPath, string runId, OverwriteType overwriteType, string otherFlags)
		{
			CheckNotRunning();
			var overwriteFlag = overwriteType switch
			{
				OverwriteType.None => "",
				OverwriteType.Continue => "--resume",
				OverwriteType.Force => "--force",
				_ => throw new InvalidEnumArgumentException(
					nameof(overwriteType),
					(int)overwriteType,
					overwriteType.GetType())
			};
			
			var command = $"mlagents-learn {ProcessUtil.InQuotes(configPath, false)} --run-id {ProcessUtil.InQuotes(runId)} {overwriteFlag} {otherFlags} --time-scale 1";
			RunId = runId;
			Process = PythonVenv.StartShellProcess(command);
			Debug.Log($"Starting training:\n{command}");
			
			_trainingStartedCompletionSource = new TaskCompletionSource<object>();
			_listeningStartedCompletionSource = new TaskCompletionSource<object>();
			ConnectProcess();
		}

		[InitializeOnLoadMethod]
		private static void ConnectProcess()
		{
			if (Process == null)
			{
				return;
			}
			
			Process.Exited += delegate
			{
				EditorUpdateEvents.MainThread += () =>
				{
					TrainingStopped?.Invoke();
					EditorApplication.UnlockReloadAssemblies();
					_ = OutputOnExit();
					StopRunLogsWatching();
					Process = null;
					RunId = null;
				};
			};
			
			EditorApplication.LockReloadAssemblies();
			ConnectProcessOutput();
			WatchRunLogs();
		}
		
		public static async Task Stop()
		{
			if (Process == null)
			{
				return;
			}
			Debug.Log("Stopping training...");
			await Process.KillTree();
		}

		#endregion

		#region Start/stop functions with helpers
		
		public static async Task StartInEditor(string configPath, string runId, OverwriteType overwriteType)
		{
			await StartListening(configPath, runId, overwriteType);
			EditorApplication.EnterPlaymode();
		}
		
		public static async Task StopInEditor()
		{
			var task = Stop();
			EditorApplication.ExitPlaymode();
			await task;
		}

		public static void BuildAndStart(string configPath, string runId, OverwriteType overwriteType, GraphicsType graphicsType, int numEnvs)
		{
			CheckNotRunning();
			var buildTarget = Platform.Type switch
			{
				PlatformType.Windows => BuildTarget.StandaloneWindows64,
				PlatformType.Linux => BuildTarget.StandaloneLinux64,
				PlatformType.MacOS => BuildTarget.StandaloneOSX,
				PlatformType.Other or _ => throw Platform.OSNotSupportedException,
			};
			var fileExtension = Platform.Type switch
			{
				PlatformType.Windows => "exe",
				PlatformType.Linux => "x86_64",
				PlatformType.MacOS => "app",
				PlatformType.Other or _ => throw Platform.OSNotSupportedException,
			};
			var report = BuildPipeline.BuildPlayer(new BuildPlayerOptions
			{
				locationPathName = Path.Combine(EditorUtil.ProjectFolder, "Build", $"{Application.productName}.{fileExtension}"),
				scenes = new [] { SceneManager.GetActiveScene().path },
				target = buildTarget,
			});
			if (report.summary.result != BuildResult.Succeeded)
			{
				throw new AggregateException($"Build to {report.summary.outputPath} was not finished (status = {report.summary.result})");
			}
			Debug.Log($"Current scene was built to {report.summary.outputPath}");
			StartProcesses(configPath, runId, overwriteType, report.summary.outputPath, graphicsType, numEnvs);
		}

		#endregion

		#region Output parsing
		
		private static readonly Regex LogFormat = new(@"^\[(?<level>.*?)] (?<message>.*)$");
		private static readonly Regex ListeningStartedLine = new(@"^Listening on port (?<port>\d+)\. Start training by pressing the Play button in the Unity Editor\.$");
		private static readonly Regex UnityConnectedLine = new(@"^Connected to Unity environment with package version (?<packageVersion>\d+\.\d+\.\d+) and communication version (?<communicationVersion>\d+\.\d+\.\d+)$");
		private static readonly Regex BrainConnectedLine = new(@"^Connected new brain: (?<behaviorName>.*)?team=(?<team>\d+)$");
		
		private static void ConnectProcessOutput()
		{
			if (Process == null || !Process.StartInfo.RedirectStandardOutput)
			{
				// after domain reload
				return;
			}
			Process.OutputDataReceived += (_, args) =>
			{
				if (args.Data != null)
				{
					OnOutput(args.Data);
				}
			};
			Process.BeginOutputReadLine();
		}

		private static async Task OutputOnExit()
		{
			if (Process == null || !Process.StartInfo.RedirectStandardError)
			{
				// after domain reload
				return;
			}
			var errorMessage = await Process.StandardError.ReadToEndAsync();
			if (errorMessage == "")
			{
				Debug.Log("Training was stopped");
				return;
			}
			var importantMessage = errorMessage
				.Split("\n")
				.Where(line => !line.StartsWith(" "))
				.Where(line => !line.Contains("Traceback"))
				.Where(line => !string.IsNullOrWhiteSpace(line))
				.ToDelimitedString("\n");
			Debug.LogError($"{importantMessage}\n\n{errorMessage}");
		}

		private static void OnOutput(string line)
		{
			var lineMatch = LogFormat.Match(line);
			if (!lineMatch.Success)
			{
				Debug.Log(line);
				return;
			}

			var message = lineMatch.Groups["message"].Value;
			
			if (ListeningStartedLine.IsMatch(message))
			{
				_listeningStartedCompletionSource.SetResult(true);
				return;
			}

			if (BrainConnectedLine.IsMatch(message) || UnityConnectedLine.IsMatch(message))
			{
				var isResultSet = _trainingStartedCompletionSource.TrySetResult(true);
				if (isResultSet && TrainingStarted != null)
				{
					EditorUpdateEvents.MainThread += TrainingStarted.Invoke;
				}
				return;
			}

			var logLevel = lineMatch.Groups["level"].Value switch
			{
				"INFO" => LogType.Log,
				"WARNING" => LogType.Warning,
				"ERROR" or "FATAL" or "CRITICAL" => LogType.Error,
				_ => LogType.Warning,
			};
			Debug.unityLogger.Log(logLevel, message);
		}
		
		#endregion

		#region Run logs watcher
		
		private static FileSystemWatcher _runLogsWatcher;
		private static FileSystemEventHandler _runLogsChangedHandler;
		public static event Action<string, string> RunLogsChanged;
		public static event Action<string, string, string, int> CheckpointCreated;

		private static async void WatchRunLogs()
		{
			var runId = RunId;
			if (Process == null || runId == null)
			{
				return;
			}

			if (_trainingStartedCompletionSource != null)
			{
				await _trainingStartedCompletionSource.Task;
			}

			var runPath = Path.Combine(ResultsFolder, runId);
			_runLogsWatcher = new FileSystemWatcher(runPath);
			_runLogsWatcher.NotifyFilter = NotifyFilters.LastWrite;
			_runLogsWatcher.IncludeSubdirectories = true;
			_runLogsWatcher.EnableRaisingEvents = true;

			_runLogsChangedHandler = (sender, args) => OnRunLogsChanged(runId, args);
			_runLogsWatcher.Changed += _runLogsChangedHandler;
			_runLogsWatcher.Created += _runLogsChangedHandler;
		}

		private static void StopRunLogsWatching()
		{
			if (_runLogsWatcher == null)
			{
				return;
			}

			_runLogsWatcher.Changed -= _runLogsChangedHandler;
			_runLogsWatcher.Created -= _runLogsChangedHandler;
			_runLogsChangedHandler = null;
			_runLogsWatcher.Dispose();
			_runLogsWatcher = null;
		}

		private static void OnRunLogsChanged(string runId, FileSystemEventArgs args)
		{
			var filename = Path.GetFileName(args.Name);
			var behaviorName = Path.GetFileName(Path.GetDirectoryName(args.Name));
			var extension = Path.GetExtension(args.Name);
			CheckRunLogsChanged(runId, behaviorName, filename);
			CheckCheckpointCreated(runId, behaviorName, filename, extension, args.FullPath, args.ChangeType);
		}

		private static void CheckRunLogsChanged(string runId, string behaviorName, string filename)
		{
			if (!filename.StartsWith("events.out.tfevents"))
			{
				return;
			}
			EditorUpdateEvents.MainThread += () =>
			{
				RunLogsChanged?.Invoke(runId, behaviorName);
			};
		}

		private static void CheckCheckpointCreated(
			string runId,
			string behaviorName,
			string filename,
			string extension,
			string fullPath,
			WatcherChangeTypes changeType)
		{
			if (behaviorName == null || extension != ".onnx" || changeType != WatcherChangeTypes.Created)
			{
				return;
			}

			var match = Regex.Match(filename, $@"^{behaviorName}-(?<step>\d+)\.onnx$");
			if (!match.Success)
			{
				return;
			}
			
			var step = int.Parse(match.Groups["step"].Value);
			EditorUpdateEvents.MainThread += () =>
			{
				CheckpointCreated?.Invoke(fullPath, runId, behaviorName, step);
			};
		}

		#endregion
		
	}
}