diff --git a/client/src/configSchema.js b/client/src/configSchema.js
new file mode 100644
index 0000000..06c847b
--- /dev/null
+++ b/client/src/configSchema.js
@@ -0,0 +1,262 @@
+const LEGACY_ROOT_KEYS = ["DATASET", "SOLVER", "SYSTEM", "INFERENCE", "MODEL"];
+const V2_ROOT_KEYS = [
+ "model",
+ "data",
+ "train",
+ "test",
+ "monitor",
+ "inference",
+ "optimization",
+ "default",
+];
+
+const SLIDER_PATHS = {
+ training: {
+ batch_size: [
+ ["SOLVER", "SAMPLES_PER_BATCH"],
+ ["data", "dataloader", "batch_size"],
+ ["default", "data", "dataloader", "batch_size"],
+ ],
+ gpus: [
+ ["SYSTEM", "NUM_GPUS"],
+ ["train", "system", "num_gpus"],
+ ["system", "num_gpus"],
+ ["default", "system", "num_gpus"],
+ ],
+ cpus: [
+ ["SYSTEM", "NUM_CPUS"],
+ ["train", "system", "num_workers"],
+ ["system", "num_workers"],
+ ["default", "system", "num_workers"],
+ ],
+ },
+ inference: {
+ batch_size: [
+ ["INFERENCE", "SAMPLES_PER_BATCH"],
+ ["inference", "batch_size"],
+ ["default", "inference", "batch_size"],
+ ],
+ augmentations: [["INFERENCE", "AUG_NUM"]],
+ },
+};
+
+function isObject(value) {
+ return value && typeof value === "object" && !Array.isArray(value);
+}
+
+function joinPath(basePath, leaf) {
+ if (!basePath) return leaf;
+ return basePath.endsWith("/") ? `${basePath}${leaf}` : `${basePath}/${leaf}`;
+}
+
+export function getPathValue(data, path) {
+ if (!isObject(data) || !Array.isArray(path)) return undefined;
+ return path.reduce((cursor, key) => {
+ if (!isObject(cursor) || !(key in cursor)) return undefined;
+ return cursor[key];
+ }, data);
+}
+
+export function hasPath(data, path) {
+ return getPathValue(data, path) !== undefined;
+}
+
+export function setPathValue(data, path, value) {
+ if (!isObject(data) || !Array.isArray(path) || path.length === 0) return;
+ let cursor = data;
+ path.forEach((key, index) => {
+ if (index === path.length - 1) {
+ cursor[key] = value;
+ return;
+ }
+ if (!isObject(cursor[key])) {
+ cursor[key] = {};
+ }
+ cursor = cursor[key];
+ });
+}
+
+export function pickFirstExistingPath(data, candidates) {
+ if (!Array.isArray(candidates)) return null;
+ for (const candidate of candidates) {
+ if (hasPath(data, candidate)) return candidate;
+ }
+ return null;
+}
+
+export function detectConfigSchema(configObj) {
+ if (!isObject(configObj)) return "unknown";
+ const legacyScore = LEGACY_ROOT_KEYS.filter((k) => k in configObj).length;
+ const v2Score = V2_ROOT_KEYS.filter((k) => k in configObj).length;
+ if (legacyScore === 0 && v2Score === 0) return "unknown";
+ return legacyScore >= v2Score ? "legacy" : "v2";
+}
+
+function resolveSliderPath(configObj, type, key) {
+ const candidates = SLIDER_PATHS[type]?.[key] || [];
+ return pickFirstExistingPath(configObj, candidates);
+}
+
+export function getSliderValue(configObj, type, key) {
+ const path = resolveSliderPath(configObj, type, key);
+ return path ? getPathValue(configObj, path) : undefined;
+}
+
+export function isSliderSupported(configObj, type, key) {
+ return Boolean(resolveSliderPath(configObj, type, key));
+}
+
+export function setSliderValue(configObj, type, key, value) {
+ const path = resolveSliderPath(configObj, type, key);
+ if (!path) return false;
+ setPathValue(configObj, path, value);
+ return true;
+}
+
+export function setTrainingOutputPath(configObj, outputPath) {
+ if (!outputPath || !isObject(configObj)) return;
+ const schema = detectConfigSchema(configObj);
+ if (schema === "legacy") {
+ setPathValue(configObj, ["DATASET", "OUTPUT_PATH"], outputPath);
+ return;
+ }
+ const checkpointsPath = joinPath(outputPath, "checkpoints");
+ if (hasPath(configObj, ["train", "monitor", "checkpoint", "dirpath"])) {
+ setPathValue(configObj, ["train", "monitor", "checkpoint", "dirpath"], checkpointsPath);
+ return;
+ }
+ if (hasPath(configObj, ["monitor", "checkpoint", "dirpath"])) {
+ setPathValue(configObj, ["monitor", "checkpoint", "dirpath"], checkpointsPath);
+ return;
+ }
+ setPathValue(configObj, ["monitor", "checkpoint", "dirpath"], checkpointsPath);
+}
+
+export function setInferenceOutputPath(configObj, outputPath) {
+ if (!outputPath || !isObject(configObj)) return;
+ const schema = detectConfigSchema(configObj);
+ if (schema === "legacy") {
+ setPathValue(configObj, ["INFERENCE", "OUTPUT_PATH"], outputPath);
+ return;
+ }
+ setPathValue(configObj, ["inference", "save_prediction", "output_path"], outputPath);
+}
+
+export function setInferenceExecutionDefaults(configObj) {
+ if (!isObject(configObj)) return;
+ const schema = detectConfigSchema(configObj);
+ if (schema === "legacy") {
+ setPathValue(configObj, ["SYSTEM", "NUM_GPUS"], 1);
+ return;
+ }
+ const existingPath = pickFirstExistingPath(configObj, [
+ ["test", "system", "num_gpus"],
+ ["system", "num_gpus"],
+ ["default", "system", "num_gpus"],
+ ]);
+ if (existingPath) {
+ setPathValue(configObj, existingPath, 1);
+ return;
+ }
+ setPathValue(configObj, ["system", "num_gpus"], 1);
+}
+
+export function applyInputPaths(
+ configObj,
+ { mode, inputImagePath, inputLabelPath, inputPath, outputPath },
+) {
+ if (!isObject(configObj) || !inputImagePath || !inputLabelPath) return;
+ const schema = detectConfigSchema(configObj);
+
+ if (schema === "legacy") {
+ setPathValue(configObj, ["DATASET", "INPUT_PATH"], inputPath);
+ setPathValue(
+ configObj,
+ ["DATASET", "IMAGE_NAME"],
+ inputImagePath.replace(inputPath, ""),
+ );
+ setPathValue(
+ configObj,
+ ["DATASET", "LABEL_NAME"],
+ inputLabelPath.replace(inputPath, ""),
+ );
+ if (outputPath) {
+ if (mode === "training") {
+ setTrainingOutputPath(configObj, outputPath);
+ } else {
+ setInferenceOutputPath(configObj, outputPath);
+ }
+ }
+ return;
+ }
+
+ if (mode === "training") {
+ const imagePath =
+ pickFirstExistingPath(configObj, [
+ ["train", "data", "train", "image"],
+ ["data", "train", "image"],
+ ]) || ["train", "data", "train", "image"];
+ const labelPath =
+ pickFirstExistingPath(configObj, [
+ ["train", "data", "train", "label"],
+ ["data", "train", "label"],
+ ]) || ["train", "data", "train", "label"];
+ setPathValue(configObj, imagePath, inputImagePath);
+ setPathValue(configObj, labelPath, inputLabelPath);
+ if (outputPath) {
+ setTrainingOutputPath(configObj, outputPath);
+ }
+ return;
+ }
+
+ const imagePath =
+ pickFirstExistingPath(configObj, [
+ ["test", "data", "test", "image"],
+ ["data", "test", "image"],
+ ]) || ["test", "data", "test", "image"];
+ const labelPath =
+ pickFirstExistingPath(configObj, [
+ ["test", "data", "test", "label"],
+ ["data", "test", "label"],
+ ]) || ["test", "data", "test", "label"];
+ setPathValue(configObj, imagePath, inputImagePath);
+ setPathValue(configObj, labelPath, inputLabelPath);
+ if (outputPath) {
+ setInferenceOutputPath(configObj, outputPath);
+ }
+}
+
+export function getArchitectureValue(configObj) {
+ if (!isObject(configObj)) return undefined;
+ const path = pickFirstExistingPath(configObj, [
+ ["MODEL", "ARCHITECTURE"],
+ ["model", "arch", "type"],
+ ["default", "model", "arch", "profile"],
+ ]);
+ return path ? getPathValue(configObj, path) : undefined;
+}
+
+export function isArchitectureSupported(configObj) {
+ if (!isObject(configObj)) return false;
+ return Boolean(
+ pickFirstExistingPath(configObj, [
+ ["MODEL", "ARCHITECTURE"],
+ ["model", "arch", "type"],
+ ["default", "model", "arch", "profile"],
+ ]),
+ );
+}
+
+export function setArchitectureValue(configObj, value) {
+ if (!isObject(configObj)) return false;
+ const existingPath = pickFirstExistingPath(configObj, [
+ ["MODEL", "ARCHITECTURE"],
+ ["model", "arch", "type"],
+ ["default", "model", "arch", "profile"],
+ ]);
+ if (existingPath) {
+ setPathValue(configObj, existingPath, value);
+ return true;
+ }
+ return false;
+}
diff --git a/client/src/contexts/GlobalContext.js b/client/src/contexts/GlobalContext.js
index 63169f0..023e9fd 100644
--- a/client/src/contexts/GlobalContext.js
+++ b/client/src/contexts/GlobalContext.js
@@ -121,6 +121,10 @@ export const ContextWrapper = (props) => {
"inferenceConfig",
null,
);
+ const [trainingConfigOriginPath, setTrainingConfigOriginPath] =
+ usePersistedState("trainingConfigOriginPath", "");
+ const [inferenceConfigOriginPath, setInferenceConfigOriginPath] =
+ usePersistedState("inferenceConfigOriginPath", "");
const [uploadedYamlFile, setUploadedYamlFile] = usePersistedState(
"uploadedYamlFile",
"",
@@ -203,12 +207,16 @@ export const ContextWrapper = (props) => {
setViewer,
trainingConfig,
setTrainingConfig,
+ trainingConfigOriginPath,
+ setTrainingConfigOriginPath,
imageFileList,
setImageFileList,
labelFileList,
setLabelFileList,
inferenceConfig,
setInferenceConfig,
+ inferenceConfigOriginPath,
+ setInferenceConfigOriginPath,
uploadedYamlFile,
setUploadedYamlFile,
selectedYamlPreset,
diff --git a/client/src/views/ModelInference.js b/client/src/views/ModelInference.js
index d632923..c8ca506 100644
--- a/client/src/views/ModelInference.js
+++ b/client/src/views/ModelInference.js
@@ -1,44 +1,168 @@
-import React, { useContext, useState } from "react";
+import React, { useContext, useEffect, useRef, useState } from "react";
import { Button, Space } from "antd";
-import { startModelInference, stopModelInference } from "../api";
+import yaml from "js-yaml";
+import {
+ getInferenceLogs,
+ getInferenceStatus,
+ startModelInference,
+ stopModelInference,
+} from "../api";
import Configurator from "../components/Configurator";
+import { applyInputPaths } from "../configSchema";
+import RuntimeLogPanel from "../components/RuntimeLogPanel";
import { AppContext } from "../contexts/GlobalContext";
function ModelInference({ isInferring, setIsInferring }) {
const context = useContext(AppContext);
- // const [isInference, setIsInference] = useState(false)
+ const [inferenceStatus, setInferenceStatus] = useState("");
+ const [inferenceRuntime, setInferenceRuntime] = useState(null);
+ const pollingIntervalRef = useRef(null);
+
+ const getPath = (val) => {
+ if (!val) return "";
+ if (typeof val === "string") return val;
+ return val.path || val.originFileObj?.path || "";
+ };
+
+ const getConfigOriginPath = () => {
+ return (
+ context.inferenceConfigOriginPath ||
+ context.selectedYamlPreset ||
+ getPath(context.uploadedYamlFile)
+ );
+ };
+
+ const refreshInferenceLogs = async () => {
+ try {
+ const runtime = await getInferenceLogs();
+ setInferenceRuntime(runtime);
+ return runtime;
+ } catch (error) {
+ console.error("Error loading inference logs:", error);
+ return null;
+ }
+ };
+
+ const getPreparedInferenceConfig = (inferenceConfig) => {
+ try {
+ const yamlData = yaml.load(inferenceConfig);
+ if (!yamlData || typeof yamlData !== "object") {
+ return inferenceConfig;
+ }
+
+ applyInputPaths(yamlData, {
+ mode: "inference",
+ inputImagePath: getPath(context.inputImage),
+ inputLabelPath: getPath(context.inputLabel),
+ inputPath: "",
+ outputPath: getPath(context.outputPath),
+ });
+ return yaml.dump(yamlData, { indent: 2 }).replace(/^\s*\n/gm, "");
+ } catch (error) {
+ console.warn("Failed to prepare inference config from current inputs:", error);
+ return inferenceConfig;
+ }
+ };
+
+ useEffect(() => {
+ refreshInferenceLogs();
+ }, []);
+
+ useEffect(() => {
+ if (isInferring) {
+ pollingIntervalRef.current = setInterval(async () => {
+ try {
+ const [status, runtime] = await Promise.all([
+ getInferenceStatus(),
+ getInferenceLogs(),
+ ]);
+ setInferenceRuntime(runtime);
+
+ if (!status.isRunning) {
+ setIsInferring(false);
+ if (status.exitCode === 0) {
+ setInferenceStatus("Inference completed successfully! ✓");
+ } else if (status.exitCode !== null && status.exitCode !== undefined) {
+ setInferenceStatus(
+ `Inference finished with exit code: ${status.exitCode}`,
+ );
+ } else if (status.phase === "failed" && status.lastError) {
+ setInferenceStatus(`Inference failed: ${status.lastError}`);
+ } else {
+ setInferenceStatus("Inference stopped.");
+ }
+ }
+ } catch (error) {
+ console.error("Error polling inference status:", error);
+ setIsInferring(false);
+ setInferenceStatus(
+ `Inference status polling failed: ${error.message || "unknown error"}`,
+ );
+ }
+ }, 2000);
+ }
+
+ return () => {
+ if (pollingIntervalRef.current) {
+ clearInterval(pollingIntervalRef.current);
+ pollingIntervalRef.current = null;
+ }
+ };
+ }, [isInferring, setIsInferring]);
+
const handleStartButton = async () => {
try {
+ const inferenceConfig =
+ localStorage.getItem("inferenceConfig") || context.inferenceConfig;
+ if (!inferenceConfig) {
+ setInferenceStatus(
+ "Error: Please load or upload an inference configuration first.",
+ );
+ return;
+ }
+
+ const checkpointPath = getPath(context.checkpointPath);
+ if (!checkpointPath) {
+ setInferenceStatus("Error: Please set checkpoint path first.");
+ return;
+ }
+
setIsInferring(true);
- const inferenceConfig = localStorage.getItem("inferenceConfig");
+ setInferenceStatus("Starting inference...");
- const getPath = (val) => {
- if (!val) return "";
- if (typeof val === "string") return val;
- return val.path || "";
- };
+ const preparedInferenceConfig = getPreparedInferenceConfig(inferenceConfig);
- // const res = startModelInference(
const res = await startModelInference(
- context.uploadedYamlFile.name,
- inferenceConfig,
+ preparedInferenceConfig,
getPath(context.outputPath),
- getPath(context.checkpointPath),
+ checkpointPath,
+ getConfigOriginPath(),
);
console.log(res);
+ await refreshInferenceLogs();
+ setInferenceStatus("Inference started. Monitoring process...");
} catch (e) {
console.log(e);
setIsInferring(false);
+ await refreshInferenceLogs();
+ setInferenceStatus(
+ `Inference error: ${e.message || "Please check console for details."}`,
+ );
}
};
const handleStopButton = async () => {
try {
+ setInferenceStatus("Stopping inference...");
await stopModelInference();
} catch (e) {
console.log(e);
+ setInferenceStatus(
+ `Error stopping inference: ${e.message || "Please check console for details."}`,
+ );
} finally {
setIsInferring(false);
+ await refreshInferenceLogs();
}
};
@@ -63,6 +187,12 @@ function ModelInference({ isInferring, setIsInferring }) {
Stop Inference
+
{inferenceStatus}
+
>
);
diff --git a/client/src/views/ModelTraining.js b/client/src/views/ModelTraining.js
index 14907c2..7af6d57 100644
--- a/client/src/views/ModelTraining.js
+++ b/client/src/views/ModelTraining.js
@@ -1,52 +1,117 @@
// global localStorage
import React, { useContext, useState, useEffect, useRef } from "react";
import { Button, Space } from "antd";
+import yaml from "js-yaml";
import {
+ getTrainingLogs,
startModelTraining,
stopModelTraining,
getTrainingStatus,
} from "../api";
import Configurator from "../components/Configurator";
+import { applyInputPaths } from "../configSchema";
+import RuntimeLogPanel from "../components/RuntimeLogPanel";
import { AppContext } from "../contexts/GlobalContext";
function ModelTraining() {
const context = useContext(AppContext);
const [isTraining, setIsTraining] = useState(false);
const [trainingStatus, setTrainingStatus] = useState("");
+ const [trainingRuntime, setTrainingRuntime] = useState(null);
const pollingIntervalRef = useRef(null);
+ const getPath = (val) => {
+ if (!val) return "";
+ if (typeof val === "string") return val;
+ return val.path || val.originFileObj?.path || "";
+ };
+
+ const getConfigOriginPath = () => {
+ return (
+ context.trainingConfigOriginPath ||
+ context.selectedYamlPreset ||
+ getPath(context.uploadedYamlFile)
+ );
+ };
+
+ const refreshTrainingLogs = async () => {
+ try {
+ const runtime = await getTrainingLogs();
+ setTrainingRuntime(runtime);
+ return runtime;
+ } catch (error) {
+ console.error("Error loading training logs:", error);
+ return null;
+ }
+ };
+
+ const refreshTrainingRuntime = async () => {
+ const [status, runtime] = await Promise.all([
+ getTrainingStatus(),
+ getTrainingLogs(),
+ ]);
+ setTrainingRuntime(runtime);
+ console.log("Training status:", status);
+
+ if (!status.isRunning) {
+ console.log("Training completed!");
+ setIsTraining(false);
+
+ if (status.exitCode === 0) {
+ setTrainingStatus("Training completed successfully! ✓");
+ } else if (status.exitCode !== null) {
+ setTrainingStatus(`Training finished with exit code: ${status.exitCode}`);
+ } else if (status.phase === "failed" && status.lastError) {
+ setTrainingStatus(`Training failed: ${status.lastError}`);
+ } else {
+ setTrainingStatus("Training stopped.");
+ }
+
+ if (pollingIntervalRef.current) {
+ clearInterval(pollingIntervalRef.current);
+ pollingIntervalRef.current = null;
+ }
+ }
+ };
+
+ const getPreparedTrainingConfig = (trainingConfig) => {
+ try {
+ const yamlData = yaml.load(trainingConfig);
+ if (!yamlData || typeof yamlData !== "object") {
+ return trainingConfig;
+ }
+
+ applyInputPaths(yamlData, {
+ mode: "training",
+ inputImagePath: getPath(context.inputImage),
+ inputLabelPath: getPath(context.inputLabel),
+ inputPath: "",
+ outputPath: getPath(context.outputPath),
+ });
+ return yaml.dump(yamlData, { indent: 2 }).replace(/^\s*\n/gm, "");
+ } catch (error) {
+ console.warn("Failed to prepare training config from current inputs:", error);
+ return trainingConfig;
+ }
+ };
+
+ useEffect(() => {
+ refreshTrainingLogs();
+ }, []);
+
// Poll training status when training is active
useEffect(() => {
if (isTraining) {
console.log("Starting training status polling...");
pollingIntervalRef.current = setInterval(async () => {
try {
- const status = await getTrainingStatus();
- console.log("Training status:", status);
-
- if (!status.isRunning) {
- // Training has finished
- console.log("Training completed!");
- setIsTraining(false);
-
- if (status.exitCode === 0) {
- setTrainingStatus("Training completed successfully! ✓");
- } else if (status.exitCode !== null) {
- setTrainingStatus(
- `Training finished with exit code: ${status.exitCode}`,
- );
- } else {
- setTrainingStatus("Training stopped.");
- }
-
- // Clear the polling interval
- if (pollingIntervalRef.current) {
- clearInterval(pollingIntervalRef.current);
- pollingIntervalRef.current = null;
- }
- }
+ await refreshTrainingRuntime();
} catch (error) {
console.error("Error polling training status:", error);
+ setIsTraining(false);
+ setTrainingStatus(
+ `Training status polling failed: ${error.message || "unknown error"}`,
+ );
}
}, 2000); // Poll every 2 seconds
}
@@ -61,13 +126,15 @@ function ModelTraining() {
};
}, [isTraining]);
- // const [tensorboardURL, setTensorboardURL] = useState(null);
const handleStartButton = async () => {
try {
- // TODO: Validate required context values before starting
- if (!context.uploadedYamlFile) {
+ const trainingConfig =
+ localStorage.getItem("trainingConfig") || context.trainingConfig;
+
+ // Accept either uploaded YAML or preset-backed config text.
+ if (!trainingConfig) {
setTrainingStatus(
- "Error: Please upload a YAML configuration file first.",
+ "Error: Please load a preset or upload a YAML configuration first.",
);
return;
}
@@ -77,14 +144,7 @@ function ModelTraining() {
return;
}
- if (!context.logPath) {
- setTrainingStatus("Error: Please set log path first in Step 1.");
- return;
- }
-
console.log(context.uploadedYamlFile);
- const trainingConfig =
- localStorage.getItem("trainingConfig") || context.trainingConfig;
console.log(trainingConfig);
setIsTraining(true);
@@ -92,20 +152,18 @@ function ModelTraining() {
"Starting training... Please wait, this may take a while.",
);
- const getPath = (val) => {
- if (!val) return "";
- if (typeof val === "string") return val;
- return val.path || "";
- };
+ const preparedTrainingConfig = getPreparedTrainingConfig(trainingConfig);
// TODO: The API call should be non-blocking and return immediately
// Real training status should be polled separately
const res = await startModelTraining(
- trainingConfig,
- getPath(context.logPath),
+ preparedTrainingConfig,
+ getPath(context.logPath) || getPath(context.outputPath),
getPath(context.outputPath),
+ getConfigOriginPath(),
);
console.log(res);
+ await refreshTrainingLogs();
// TODO: Don't set training complete here - implement proper status polling
setTrainingStatus(
@@ -113,6 +171,7 @@ function ModelTraining() {
);
} catch (e) {
console.error("Training start error:", e);
+ await refreshTrainingLogs();
setTrainingStatus(
`Training error: ${e.message || "Please check console for details."}`,
);
@@ -131,6 +190,8 @@ function ModelTraining() {
setTrainingStatus(
`Error stopping training: ${e.message || "Please check console for details."}`,
);
+ } finally {
+ await refreshTrainingLogs();
}
};
@@ -161,8 +222,12 @@ function ModelTraining() {
Stop Training
- {/*