From ebcab0569a1a1a0795d5cea1e872f960cbf673b2 Mon Sep 17 00:00:00 2001 From: Pierre Tenedero Date: Wed, 8 Apr 2026 17:44:33 +0800 Subject: [PATCH 1/4] Add face detection feature on video input --- README.md | 10 + services/ws-server/src/main.rs | 21 + services/ws-server/static/app.js | 1079 ++++++++++++++++++++++++++ services/ws-server/static/index.html | 9 + 4 files changed, 1119 insertions(+) diff --git a/README.md b/README.md index 57531f3..7f8eedd 100644 --- a/README.md +++ b/README.md @@ -19,9 +19,18 @@ mise run ws-e2e-chrome ## Run ws agent in browser +### HAR model setup + Download the onnx from https://modelnova.ai/models/details/human-activity-recognition , and save it as `services/ws-server/static/models/human_activity_recognition.onnx` +### Face detection setup + +Download the onnx from https://huggingface.co/amd/retinaface and save it in +`services/ws-server/static/models/` and rename the file to `video_cv.onnx`. + +### Build and run the agent + ```bash mise run build-ws-wasm-agent mise run ws-server @@ -35,6 +44,7 @@ which will normally be something like 192.168.1.x. Then on your phone, open Chrome and type in https://192.168.1.x:8433/ Click "Load HAR model" and then "Start sensors". +For webcam inference, click "Load video CV model" and then "Start video". ## Grant diff --git a/services/ws-server/src/main.rs b/services/ws-server/src/main.rs index 4e78025..3538a1d 100644 --- a/services/ws-server/src/main.rs +++ b/services/ws-server/src/main.rs @@ -177,6 +177,27 @@ impl StreamHandler> for WebSocketActor { action, details, } => { + if capability == "video_cv" && action == "inference" { + let detected_class = details + .get("detected_class") + .and_then(|value| value.as_str()) + .unwrap_or("unknown"); + let confidence = details + .get("confidence") + .and_then(|value| value.as_f64()) + .unwrap_or_default(); + let processed_at = details + .get("processed_at") + .and_then(|value| value.as_str()) + .unwrap_or("unknown"); + info!( + "Video inference received from {}: class={} confidence={:.4} processed_at={}", + self.current_agent_id(), + detected_class, + confidence, + processed_at + ); + } info!( "Client event from {}: capability={} action={} details={}", self.current_agent_id(), diff --git a/services/ws-server/static/app.js b/services/ws-server/static/app.js index 3caa152..43409ac 100644 --- a/services/ws-server/static/app.js +++ b/services/ws-server/static/app.js @@ -24,12 +24,16 @@ const speechButton = document.getElementById("speech-button"); const nfcButton = document.getElementById("nfc-button"); const sensorsButton = document.getElementById("sensors-button"); const harButton = document.getElementById("har-button"); +const videoModelButton = document.getElementById("video-model-button"); +const videoOutputButton = document.getElementById("video-output-button"); const harExportButton = document.getElementById("har-export-button"); const agentStatusEl = document.getElementById("agent-status"); const agentIdEl = document.getElementById("agent-id"); const sensorOutputEl = document.getElementById("sensor-output"); const harOutputEl = document.getElementById("har-output"); +const videoOutputEl = document.getElementById("video-output"); const videoPreview = document.getElementById("video-preview"); +const videoOutputCanvas = document.getElementById("video-output-canvas"); let microphone = null; let videoCapture = null; let bluetoothDevice = null; @@ -46,11 +50,36 @@ let harInferencePending = false; let lastInferenceAt = 0; let harSamplerId = null; let lastHarClassLabel = null; +let videoCvSession = null; +let videoCvInputName = null; +let videoCvOutputName = null; +let videoCvLoopId = null; +let videoCvInferencePending = false; +let lastVideoInferenceAt = 0; +let lastVideoCvLabel = null; +let videoCvCanvas = null; +let videoCvContext = null; +let videoOverlayContext = videoOutputCanvas.getContext("2d"); +let videoOutputVisible = false; +let videoRenderFrameId = null; +let lastVideoInferenceSummary = null; let gravityEstimate = { x: 0, y: 0, z: 0 }; let sendClientEvent = () => {}; const HAR_SEQUENCE_LENGTH = 512; const HAR_FEATURE_COUNT = 9; const HAR_SAMPLE_INTERVAL_MS = 20; +const VIDEO_INFERENCE_INTERVAL_MS = 750; +const VIDEO_RENDER_SCORE_THRESHOLD = 0.35; +const VIDEO_MODEL_PATH = "/static/models/video_cv.onnx"; +const VIDEO_FALLBACK_INPUT_SIZE = 224; +const RETINAFACE_INPUT_HEIGHT = 608; +const RETINAFACE_INPUT_WIDTH = 640; +const RETINAFACE_CONFIDENCE_THRESHOLD = 0.75; +const RETINAFACE_NMS_THRESHOLD = 0.4; +const RETINAFACE_VARIANCES = [0.1, 0.2]; +const RETINAFACE_MIN_SIZES = [[16, 32], [64, 128], [256, 512]]; +const RETINAFACE_STEPS = [8, 16, 32]; +const RETINAFACE_MEAN_BGR = [104, 117, 123]; const STANDARD_GRAVITY = 9.80665; const GRAVITY_FILTER_ALPHA = 0.8; const HAR_CLASS_LABELS = [ @@ -224,6 +253,10 @@ const setHarOutput = (lines) => { harOutputEl.value = Array.isArray(lines) ? lines.join("\n") : String(lines); }; +const setVideoOutput = (lines) => { + videoOutputEl.value = Array.isArray(lines) ? lines.join("\n") : String(lines); +}; + const updateHarStatus = (extraLines = []) => { const lines = [ `model: ${harSession ? "loaded" : "not loaded"}`, @@ -237,6 +270,27 @@ const updateHarStatus = (extraLines = []) => { setHarOutput(lines.concat("", extraLines)); }; +const updateVideoStatus = (extraLines = []) => { + const inputMetadata = videoCvInputName + ? videoCvSession?.inputMetadata?.[videoCvInputName] + : null; + const outputMetadata = videoCvOutputName + ? videoCvSession?.outputMetadata?.[videoCvOutputName] + : null; + const lines = [ + `model: ${videoCvSession ? "loaded" : "not loaded"}`, + `video: ${videoCapture ? "active" : "inactive"}`, + `input: ${videoCvInputName ?? "n/a"}`, + `output: ${videoCvOutputName ?? "n/a"}`, + `input dims: ${JSON.stringify(inputMetadata?.dimensions ?? [])}`, + `output dims: ${JSON.stringify(outputMetadata?.dimensions ?? [])}`, + `loop: ${videoCvLoopId === null ? "idle" : "running"}`, + `display: ${videoOutputVisible ? "visible" : "hidden"}`, + `mode: ${lastVideoInferenceSummary?.mode ?? "unknown"}`, + ]; + setVideoOutput(lines.concat("", extraLines)); +}; + const getFeatureVector = () => { const totalAcceleration = motionState?.accelerationIncludingGravity ?? { x: 0, y: 0, z: 0 }; const bodyAcceleration = { @@ -448,11 +502,981 @@ const requestSensorPermission = async (permissionTarget) => { return permissionTarget.requestPermission(); }; +const getTopK = (values, limit = 3) => { + return values + .map((value, index) => ({ value, index })) + .sort((left, right) => right.value - left.value) + .slice(0, limit); +}; + +const ensureVideoCvCanvas = () => { + if (!videoCvCanvas) { + videoCvCanvas = document.createElement("canvas"); + videoCvContext = videoCvCanvas.getContext("2d", { willReadFrequently: true }); + } + + if (!videoCvContext) { + throw new Error("Unable to create 2D canvas context for video preprocessing."); + } + + return videoCvContext; +}; + +const ensureVideoOverlayContext = () => { + if (!videoOverlayContext) { + videoOverlayContext = videoOutputCanvas.getContext("2d"); + } + + if (!videoOverlayContext) { + throw new Error("Unable to create video output canvas context."); + } + + return videoOverlayContext; +}; + +const isRetinaFaceSession = (session = videoCvSession) => { + if (!session) { + return false; + } + + const inputNames = Array.isArray(session.inputNames) ? session.inputNames : []; + const outputNames = Array.isArray(session.outputNames) ? session.outputNames : []; + const allNames = inputNames.concat(outputNames).map((name) => String(name).toLowerCase()); + if (allNames.some((name) => name.includes("retinaface"))) { + return true; + } + + return outputNames.length === 3 && inputNames.length === 1; +}; + +const selectVideoModelInputName = (session) => { + const inputNames = Array.isArray(session?.inputNames) ? session.inputNames : []; + if (!inputNames.length) { + return null; + } + + const ranked = inputNames + .map((name) => { + const metadata = session?.inputMetadata?.[name]; + const dimensions = Array.isArray(metadata?.dimensions) ? metadata.dimensions : []; + const normalizedName = String(name).toLowerCase(); + let score = 0; + + if (dimensions.length === 4) { + score += 100; + } else if (dimensions.length === 3) { + score += 40; + } + + if ( + normalizedName.includes("pixel") + || normalizedName.includes("image") + || normalizedName.includes("images") + || normalizedName.includes("input") + ) { + score += 25; + } + + if (normalizedName.includes("mask") || normalizedName.includes("token")) { + score -= 50; + } + + return { name, score }; + }) + .sort((left, right) => right.score - left.score); + + return ranked[0]?.name ?? inputNames[0]; +}; + +const selectVideoModelOutputName = (session) => { + const outputNames = Array.isArray(session?.outputNames) ? session.outputNames : []; + if (!outputNames.length) { + return null; + } + + const ranked = outputNames + .map((name) => { + const normalizedName = String(name).toLowerCase(); + let score = 0; + if (normalizedName.includes("box")) { + score += 100; + } + if (normalizedName.includes("logit") || normalizedName.includes("score")) { + score += 40; + } + return { name, score }; + }) + .sort((left, right) => right.score - left.score); + + return ranked[0]?.name ?? outputNames[0]; +}; + +const resolveVideoModelLayout = () => { + if (!videoCvSession || !videoCvInputName) { + throw new Error("Video CV model is not loaded."); + } + + if (isRetinaFaceSession(videoCvSession)) { + return { + dataType: "float32", + channels: 3, + width: RETINAFACE_INPUT_WIDTH, + height: RETINAFACE_INPUT_HEIGHT, + tensorDimensions: [1, RETINAFACE_INPUT_HEIGHT, RETINAFACE_INPUT_WIDTH, 3], + layout: "nhwc", + profile: "retinaface", + }; + } + + const metadata = videoCvSession.inputMetadata?.[videoCvInputName]; + const dataType = metadata?.type ?? "float32"; + if (dataType !== "float32" && dataType !== "uint8") { + throw new Error(`Unsupported video model input type: ${dataType}`); + } + + const rawDimensions = Array.isArray(metadata?.dimensions) + ? metadata.dimensions + : []; + const dimensions = rawDimensions.length === 4 + ? rawDimensions + : rawDimensions.length === 3 + ? [1, ...rawDimensions] + : [1, 3, VIDEO_FALLBACK_INPUT_SIZE, VIDEO_FALLBACK_INPUT_SIZE]; + + const resolved = dimensions.map((dimension, index) => { + if (typeof dimension === "number" && Number.isFinite(dimension) && dimension > 0) { + return dimension; + } + + if (index === 0) { + return 1; + } + + if (index === 1 && dimensions.length === 4) { + const inputName = String(videoCvInputName).toLowerCase(); + if (!inputName.includes("nhwc")) { + return 3; + } + } + + return VIDEO_FALLBACK_INPUT_SIZE; + }); + + const secondDimension = resolved[1]; + const lastDimension = resolved[3]; + const inputName = String(videoCvInputName).toLowerCase(); + const channelsFirst = inputName.includes("nhwc") + ? false + : secondDimension === 1 + || secondDimension === 3 + || ((lastDimension !== 1 && lastDimension !== 3) && !inputName.includes("image_embeddings")); + if (channelsFirst) { + const [, channels, height, width] = resolved; + if (channels !== 1 && channels !== 3) { + throw new Error(`Unsupported channel count for NCHW image input: ${channels}`); + } + + return { + dataType, + channels, + width, + height, + tensorDimensions: [1, channels, height, width], + layout: "nchw", + profile: "generic", + }; + } + + const [, height, width, channels] = resolved; + if (channels !== 1 && channels !== 3) { + throw new Error(`Unsupported channel count for NHWC image input: ${channels}`); + } + + return { + dataType, + channels, + width, + height, + tensorDimensions: [1, height, width, channels], + layout: "nhwc", + profile: "generic", + }; +}; + +const buildVideoInputTensor = () => { + if (!videoCapture || !videoCvSession || !videoCvInputName) { + throw new Error("Video capture or model session is unavailable."); + } + + if (!videoPreview.videoWidth || !videoPreview.videoHeight) { + throw new Error("Video stream is not ready yet."); + } + + const { + dataType, + channels, + width, + height, + tensorDimensions, + layout, + profile, + } = resolveVideoModelLayout(); + const context = ensureVideoCvCanvas(); + videoCvCanvas.width = width; + videoCvCanvas.height = height; + let resizeRatio = 1; + if (profile === "retinaface") { + const sourceWidth = videoPreview.videoWidth; + const sourceHeight = videoPreview.videoHeight; + const targetRatio = height / width; + if (sourceHeight / sourceWidth <= targetRatio) { + resizeRatio = width / sourceWidth; + } else { + resizeRatio = height / sourceHeight; + } + + const resizedWidth = Math.max(1, Math.min(width, Math.round(sourceWidth * resizeRatio))); + const resizedHeight = Math.max(1, Math.min(height, Math.round(sourceHeight * resizeRatio))); + context.clearRect(0, 0, width, height); + context.drawImage(videoPreview, 0, 0, resizedWidth, resizedHeight); + } else { + context.drawImage(videoPreview, 0, 0, width, height); + } + + const rgba = context.getImageData(0, 0, width, height).data; + const elementCount = width * height * channels; + const tensorData = dataType === "uint8" + ? new Uint8Array(elementCount) + : new Float32Array(elementCount); + + for (let pixelIndex = 0; pixelIndex < width * height; pixelIndex += 1) { + const rgbaIndex = pixelIndex * 4; + const red = rgba[rgbaIndex]; + const green = rgba[rgbaIndex + 1]; + const blue = rgba[rgbaIndex + 2]; + + if (profile === "retinaface") { + const tensorIndex = pixelIndex * channels; + tensorData[tensorIndex] = blue - RETINAFACE_MEAN_BGR[0]; + tensorData[tensorIndex + 1] = green - RETINAFACE_MEAN_BGR[1]; + tensorData[tensorIndex + 2] = red - RETINAFACE_MEAN_BGR[2]; + continue; + } + + if (channels === 1) { + const grayscale = Math.round(0.299 * red + 0.587 * green + 0.114 * blue); + tensorData[pixelIndex] = dataType === "uint8" ? grayscale : grayscale / 255; + continue; + } + + if (layout === "nchw") { + const planeSize = width * height; + if (dataType === "uint8") { + tensorData[pixelIndex] = red; + tensorData[pixelIndex + planeSize] = green; + tensorData[pixelIndex + 2 * planeSize] = blue; + } else { + tensorData[pixelIndex] = red / 255; + tensorData[pixelIndex + planeSize] = green / 255; + tensorData[pixelIndex + 2 * planeSize] = blue / 255; + } + continue; + } + + const tensorIndex = pixelIndex * channels; + if (dataType === "uint8") { + tensorData[tensorIndex] = red; + tensorData[tensorIndex + 1] = green; + tensorData[tensorIndex + 2] = blue; + } else { + tensorData[tensorIndex] = red / 255; + tensorData[tensorIndex + 1] = green / 255; + tensorData[tensorIndex + 2] = blue / 255; + } + } + + return { + tensor: new window.ort.Tensor(dataType, tensorData, tensorDimensions), + preprocess: { + profile, + inputWidth: width, + inputHeight: height, + resizeRatio, + sourceWidth: videoPreview.videoWidth, + sourceHeight: videoPreview.videoHeight, + }, + }; +}; + +const looksLikeBoxes = (tensor) => { + if (!tensor?.dims || !tensor?.data) { + return false; + } + + const dims = tensor.dims.filter((dimension) => Number.isFinite(dimension)); + const values = Array.from(tensor.data ?? []); + const lastDimension = dims[dims.length - 1]; + return values.length >= 4 && (lastDimension === 4 || lastDimension === 6 || lastDimension === 7); +}; + +const flattenFinite = (tensor) => { + return Array.from(tensor?.data ?? []).map(Number).filter((value) => Number.isFinite(value)); +}; + +const normalizeBox = (boxValues, format = "xyxy") => { + if (boxValues.length < 4) { + return null; + } + + let x1; + let y1; + let x2; + let y2; + if (format === "cxcywh") { + const [centerX, centerY, width, height] = boxValues; + x1 = centerX - width / 2; + y1 = centerY - height / 2; + x2 = centerX + width / 2; + y2 = centerY + height / 2; + } else { + [x1, y1, x2, y2] = boxValues; + } + + if (x2 < x1) { + [x1, x2] = [x2, x1]; + } + if (y2 < y1) { + [y1, y2] = [y2, y1]; + } + + const normalized = [x1, y1, x2, y2].map((value) => ( + value > 1.5 ? value : Math.max(0, Math.min(1, value)) + )); + + return normalized; +}; + +const clamp = (value, min, max) => Math.max(min, Math.min(max, value)); + +const buildRetinaFacePriors = (imageHeight, imageWidth) => { + const priors = []; + RETINAFACE_STEPS.forEach((step, index) => { + const featureMapHeight = Math.ceil(imageHeight / step); + const featureMapWidth = Math.ceil(imageWidth / step); + const minSizes = RETINAFACE_MIN_SIZES[index]; + + for (let row = 0; row < featureMapHeight; row += 1) { + for (let column = 0; column < featureMapWidth; column += 1) { + minSizes.forEach((minSize) => { + priors.push([ + ((column + 0.5) * step) / imageWidth, + ((row + 0.5) * step) / imageHeight, + minSize / imageWidth, + minSize / imageHeight, + ]); + }); + } + } + }); + return priors; +}; + +const decodeRetinaFaceBox = (loc, prior) => { + const centerX = prior[0] + loc[0] * RETINAFACE_VARIANCES[0] * prior[2]; + const centerY = prior[1] + loc[1] * RETINAFACE_VARIANCES[0] * prior[3]; + const width = prior[2] * Math.exp(loc[2] * RETINAFACE_VARIANCES[1]); + const height = prior[3] * Math.exp(loc[3] * RETINAFACE_VARIANCES[1]); + return [ + centerX - width / 2, + centerY - height / 2, + centerX + width / 2, + centerY + height / 2, + ]; +}; + +const computeIoU = (left, right) => { + const x1 = Math.max(left.box[0], right.box[0]); + const y1 = Math.max(left.box[1], right.box[1]); + const x2 = Math.min(left.box[2], right.box[2]); + const y2 = Math.min(left.box[3], right.box[3]); + const width = Math.max(0, x2 - x1 + 1); + const height = Math.max(0, y2 - y1 + 1); + const intersection = width * height; + const leftArea = Math.max(0, left.box[2] - left.box[0] + 1) * Math.max(0, left.box[3] - left.box[1] + 1); + const rightArea = Math.max(0, right.box[2] - right.box[0] + 1) * Math.max(0, right.box[3] - right.box[1] + 1); + return intersection / Math.max(1e-6, leftArea + rightArea - intersection); +}; + +const applyNms = (detections, threshold) => { + const sorted = [...detections].sort((left, right) => right.score - left.score); + const kept = []; + + sorted.forEach((candidate) => { + if (kept.every((accepted) => computeIoU(candidate, accepted) <= threshold)) { + kept.push(candidate); + } + }); + + return kept; +}; + +const decodeRetinaFaceOutputs = (outputs, preprocess) => { + if (!preprocess || preprocess.profile !== "retinaface") { + return null; + } + + const outputNames = Array.isArray(videoCvSession?.outputNames) ? videoCvSession.outputNames : []; + if (outputNames.length < 3) { + return null; + } + + const locTensor = outputs[outputNames[0]]; + const confTensor = outputs[outputNames[1]]; + const landmTensor = outputs[outputNames[2]]; + if (!locTensor || !confTensor || !landmTensor) { + return null; + } + + const locValues = flattenFinite(locTensor); + const confValues = flattenFinite(confTensor); + const landmValues = flattenFinite(landmTensor); + const priorCount = locValues.length / 4; + if (priorCount <= 0 || confValues.length / 2 !== priorCount || landmValues.length / 10 !== priorCount) { + return null; + } + + const priors = buildRetinaFacePriors(preprocess.inputHeight, preprocess.inputWidth); + if (priors.length !== priorCount) { + return null; + } + + const detections = []; + for (let index = 0; index < priorCount; index += 1) { + const score = softmax(confValues.slice(index * 2, index * 2 + 2))[1] ?? 0; + if (score < RETINAFACE_CONFIDENCE_THRESHOLD) { + continue; + } + + const decoded = decodeRetinaFaceBox( + locValues.slice(index * 4, index * 4 + 4), + priors[index], + ); + const scaledBox = [ + clamp((decoded[0] * preprocess.inputWidth) / preprocess.resizeRatio, 0, preprocess.sourceWidth), + clamp((decoded[1] * preprocess.inputHeight) / preprocess.resizeRatio, 0, preprocess.sourceHeight), + clamp((decoded[2] * preprocess.inputWidth) / preprocess.resizeRatio, 0, preprocess.sourceWidth), + clamp((decoded[3] * preprocess.inputHeight) / preprocess.resizeRatio, 0, preprocess.sourceHeight), + ]; + + detections.push({ + label: "face", + class_index: 0, + score, + box: scaledBox, + }); + } + + const filtered = applyNms(detections, RETINAFACE_NMS_THRESHOLD); + if (!filtered.length) { + return { + mode: "detection", + detections: [], + detected_class: "no_detection", + class_index: -1, + confidence: 0, + probabilities: [], + top_classes: [], + }; + } + + const best = filtered[0]; + return { + mode: "detection", + detections: filtered, + detected_class: best.label, + class_index: best.class_index, + confidence: best.score, + probabilities: filtered.map((entry) => entry.score), + top_classes: filtered.slice(0, 3).map((entry) => ({ + label: entry.label, + index: entry.class_index, + probability: entry.score, + })), + }; +}; + +const findDetectionTensor = (entries, patterns, predicate = () => true) => { + return entries.find(([name, tensor]) => { + const normalizedName = String(name).toLowerCase(); + return patterns.some((pattern) => pattern.test(normalizedName)) && predicate(tensor); + }) ?? null; +}; + +const decodeHuggingFaceDetectionOutputs = (entries) => { + const boxesEntry = findDetectionTensor( + entries, + [/pred_boxes/, /boxes?/, /bbox/], + (tensor) => (Array.isArray(tensor?.dims) ? tensor.dims[tensor.dims.length - 1] : null) === 4, + ); + const logitsEntry = findDetectionTensor( + entries, + [/logits/, /scores?/, /class/], + (tensor) => (Array.isArray(tensor?.dims) ? tensor.dims[tensor.dims.length - 1] : 0) > 1, + ); + + if (!boxesEntry || !logitsEntry) { + return null; + } + + const [boxesName, boxesTensor] = boxesEntry; + const [, logitsTensor] = logitsEntry; + const rawBoxes = flattenFinite(boxesTensor); + const rawLogits = flattenFinite(logitsTensor); + const boxCount = Math.floor(rawBoxes.length / 4); + const classCount = boxCount > 0 ? Math.floor(rawLogits.length / boxCount) : 0; + if (boxCount <= 0 || classCount <= 1) { + return null; + } + + const usesCenterBoxes = /pred_boxes/.test(String(boxesName).toLowerCase()); + const detections = []; + for (let index = 0; index < boxCount; index += 1) { + const box = rawBoxes.slice(index * 4, index * 4 + 4); + const logits = rawLogits.slice(index * classCount, index * classCount + classCount); + const candidateLogits = logits.length > 1 ? logits.slice(0, -1) : logits; + const probabilities = softmax(candidateLogits); + const best = getTopK(probabilities, 1)[0]; + if (!best || best.value < VIDEO_RENDER_SCORE_THRESHOLD) { + continue; + } + + const normalizedBox = normalizeBox(box, usesCenterBoxes ? "cxcywh" : "xyxy"); + if (!normalizedBox) { + continue; + } + + detections.push({ + label: `class_${best.index}`, + class_index: best.index, + score: best.value, + box: normalizedBox, + }); + } + + if (!detections.length) { + return { + mode: "detection", + detections: [], + detected_class: "no_detection", + class_index: -1, + confidence: 0, + probabilities: [], + top_classes: [], + }; + } + + detections.sort((left, right) => right.score - left.score); + const best = detections[0]; + return { + mode: "detection", + detections, + detected_class: best.label, + class_index: best.class_index, + confidence: best.score, + probabilities: detections.map((entry) => entry.score), + top_classes: detections.slice(0, 3).map((entry) => ({ + label: entry.label, + index: entry.class_index, + probability: entry.score, + })), + }; +}; + +const decodeDetectionOutputs = (outputs) => { + const entries = Object.entries(outputs); + const huggingFaceSummary = decodeHuggingFaceDetectionOutputs(entries); + if (huggingFaceSummary) { + return huggingFaceSummary; + } + + const boxesEntry = entries.find(([, tensor]) => looksLikeBoxes(tensor)); + + if (!boxesEntry) { + return null; + } + + const [boxesName, boxesTensor] = boxesEntry; + const boxDims = Array.isArray(boxesTensor.dims) ? boxesTensor.dims : []; + const rawBoxes = flattenFinite(boxesTensor); + const boxWidth = boxDims[boxDims.length - 1] ?? 4; + const detectionCount = Math.floor(rawBoxes.length / boxWidth); + if (detectionCount <= 0) { + return null; + } + + const scoresEntry = entries.find(([name, tensor]) => + name !== boxesName && flattenFinite(tensor).length >= detectionCount + ); + const classEntry = entries.find(([name, tensor]) => + name !== boxesName && name !== scoresEntry?.[0] && flattenFinite(tensor).length >= detectionCount + ); + const detections = []; + const scoreValues = scoresEntry ? flattenFinite(scoresEntry[1]) : []; + const classValues = classEntry ? flattenFinite(classEntry[1]) : []; + + for (let index = 0; index < detectionCount; index += 1) { + const start = index * boxWidth; + const row = rawBoxes.slice(start, start + boxWidth); + const normalizedBox = normalizeBox(row); + if (!normalizedBox) { + continue; + } + + let score = Number(scoreValues[index] ?? row[4] ?? row[5] ?? 1); + if (!Number.isFinite(score)) { + score = 1; + } + + let classIndex = classValues[index]; + if (!Number.isFinite(classIndex)) { + classIndex = row.length >= 6 ? row[5] : row.length >= 7 ? row[6] : index; + } + + if (score < VIDEO_RENDER_SCORE_THRESHOLD) { + continue; + } + + detections.push({ + label: `class_${Math.round(classIndex)}`, + class_index: Math.round(classIndex), + score, + box: normalizedBox, + }); + } + + if (!detections.length) { + return { + mode: "detection", + detections: [], + detected_class: "no_detection", + class_index: -1, + confidence: 0, + probabilities: [], + top_classes: [], + }; + } + + detections.sort((left, right) => right.score - left.score); + const best = detections[0]; + return { + mode: "detection", + detections, + detected_class: best.label, + class_index: best.class_index, + confidence: best.score, + probabilities: detections.map((entry) => entry.score), + top_classes: detections.slice(0, 3).map((entry) => ({ + label: entry.label, + index: entry.class_index, + probability: entry.score, + })), + }; +}; + +const decodeClassificationOutputs = (output) => { + const values = Array.from(output?.data ?? []); + if (values.length === 0) { + throw new Error("Video model returned an empty output tensor."); + } + + if (values.length === 1) { + return { + mode: "classification", + detections: [], + detected_class: "scalar_output", + class_index: 0, + confidence: Number(values[0]), + probabilities: values, + top_classes: [{ label: "scalar_output", index: 0, probability: Number(values[0]) }], + }; + } + + const probabilities = softmax(values); + const ranked = getTopK(probabilities, 3); + const best = ranked[0]; + + return { + mode: "classification", + detections: [], + detected_class: `class_${best.index}`, + class_index: best.index, + confidence: best.value, + probabilities, + top_classes: ranked.map(({ index, value }) => ({ + label: `class_${index}`, + index, + probability: value, + logit: values[index], + })), + }; +}; + +const summarizeVideoOutput = (outputMap, preprocess = null) => { + const retinaFaceSummary = decodeRetinaFaceOutputs(outputMap, preprocess); + if (retinaFaceSummary) { + return retinaFaceSummary; + } + + const detectionSummary = decodeDetectionOutputs(outputMap); + if (detectionSummary) { + return detectionSummary; + } + + const primaryOutput = outputMap[videoCvOutputName]; + const primaryValues = Array.from(primaryOutput?.data ?? []); + if (primaryValues.length > 0 && primaryValues.length <= 4096) { + return decodeClassificationOutputs(primaryOutput); + } + + return { + mode: "passthrough", + detections: [], + detected_class: "unrecognized_output", + class_index: -1, + confidence: 0, + probabilities: [], + top_classes: [], + }; +}; + +const drawOverlayText = (context, lines) => { + if (!lines.length) { + return; + } + + context.font = "18px ui-monospace, monospace"; + const lineHeight = 24; + const width = Math.max(...lines.map((line) => context.measureText(line).width), 0) + 20; + const height = lines.length * lineHeight + 12; + context.fillStyle = "rgba(24, 32, 40, 0.72)"; + context.fillRect(12, 12, width, height); + context.fillStyle = "#fffdfa"; + lines.forEach((line, index) => { + context.fillText(line, 22, 36 + index * lineHeight); + }); +}; + +const renderVideoOutputFrame = () => { + videoRenderFrameId = null; + + if (!videoOutputVisible || !videoCapture || !videoPreview.videoWidth || !videoPreview.videoHeight) { + return; + } + + const context = ensureVideoOverlayContext(); + const width = videoPreview.videoWidth; + const height = videoPreview.videoHeight; + if (videoOutputCanvas.width !== width || videoOutputCanvas.height !== height) { + videoOutputCanvas.width = width; + videoOutputCanvas.height = height; + } + + context.drawImage(videoPreview, 0, 0, width, height); + + if (lastVideoInferenceSummary?.mode === "detection") { + context.lineWidth = 3; + context.font = "16px ui-monospace, monospace"; + lastVideoInferenceSummary.detections.forEach((entry) => { + const [x1, y1, x2, y2] = entry.box; + const left = x1 <= 1 ? x1 * width : x1; + const top = y1 <= 1 ? y1 * height : y1; + const right = x2 <= 1 ? x2 * width : x2; + const bottom = y2 <= 1 ? y2 * height : y2; + const boxWidth = Math.max(1, right - left); + const boxHeight = Math.max(1, bottom - top); + + context.strokeStyle = "#ef8f35"; + context.strokeRect(left, top, boxWidth, boxHeight); + + const label = `${entry.label} ${(entry.score * 100).toFixed(1)}%`; + const textWidth = context.measureText(label).width + 10; + context.fillStyle = "#182028"; + context.fillRect(left, Math.max(0, top - 24), textWidth, 22); + context.fillStyle = "#fffdfa"; + context.fillText(label, left + 5, Math.max(16, top - 8)); + }); + } else if (lastVideoInferenceSummary?.mode === "classification") { + drawOverlayText(context, [ + `classification: ${lastVideoInferenceSummary.detected_class}`, + `confidence: ${(lastVideoInferenceSummary.confidence * 100).toFixed(1)}%`, + ]); + } else if (lastVideoInferenceSummary?.mode === "passthrough") { + drawOverlayText(context, [ + "output mode: passthrough", + "model output not recognized as detection or classification", + ]); + } + + videoRenderFrameId = window.requestAnimationFrame(renderVideoOutputFrame); +}; + +const syncVideoOutputView = () => { + videoOutputCanvas.hidden = !videoOutputVisible || !videoCapture; + videoOutputButton.textContent = videoOutputVisible ? "Hide video output" : "Show video output"; + + if (!videoOutputVisible || !videoCapture) { + if (videoRenderFrameId !== null) { + window.cancelAnimationFrame(videoRenderFrameId); + videoRenderFrameId = null; + } + updateVideoStatus(); + return; + } + + if (videoRenderFrameId === null) { + videoRenderFrameId = window.requestAnimationFrame(renderVideoOutputFrame); + } + updateVideoStatus(); +}; + +const stopVideoCvLoop = () => { + if (videoCvLoopId !== null) { + window.clearInterval(videoCvLoopId); + videoCvLoopId = null; + } + lastVideoCvLabel = null; + updateVideoStatus(); +}; + +const inferVideoPrediction = async () => { + if ( + !videoCapture + || !videoCvSession + || !videoCvInputName + || !videoCvOutputName + || videoCvInferencePending + ) { + return; + } + + const now = Date.now(); + if (now - lastVideoInferenceAt < VIDEO_INFERENCE_INTERVAL_MS) { + return; + } + + videoCvInferencePending = true; + lastVideoInferenceAt = now; + + try { + const { tensor: input, preprocess } = buildVideoInputTensor(); + const outputMap = await videoCvSession.run({ [videoCvInputName]: input }); + const output = outputMap[videoCvOutputName]; + const summary = summarizeVideoOutput(outputMap, preprocess); + const labelChanged = summary.detected_class !== lastVideoCvLabel; + lastVideoCvLabel = summary.detected_class; + lastVideoInferenceSummary = summary; + + updateVideoStatus([ + `output mode: ${summary.mode}`, + `prediction: ${summary.detected_class}`, + `confidence: ${summary.confidence.toFixed(4)}`, + ...( + summary.mode === "detection" + ? [ + `detections: ${summary.detections.length}`, + ...summary.detections.slice(0, 3).map( + (entry) => + `${entry.label}: score=${entry.score.toFixed(4)} box=${ + entry.box.map((value) => value.toFixed(3)).join(",") + }`, + ), + ] + : [ + "top classes:", + ...summary.top_classes.map( + (entry) => + `${entry.label}: p=${entry.probability.toFixed(4)} logit=${ + Number(entry.logit ?? entry.probability).toFixed(4) + }`, + ), + ] + ), + `frame: ${videoPreview.videoWidth}x${videoPreview.videoHeight}`, + `processed at: ${new Date().toLocaleTimeString()}`, + ]); + syncVideoOutputView(); + + sendClientEvent("video_cv", "inference", { + mode: summary.mode, + detected_class: summary.detected_class, + class_index: summary.class_index, + confidence: summary.confidence, + probabilities: summary.probabilities, + top_classes: summary.top_classes, + detections: summary.detections, + changed: labelChanged, + processed_at: new Date().toISOString(), + model_path: VIDEO_MODEL_PATH, + input_name: videoCvInputName, + output_name: videoCvOutputName, + input_dimensions: videoCvSession.inputMetadata?.[videoCvInputName]?.dimensions ?? [], + output_dimensions: Array.isArray(output?.dims) ? output.dims : [], + source_resolution: { + width: videoPreview.videoWidth, + height: videoPreview.videoHeight, + }, + }); + } catch (error) { + lastVideoInferenceSummary = { + mode: "passthrough", + detections: [], + detected_class: "inference_error", + class_index: -1, + confidence: 0, + probabilities: [], + top_classes: [], + }; + updateVideoStatus([ + `inference error: ${error instanceof Error ? error.message : String(error)}`, + ]); + console.error(error); + } finally { + videoCvInferencePending = false; + } +}; + +const syncVideoCvLoop = () => { + if (videoCapture && videoCvSession) { + if (videoCvLoopId === null) { + videoCvLoopId = window.setInterval(() => { + void inferVideoPrediction(); + }, VIDEO_INFERENCE_INTERVAL_MS); + } + updateVideoStatus([ + "browser-side webcam inference active", + "results are sent to the backend over the websocket.", + ]); + return; + } + + stopVideoCvLoop(); + lastVideoInferenceSummary = null; + updateVideoStatus([ + videoCvSession + ? "model loaded; start video capture to begin inference." + : `model file: ${VIDEO_MODEL_PATH}`, + ]); +}; + renderSensorOutput(); updateHarStatus([ "local-only inference path", "model file: /static/models/human_activity_recognition.onnx", ]); +updateVideoStatus([ + `model file: ${VIDEO_MODEL_PATH}`, + "load the model, then start video capture to process frames in-browser.", +]); harExportButton.addEventListener("click", () => { try { @@ -570,6 +1594,8 @@ try { videoPreview.hidden = true; videoButton.textContent = "Start video"; delete window.videoCapture; + syncVideoCvLoop(); + syncVideoOutputView(); append("video stopped"); sendClientEvent("video", "stopped", { track_count: 0 }); return; @@ -581,6 +1607,8 @@ try { videoButton.textContent = "Stop video"; append(`video granted: ${videoCapture.trackCount()} video track(s)`); window.videoCapture = videoCapture; + syncVideoCvLoop(); + syncVideoOutputView(); sendClientEvent("video", "started", { track_count: videoCapture.trackCount(), }); @@ -882,6 +1910,57 @@ try { } }); + videoModelButton.addEventListener("click", async () => { + try { + if (!window.ort) { + throw new Error("onnxruntime-web did not load."); + } + + configureOnnxRuntimeWasm(); + + videoModelButton.disabled = true; + videoModelButton.textContent = "Loading video model..."; + updateVideoStatus(["loading model..."]); + + videoCvSession = await window.ort.InferenceSession.create( + VIDEO_MODEL_PATH, + { + executionProviders: ["wasm"], + }, + ); + + videoCvInputName = selectVideoModelInputName(videoCvSession); + videoCvOutputName = selectVideoModelOutputName(videoCvSession); + lastVideoCvLabel = null; + lastVideoInferenceSummary = null; + append( + `video cv model loaded: input=${videoCvInputName} output=${videoCvOutputName} input_dims=${ + JSON.stringify(videoCvSession.inputMetadata?.[videoCvInputName]?.dimensions ?? []) + }`, + ); + syncVideoCvLoop(); + } catch (error) { + videoCvSession = null; + videoCvInputName = null; + videoCvOutputName = null; + stopVideoCvLoop(); + lastVideoInferenceSummary = null; + updateVideoStatus([ + `model load error: ${error instanceof Error ? error.message : String(error)}`, + ]); + append(`video cv error: ${error instanceof Error ? error.message : String(error)}`); + console.error(error); + } finally { + videoModelButton.disabled = false; + videoModelButton.textContent = videoCvSession ? "Reload video CV model" : "Load video CV model"; + } + }); + + videoOutputButton.addEventListener("click", () => { + videoOutputVisible = !videoOutputVisible; + syncVideoOutputView(); + }); + window.client = client; window.sendAlive = () => client.send_alive(); } catch (error) { diff --git a/services/ws-server/static/index.html b/services/ws-server/static/index.html index 630db76..2ccf512 100644 --- a/services/ws-server/static/index.html +++ b/services/ws-server/static/index.html @@ -138,14 +138,23 @@

WASM web agent

+ +

+
+
Booting…
From e08c07163de1b8482d43b9e815903138de441dce Mon Sep 17 00:00:00 2001 From: Pierre Tenedero Date: Thu, 9 Apr 2026 17:08:55 +0800 Subject: [PATCH 2/4] Create loadable face detection module --- .mise.toml | 7 +- Cargo.toml | 8 +- README.md | 8 +- services/ws-modules/face-detection/Cargo.toml | 31 + services/ws-modules/face-detection/src/lib.rs | 849 ++++++++++++++++++ services/ws-server/static/app.js | 68 +- services/ws-server/static/index.html | 8 +- 7 files changed, 935 insertions(+), 44 deletions(-) create mode 100644 services/ws-modules/face-detection/Cargo.toml create mode 100644 services/ws-modules/face-detection/src/lib.rs diff --git a/.mise.toml b/.mise.toml index 586e2c3..6d21ef5 100644 --- a/.mise.toml +++ b/.mise.toml @@ -90,8 +90,13 @@ description = "Build the har1 workflow WASM module" dir = "services/ws-modules/har1" run = "wasm-pack build . --target web" +[tasks.build-ws-face-detection-module] +description = "Build the face detection workflow WASM module" +dir = "services/ws-modules/face-detection" +run = "wasm-pack build . --target web" + [tasks.build] -depends = ["build-ws-har1-module", "build-ws-wasm-agent"] +depends = ["build-ws-face-detection-module", "build-ws-har1-module", "build-ws-wasm-agent"] description = "Build all WebAssembly modules" [tasks.test-ws-wasm-agent-firefox] diff --git a/Cargo.toml b/Cargo.toml index 026f1b9..7eeb046 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,11 @@ [workspace] -members = ["libs/edge-toolkit", "services/ws-modules/har1", "services/ws-server", "services/ws-wasm-agent"] +members = [ + "libs/edge-toolkit", + "services/ws-modules/face-detection", + "services/ws-modules/har1", + "services/ws-server", + "services/ws-wasm-agent", +] resolver = "2" [workspace.dependencies] diff --git a/README.md b/README.md index e52d4fd..c3f5ce7 100644 --- a/README.md +++ b/README.md @@ -26,14 +26,16 @@ and save it as `services/ws-server/static/models/human_activity_recognition.onnx ### Face detection setup -Download the onnx from https://huggingface.co/amd/retinaface and save it in -`services/ws-server/static/models/` and rename the file to `video_cv.onnx`. +1. Download RetinaFace_int.onnx from https://huggingface.co/amd/retinaface/tree/main/weights +2. Save it in `services/ws-server/static/models/` +3. Rename the file to `video_cv.onnx`. ### Build and run the agent ```bash mise run build-ws-wasm-agent mise run build-ws-har1-module +mise run build-ws-face-detection-module mise run ws-server ``` @@ -46,7 +48,7 @@ Then on your phone, open Chrome and type in https://192.168.1.x:8433/ Click "har demo". -For webcam inference, click "Load video CV model" and then "Start video". +For webcam inference, click "face demo". ## Grant diff --git a/services/ws-modules/face-detection/Cargo.toml b/services/ws-modules/face-detection/Cargo.toml new file mode 100644 index 0000000..59725b4 --- /dev/null +++ b/services/ws-modules/face-detection/Cargo.toml @@ -0,0 +1,31 @@ +[package] +name = "et-ws-face-detection" +version = "0.1.0" +edition = "2024" + +[lib] +crate-type = ["cdylib", "rlib"] + +[dependencies] +et-ws-wasm-agent = { path = "../../ws-wasm-agent" } +js-sys = "0.3" +serde.workspace = true +serde_json.workspace = true +tracing.workspace = true +tracing-wasm = "0.2" +wasm-bindgen = "0.2" +wasm-bindgen-futures = "0.4" +web-sys = { version = "0.3", features = [ + "BinaryType", + "Document", + "Event", + "EventTarget", + "MessageEvent", + "Storage", + "WebSocket", + "Window", + "console", +] } + +[dev-dependencies] +wasm-bindgen-test = "0.3" diff --git a/services/ws-modules/face-detection/src/lib.rs b/services/ws-modules/face-detection/src/lib.rs new file mode 100644 index 0000000..980d784 --- /dev/null +++ b/services/ws-modules/face-detection/src/lib.rs @@ -0,0 +1,849 @@ +use std::cell::{Cell, RefCell}; +use std::rc::Rc; + +use et_ws_wasm_agent::{VideoCapture, WsClient, WsClientConfig}; +use js_sys::{Array, Float32Array, Function, Promise, Reflect}; +use serde_json::json; +use tracing::info; +use wasm_bindgen::JsCast; +use wasm_bindgen::prelude::*; +use wasm_bindgen_futures::{JsFuture, spawn_local}; + +const FACE_MODEL_PATH: &str = "/static/models/video_cv.onnx"; +const FACE_INPUT_WIDTH: usize = 640; +const FACE_INPUT_HEIGHT: usize = 608; +const FACE_INPUT_WIDTH_F64: f64 = FACE_INPUT_WIDTH as f64; +const FACE_INPUT_HEIGHT_F64: f64 = FACE_INPUT_HEIGHT as f64; +const FACE_INFERENCE_INTERVAL_MS: i32 = 750; +const FACE_RENDER_INTERVAL_MS: i32 = 60; +const RETINAFACE_CONFIDENCE_THRESHOLD: f64 = 0.75; +const RETINAFACE_NMS_THRESHOLD: f64 = 0.4; +const RETINAFACE_VARIANCES: [f64; 2] = [0.1, 0.2]; +const RETINAFACE_MIN_SIZES: [&[f64]; 3] = [&[16.0, 32.0], &[64.0, 128.0], &[256.0, 512.0]]; +const RETINAFACE_STEPS: [f64; 3] = [8.0, 16.0, 32.0]; + +#[wasm_bindgen(inline_js = r##" +export async function face_attach_stream(stream) { + const video = document.getElementById("face-video-preview"); + if (!video) { + throw new Error("Missing #face-video-preview element"); + } + + video.srcObject = stream; + video.hidden = false; + + if (!video.videoWidth || !video.videoHeight) { + await new Promise((resolve, reject) => { + const onLoaded = () => { + cleanup(); + resolve(); + }; + const onError = () => { + cleanup(); + reject(new Error("Video stream metadata did not load")); + }; + const cleanup = () => { + video.removeEventListener("loadedmetadata", onLoaded); + video.removeEventListener("error", onError); + }; + video.addEventListener("loadedmetadata", onLoaded, { once: true }); + video.addEventListener("error", onError, { once: true }); + }); + } + + const playResult = video.play?.(); + if (playResult?.catch) { + try { + await playResult; + } catch { + // Browsers may reject autoplay even after a gesture; metadata is enough for capture. + } + } +} + +export function face_detach_stream() { + const video = document.getElementById("face-video-preview"); + const canvas = document.getElementById("face-video-output-canvas"); + if (video) { + video.pause?.(); + video.srcObject = null; + video.hidden = true; + } + if (canvas) { + canvas.hidden = true; + const context = canvas.getContext("2d"); + context?.clearRect(0, 0, canvas.width, canvas.height); + } +} + +export function face_set_status(message) { + const output = document.getElementById("face-output"); + if (output) { + output.value = String(message); + } +} + +export function face_log(message) { + const line = `[face-detection] ${message}`; + console.log(line); + const logEl = document.getElementById("log"); + if (!logEl) { + return; + } + const current = logEl.textContent ?? ""; + logEl.textContent = current ? `${current}\n${line}` : line; +} + +export function face_capture_input_tensor() { + const video = document.getElementById("face-video-preview"); + if (!video?.videoWidth || !video?.videoHeight) { + throw new Error("Video stream is not ready yet."); + } + + const width = 640; + const height = 608; + const mean = [104, 117, 123]; + const canvas = globalThis.__etFacePreprocessCanvas ?? document.createElement("canvas"); + globalThis.__etFacePreprocessCanvas = canvas; + const context = canvas.getContext("2d", { willReadFrequently: true }); + if (!context) { + throw new Error("Unable to create face preprocessing canvas context."); + } + + canvas.width = width; + canvas.height = height; + + const sourceWidth = video.videoWidth; + const sourceHeight = video.videoHeight; + const targetRatio = height / width; + let resizeRatio; + if (sourceHeight / sourceWidth <= targetRatio) { + resizeRatio = width / sourceWidth; + } else { + resizeRatio = height / sourceHeight; + } + + const resizedWidth = Math.max(1, Math.min(width, Math.round(sourceWidth * resizeRatio))); + const resizedHeight = Math.max(1, Math.min(height, Math.round(sourceHeight * resizeRatio))); + context.clearRect(0, 0, width, height); + context.drawImage(video, 0, 0, resizedWidth, resizedHeight); + + const rgba = context.getImageData(0, 0, width, height).data; + const tensorData = new Float32Array(width * height * 3); + + for (let pixelIndex = 0; pixelIndex < width * height; pixelIndex += 1) { + const rgbaIndex = pixelIndex * 4; + const red = rgba[rgbaIndex]; + const green = rgba[rgbaIndex + 1]; + const blue = rgba[rgbaIndex + 2]; + const tensorIndex = pixelIndex * 3; + tensorData[tensorIndex] = blue - mean[0]; + tensorData[tensorIndex + 1] = green - mean[1]; + tensorData[tensorIndex + 2] = red - mean[2]; + } + + return { + data: tensorData, + resizeRatio, + sourceWidth, + sourceHeight, + }; +} + +export function face_render(detections) { + const video = document.getElementById("face-video-preview"); + const canvas = document.getElementById("face-video-output-canvas"); + if (!video?.videoWidth || !video?.videoHeight || !canvas) { + return; + } + + const context = canvas.getContext("2d"); + if (!context) { + throw new Error("Unable to create face output canvas context."); + } + + const width = video.videoWidth; + const height = video.videoHeight; + if (canvas.width !== width || canvas.height !== height) { + canvas.width = width; + canvas.height = height; + } + + canvas.hidden = false; + context.drawImage(video, 0, 0, width, height); + context.lineWidth = 3; + context.font = "16px ui-monospace, monospace"; + + for (const entry of detections ?? []) { + const [x1, y1, x2, y2] = entry.box ?? []; + const left = Number(x1 ?? 0); + const top = Number(y1 ?? 0); + const right = Number(x2 ?? 0); + const bottom = Number(y2 ?? 0); + const boxWidth = Math.max(1, right - left); + const boxHeight = Math.max(1, bottom - top); + context.strokeStyle = "#ef8f35"; + context.strokeRect(left, top, boxWidth, boxHeight); + + const label = `${entry.label ?? "face"} ${((entry.score ?? 0) * 100).toFixed(1)}%`; + const textWidth = context.measureText(label).width + 10; + context.fillStyle = "#182028"; + context.fillRect(left, Math.max(0, top - 24), textWidth, 22); + context.fillStyle = "#fffdfa"; + context.fillText(label, left + 5, Math.max(16, top - 8)); + } +} +"##)] +extern "C" { + #[wasm_bindgen(catch)] + async fn face_attach_stream(stream: JsValue) -> Result; + #[wasm_bindgen] + fn face_detach_stream(); + #[wasm_bindgen] + fn face_set_status(message: &str); + #[wasm_bindgen] + fn face_log(message: &str); + #[wasm_bindgen(catch)] + fn face_capture_input_tensor() -> Result; + #[wasm_bindgen(catch)] + fn face_render(detections: &JsValue) -> Result<(), JsValue>; +} + +#[derive(Clone)] +struct Detection { + label: String, + class_index: i32, + score: f64, + box_coords: [f64; 4], +} + +#[derive(Clone)] +struct DetectionSummary { + detections: Vec, + confidence: f64, + processed_at: String, +} + +struct FaceDetectionRuntime { + client: WsClient, + capture: VideoCapture, + inference_interval_id: i32, + render_interval_id: i32, + _inference_closure: Closure, + _render_closure: Closure, +} + +thread_local! { + static FACE_RUNTIME: RefCell> = const { RefCell::new(None) }; +} + +#[wasm_bindgen(start)] +pub fn init() { + tracing_wasm::set_as_global_default(); + info!("face detection workflow module initialized"); +} + +#[wasm_bindgen] +pub fn is_running() -> bool { + FACE_RUNTIME.with(|runtime| runtime.borrow().is_some()) +} + +#[wasm_bindgen] +pub async fn start() -> Result<(), JsValue> { + if is_running() { + return Ok(()); + } + + face_set_status("face detection: starting"); + log(&format!("loading RetinaFace model from {FACE_MODEL_PATH}"))?; + + let ws_url = websocket_url()?; + let mut client = WsClient::new(WsClientConfig::new(ws_url.clone())); + client.connect()?; + wait_for_connected(&client).await?; + log(&format!("websocket connected with agent_id={}", client.get_client_id()))?; + + let capture = match VideoCapture::request().await { + Ok(capture) => capture, + Err(error) => { + client.disconnect(); + return Err(error); + } + }; + + if let Err(error) = face_attach_stream(capture.raw_stream()).await { + capture.stop(); + client.disconnect(); + return Err(error); + } + + let session = match create_face_session(FACE_MODEL_PATH).await { + Ok(session) => session, + Err(error) => { + capture.stop(); + face_detach_stream(); + client.disconnect(); + return Err(error); + } + }; + + let input_name = first_string_entry(&session, "inputNames")?; + let output_names = string_entries(&session, "outputNames")?; + if output_names.len() < 3 { + capture.stop(); + face_detach_stream(); + client.disconnect(); + return Err(JsValue::from_str( + "RetinaFace session did not expose the expected outputs", + )); + } + + let last_summary: Rc>> = Rc::new(RefCell::new(None)); + let inference_pending = Rc::new(Cell::new(false)); + let last_has_detection = Rc::new(Cell::new(false)); + + let inference_session = session.clone(); + let inference_input_name = input_name.clone(); + let inference_output_names = output_names.clone(); + let inference_client = client.clone(); + let inference_last_summary = last_summary.clone(); + let inference_pending_flag = inference_pending.clone(); + let inference_last_has_detection = last_has_detection.clone(); + let inference_closure = Closure::wrap(Box::new(move || { + if inference_pending_flag.get() { + return; + } + + inference_pending_flag.set(true); + let session = inference_session.clone(); + let input_name = inference_input_name.clone(); + let output_names = inference_output_names.clone(); + let client = inference_client.clone(); + let last_summary = inference_last_summary.clone(); + let pending_flag = inference_pending_flag.clone(); + let last_has_detection = inference_last_has_detection.clone(); + + spawn_local(async move { + let outcome = infer_once(&session, &input_name, &output_names, &client, &last_has_detection).await; + + match outcome { + Ok(summary) => { + update_face_status(&input_name, &output_names, &summary); + *last_summary.borrow_mut() = Some(summary); + } + Err(error) => { + let message = describe_js_error(&error); + face_set_status(&format!("face detection: inference error\n{message}")); + let _ = log(&format!("inference error: {message}")); + } + } + + pending_flag.set(false); + }); + }) as Box); + + let render_last_summary = last_summary.clone(); + let render_closure = Closure::wrap(Box::new(move || { + let detections = render_last_summary + .borrow() + .as_ref() + .map(|summary| detections_to_js(&summary.detections)) + .unwrap_or_else(|| Array::new().into()); + let _ = face_render(&detections); + }) as Box); + + let window = web_sys::window().ok_or_else(|| JsValue::from_str("No window available"))?; + let inference_interval_id = window.set_interval_with_callback_and_timeout_and_arguments_0( + inference_closure.as_ref().unchecked_ref(), + FACE_INFERENCE_INTERVAL_MS, + )?; + let render_interval_id = window.set_interval_with_callback_and_timeout_and_arguments_0( + render_closure.as_ref().unchecked_ref(), + FACE_RENDER_INTERVAL_MS, + )?; + + let startup_summary = DetectionSummary { + detections: Vec::new(), + confidence: 0.0, + processed_at: String::from("waiting for first inference"), + }; + update_face_status(&input_name, &output_names, &startup_summary); + log("face detection demo started")?; + + FACE_RUNTIME.with(|runtime| { + *runtime.borrow_mut() = Some(FaceDetectionRuntime { + client, + capture, + inference_interval_id, + render_interval_id, + _inference_closure: inference_closure, + _render_closure: render_closure, + }); + }); + + Ok(()) +} + +#[wasm_bindgen] +pub fn stop() -> Result<(), JsValue> { + FACE_RUNTIME.with(|runtime| { + let Some(mut runtime) = runtime.borrow_mut().take() else { + return Ok(()); + }; + + let window = web_sys::window().ok_or_else(|| JsValue::from_str("No window available"))?; + window.clear_interval_with_handle(runtime.inference_interval_id); + window.clear_interval_with_handle(runtime.render_interval_id); + runtime.capture.stop(); + runtime.client.disconnect(); + face_detach_stream(); + face_set_status("face detection demo stopped."); + log("face detection demo stopped")?; + Ok(()) + }) +} + +async fn infer_once( + session: &JsValue, + input_name: &str, + output_names: &[String], + client: &WsClient, + last_has_detection: &Cell, +) -> Result { + let capture = face_capture_input_tensor()?; + let tensor_data = Reflect::get(&capture, &JsValue::from_str("data"))?; + let resize_ratio = Reflect::get(&capture, &JsValue::from_str("resizeRatio"))? + .as_f64() + .ok_or_else(|| JsValue::from_str("face capture resizeRatio was unavailable"))?; + let source_width = Reflect::get(&capture, &JsValue::from_str("sourceWidth"))? + .as_f64() + .ok_or_else(|| JsValue::from_str("face capture sourceWidth was unavailable"))?; + let source_height = Reflect::get(&capture, &JsValue::from_str("sourceHeight"))? + .as_f64() + .ok_or_else(|| JsValue::from_str("face capture sourceHeight was unavailable"))?; + + let tensor = create_tensor(&Float32Array::new(&tensor_data))?; + let feeds = js_sys::Object::new(); + Reflect::set(&feeds, &JsValue::from_str(input_name), &tensor)?; + + let run_value = method(session, "run")?.call1(session, &feeds)?; + let outputs = JsFuture::from( + run_value + .dyn_into::() + .map_err(|_| JsValue::from_str("InferenceSession.run did not return a Promise"))?, + ) + .await?; + + let summary = decode_retinaface_outputs(&outputs, output_names, resize_ratio, source_width, source_height)?; + let has_detection = !summary.detections.is_empty(); + let changed = last_has_detection.get() != has_detection; + last_has_detection.set(has_detection); + + client.send_client_event( + "face_detection", + "inference", + json!({ + "mode": "detection", + "detected_class": if has_detection { "face" } else { "no_detection" }, + "class_index": if has_detection { 0 } else { -1 }, + "confidence": summary.confidence, + "detections": summary + .detections + .iter() + .map(|entry| json!({ + "label": entry.label, + "class_index": entry.class_index, + "score": entry.score, + "box": entry.box_coords, + })) + .collect::>(), + "changed": changed, + "processed_at": summary.processed_at, + "model_path": FACE_MODEL_PATH, + "input_name": input_name, + "output_names": output_names, + "source_resolution": { + "width": source_width, + "height": source_height, + }, + }), + )?; + + Ok(summary) +} + +fn update_face_status(input_name: &str, output_names: &[String], summary: &DetectionSummary) { + let mut lines = vec![ + String::from("face detection demo"), + format!("model file: {FACE_MODEL_PATH}"), + format!("input: {input_name}"), + format!("outputs: {}", output_names.join(", ")), + format!("detections: {}", summary.detections.len()), + format!("best confidence: {:.4}", summary.confidence), + format!("processed at: {}", summary.processed_at), + ]; + + if let Some(best) = summary.detections.first() { + lines.push(String::new()); + lines.push(format!( + "best box: {:.1}, {:.1}, {:.1}, {:.1}", + best.box_coords[0], best.box_coords[1], best.box_coords[2], best.box_coords[3] + )); + } + + face_set_status(&lines.join("\n")); +} + +fn detections_to_js(detections: &[Detection]) -> JsValue { + let array = Array::new(); + + for detection in detections { + let object = js_sys::Object::new(); + let box_values = Array::new(); + for value in detection.box_coords { + box_values.push(&JsValue::from_f64(value)); + } + let _ = Reflect::set( + &object, + &JsValue::from_str("label"), + &JsValue::from_str(&detection.label), + ); + let _ = Reflect::set( + &object, + &JsValue::from_str("score"), + &JsValue::from_f64(detection.score), + ); + let _ = Reflect::set(&object, &JsValue::from_str("box"), &box_values); + array.push(&object); + } + + array.into() +} + +fn decode_retinaface_outputs( + outputs: &JsValue, + output_names: &[String], + resize_ratio: f64, + source_width: f64, + source_height: f64, +) -> Result { + let loc_tensor = Reflect::get(outputs, &JsValue::from_str(&output_names[0]))?; + let conf_tensor = Reflect::get(outputs, &JsValue::from_str(&output_names[1]))?; + let landm_tensor = Reflect::get(outputs, &JsValue::from_str(&output_names[2]))?; + + let loc_values = tensor_f32_values(&loc_tensor)?; + let conf_values = tensor_f32_values(&conf_tensor)?; + let landm_values = tensor_f32_values(&landm_tensor)?; + let prior_count = loc_values.len() / 4; + if prior_count == 0 || conf_values.len() / 2 != prior_count || landm_values.len() / 10 != prior_count { + return Err(JsValue::from_str("RetinaFace outputs had unexpected shapes")); + } + + let priors = build_retinaface_priors(FACE_INPUT_HEIGHT_F64, FACE_INPUT_WIDTH_F64); + if priors.len() != prior_count { + return Err(JsValue::from_str("RetinaFace priors did not match output count")); + } + + let mut detections = Vec::new(); + for index in 0..prior_count { + let score = softmax(&[f64::from(conf_values[index * 2]), f64::from(conf_values[index * 2 + 1])])[1]; + if score < RETINAFACE_CONFIDENCE_THRESHOLD { + continue; + } + + let decoded = decode_retinaface_box( + [ + f64::from(loc_values[index * 4]), + f64::from(loc_values[index * 4 + 1]), + f64::from(loc_values[index * 4 + 2]), + f64::from(loc_values[index * 4 + 3]), + ], + priors[index], + ); + let box_coords = [ + clamp((decoded[0] * FACE_INPUT_WIDTH_F64) / resize_ratio, 0.0, source_width), + clamp((decoded[1] * FACE_INPUT_HEIGHT_F64) / resize_ratio, 0.0, source_height), + clamp((decoded[2] * FACE_INPUT_WIDTH_F64) / resize_ratio, 0.0, source_width), + clamp((decoded[3] * FACE_INPUT_HEIGHT_F64) / resize_ratio, 0.0, source_height), + ]; + + detections.push(Detection { + label: String::from("face"), + class_index: 0, + score, + box_coords, + }); + } + + let detections = apply_nms(detections, RETINAFACE_NMS_THRESHOLD); + let confidence = detections.first().map(|entry| entry.score).unwrap_or(0.0); + Ok(DetectionSummary { + detections, + confidence, + processed_at: String::from(js_sys::Date::new_0().to_locale_time_string("en-US")), + }) +} + +fn tensor_f32_values(tensor: &JsValue) -> Result, JsValue> { + let data = Reflect::get(tensor, &JsValue::from_str("data"))?; + Ok(Float32Array::new(&data).to_vec()) +} + +fn build_retinaface_priors(image_height: f64, image_width: f64) -> Vec<[f64; 4]> { + let mut priors = Vec::new(); + + for (index, step) in RETINAFACE_STEPS.into_iter().enumerate() { + let feature_map_height = (image_height / step).ceil() as usize; + let feature_map_width = (image_width / step).ceil() as usize; + let min_sizes = RETINAFACE_MIN_SIZES[index]; + + for row in 0..feature_map_height { + for column in 0..feature_map_width { + for min_size in min_sizes { + priors.push([ + (((column as f64) + 0.5) * step) / image_width, + (((row as f64) + 0.5) * step) / image_height, + min_size / image_width, + min_size / image_height, + ]); + } + } + } + } + + priors +} + +fn decode_retinaface_box(loc: [f64; 4], prior: [f64; 4]) -> [f64; 4] { + let center_x = prior[0] + loc[0] * RETINAFACE_VARIANCES[0] * prior[2]; + let center_y = prior[1] + loc[1] * RETINAFACE_VARIANCES[0] * prior[3]; + let width = prior[2] * (loc[2] * RETINAFACE_VARIANCES[1]).exp(); + let height = prior[3] * (loc[3] * RETINAFACE_VARIANCES[1]).exp(); + + [ + center_x - width / 2.0, + center_y - height / 2.0, + center_x + width / 2.0, + center_y + height / 2.0, + ] +} + +fn compute_iou(left: &Detection, right: &Detection) -> f64 { + let x1 = left.box_coords[0].max(right.box_coords[0]); + let y1 = left.box_coords[1].max(right.box_coords[1]); + let x2 = left.box_coords[2].min(right.box_coords[2]); + let y2 = left.box_coords[3].min(right.box_coords[3]); + let width = (x2 - x1 + 1.0).max(0.0); + let height = (y2 - y1 + 1.0).max(0.0); + let intersection = width * height; + let left_area = (left.box_coords[2] - left.box_coords[0] + 1.0).max(0.0) + * (left.box_coords[3] - left.box_coords[1] + 1.0).max(0.0); + let right_area = (right.box_coords[2] - right.box_coords[0] + 1.0).max(0.0) + * (right.box_coords[3] - right.box_coords[1] + 1.0).max(0.0); + + intersection / (left_area + right_area - intersection).max(1e-6) +} + +fn apply_nms(mut detections: Vec, threshold: f64) -> Vec { + detections.sort_by(|left, right| { + right + .score + .partial_cmp(&left.score) + .unwrap_or(std::cmp::Ordering::Equal) + }); + + let mut kept = Vec::new(); + 'candidates: for candidate in detections { + for accepted in &kept { + if compute_iou(&candidate, accepted) > threshold { + continue 'candidates; + } + } + kept.push(candidate); + } + + kept +} + +fn softmax(values: &[f64]) -> Vec { + if values.is_empty() { + return Vec::new(); + } + + let max_value = values.iter().copied().fold(f64::NEG_INFINITY, f64::max); + let exps: Vec = values.iter().map(|value| (value - max_value).exp()).collect(); + let sum: f64 = exps.iter().sum(); + exps.into_iter().map(|value| value / sum).collect() +} + +fn clamp(value: f64, min: f64, max: f64) -> f64 { + value.max(min).min(max) +} + +fn method(target: &JsValue, name: &str) -> Result { + Reflect::get(target, &JsValue::from_str(name))? + .dyn_into::() + .map_err(|_| JsValue::from_str(&format!("{name} is not callable"))) +} + +fn first_string_entry(target: &JsValue, field: &str) -> Result { + let values = Reflect::get(target, &JsValue::from_str(field))?; + let first = Reflect::get(&values, &JsValue::from_f64(0.0))?; + first + .as_string() + .ok_or_else(|| JsValue::from_str(&format!("Missing first entry for {field}"))) +} + +fn string_entries(target: &JsValue, field: &str) -> Result, JsValue> { + let values = Reflect::get(target, &JsValue::from_str(field))?; + let array = Array::from(&values); + let mut entries = Vec::with_capacity(array.length() as usize); + + for value in array.iter() { + let entry = value + .as_string() + .ok_or_else(|| JsValue::from_str(&format!("Invalid entry in {field}")))?; + entries.push(entry); + } + + Ok(entries) +} + +async fn wait_for_connected(client: &WsClient) -> Result<(), JsValue> { + for _ in 0..100 { + if client.get_state() == "connected" { + return Ok(()); + } + sleep_ms(100).await?; + } + + Err(JsValue::from_str("Timed out waiting for websocket connection")) +} + +fn websocket_url() -> Result { + let window = web_sys::window().ok_or_else(|| JsValue::from_str("No window available"))?; + let location = Reflect::get(window.as_ref(), &JsValue::from_str("location"))?; + let protocol = Reflect::get(&location, &JsValue::from_str("protocol"))? + .as_string() + .ok_or_else(|| JsValue::from_str("window.location.protocol is unavailable"))?; + let host = Reflect::get(&location, &JsValue::from_str("host"))? + .as_string() + .ok_or_else(|| JsValue::from_str("window.location.host is unavailable"))?; + let ws_protocol = if protocol == "https:" { "wss:" } else { "ws:" }; + Ok(format!("{ws_protocol}//{host}/ws")) +} + +async fn create_face_session(model_path: &str) -> Result { + let window = web_sys::window().ok_or_else(|| JsValue::from_str("No window available"))?; + let ort = Reflect::get(window.as_ref(), &JsValue::from_str("ort"))?; + if ort.is_null() || ort.is_undefined() { + return Err(JsValue::from_str("onnxruntime-web did not load")); + } + + configure_onnx_runtime_wasm(&window, &ort)?; + + let inference_session = Reflect::get(&ort, &JsValue::from_str("InferenceSession"))?; + let create = method(&inference_session, "create")?; + let options = js_sys::Object::new(); + Reflect::set( + &options, + &JsValue::from_str("executionProviders"), + &Array::of1(&JsValue::from_str("wasm")), + )?; + + let value = create.call2(&inference_session, &JsValue::from_str(model_path), &options)?; + JsFuture::from( + value + .dyn_into::() + .map_err(|_| JsValue::from_str("InferenceSession.create did not return a Promise"))?, + ) + .await +} + +fn configure_onnx_runtime_wasm(window: &web_sys::Window, ort: &JsValue) -> Result<(), JsValue> { + let env = Reflect::get(ort, &JsValue::from_str("env"))?; + let wasm = Reflect::get(&env, &JsValue::from_str("wasm"))?; + if wasm.is_null() || wasm.is_undefined() { + return Err(JsValue::from_str("onnxruntime-web environment is unavailable")); + } + + let versions = Reflect::get(&env, &JsValue::from_str("versions"))?; + let ort_version = Reflect::get(&versions, &JsValue::from_str("web"))? + .as_string() + .ok_or_else(|| JsValue::from_str("onnxruntime-web version is unavailable"))?; + let dist_base_url = format!("https://cdn.jsdelivr.net/npm/onnxruntime-web@{ort_version}/dist"); + + let supports_threads = Reflect::get(window.as_ref(), &JsValue::from_str("crossOriginIsolated"))? + .as_bool() + .unwrap_or(false) + && Reflect::has(window.as_ref(), &JsValue::from_str("SharedArrayBuffer"))?; + + Reflect::set( + &wasm, + &JsValue::from_str("numThreads"), + &JsValue::from_f64(if supports_threads { 0.0 } else { 1.0 }), + )?; + + let wasm_paths = js_sys::Object::new(); + Reflect::set( + &wasm_paths, + &JsValue::from_str("mjs"), + &JsValue::from_str(&format!("{dist_base_url}/ort-wasm-simd-threaded.mjs")), + )?; + Reflect::set( + &wasm_paths, + &JsValue::from_str("wasm"), + &JsValue::from_str(&format!("{dist_base_url}/ort-wasm-simd-threaded.wasm")), + )?; + Reflect::set(&wasm, &JsValue::from_str("wasmPaths"), &wasm_paths)?; + Ok(()) +} + +fn create_tensor(values: &Float32Array) -> Result { + let window = web_sys::window().ok_or_else(|| JsValue::from_str("No window available"))?; + let ort = Reflect::get(window.as_ref(), &JsValue::from_str("ort"))?; + let tensor_ctor = Reflect::get(&ort, &JsValue::from_str("Tensor"))? + .dyn_into::() + .map_err(|_| JsValue::from_str("ort.Tensor is not callable"))?; + + let dims = Array::new(); + dims.push(&JsValue::from_f64(1.0)); + dims.push(&JsValue::from_f64(FACE_INPUT_HEIGHT_F64)); + dims.push(&JsValue::from_f64(FACE_INPUT_WIDTH_F64)); + dims.push(&JsValue::from_f64(3.0)); + + let args = Array::new(); + args.push(&JsValue::from_str("float32")); + args.push(values); + args.push(&dims.into()); + + Reflect::construct(&tensor_ctor, &args) +} + +async fn sleep_ms(duration_ms: i32) -> Result<(), JsValue> { + let window = web_sys::window().ok_or_else(|| JsValue::from_str("No window available"))?; + let promise = Promise::new(&mut |resolve, reject| { + let callback = Closure::once_into_js(move || { + let _ = resolve.call0(&JsValue::NULL); + }); + + if let Err(error) = + window.set_timeout_with_callback_and_timeout_and_arguments_0(callback.unchecked_ref(), duration_ms) + { + let _ = reject.call1(&JsValue::NULL, &error); + } + }); + JsFuture::from(promise).await.map(|_| ()) +} + +fn describe_js_error(error: &JsValue) -> String { + error + .as_string() + .or_else(|| js_sys::JSON::stringify(error).ok().map(String::from)) + .unwrap_or_else(|| format!("{error:?}")) +} + +fn log(message: &str) -> Result<(), JsValue> { + face_log(message); + Ok(()) +} diff --git a/services/ws-server/static/app.js b/services/ws-server/static/app.js index e9ebc60..d5d6529 100644 --- a/services/ws-server/static/app.js +++ b/services/ws-server/static/app.js @@ -15,6 +15,7 @@ import init, { const logEl = document.getElementById("log"); const runHarButton = document.getElementById("run-har-button"); +const runFaceDetectionButton = document.getElementById("run-face-detection-button"); const micButton = document.getElementById("mic-button"); const videoButton = document.getElementById("video-button"); const bluetoothButton = document.getElementById("bluetooth-button"); @@ -25,7 +26,6 @@ const gpuInfoButton = document.getElementById("gpu-info-button"); const speechButton = document.getElementById("speech-button"); const nfcButton = document.getElementById("nfc-button"); const sensorsButton = document.getElementById("sensors-button"); -const videoModelButton = document.getElementById("video-model-button"); const videoOutputButton = document.getElementById("video-output-button"); const agentStatusEl = document.getElementById("agent-status"); const agentIdEl = document.getElementById("agent-id"); @@ -54,6 +54,8 @@ let videoOverlayContext = videoOutputCanvas.getContext("2d"); let videoOutputVisible = false; let videoRenderFrameId = null; let lastVideoInferenceSummary = null; +let faceDetectionModule = null; +let faceDetectionModuleInitialized = false; let sendClientEvent = () => {}; const VIDEO_INFERENCE_INTERVAL_MS = 750; const VIDEO_RENDER_SCORE_THRESHOLD = 0.35; @@ -1607,49 +1609,42 @@ try { } }); - videoModelButton.addEventListener("click", async () => { + runFaceDetectionButton.addEventListener("click", async () => { + runFaceDetectionButton.disabled = true; + try { - if (!window.ort) { - throw new Error("onnxruntime-web did not load."); + if (!faceDetectionModule) { + const cacheBust = Date.now(); + const moduleUrl = `/modules/face-detection/pkg/et_ws_face_detection.js?v=${cacheBust}`; + const wasmUrl = `/modules/face-detection/pkg/et_ws_face_detection_bg.wasm?v=${cacheBust}`; + append(`face detection module: importing ${moduleUrl}`); + faceDetectionModule = await import(moduleUrl); + append(`face detection module: initializing ${wasmUrl}`); + await faceDetectionModule.default(wasmUrl); + faceDetectionModuleInitialized = true; } - configureOnnxRuntimeWasm(); - - videoModelButton.disabled = true; - videoModelButton.textContent = "Loading video model..."; - updateVideoStatus(["loading model..."]); + if (!faceDetectionModuleInitialized) { + throw new Error("face detection module failed to initialize"); + } - videoCvSession = await window.ort.InferenceSession.create( - VIDEO_MODEL_PATH, - { - executionProviders: ["wasm"], - }, - ); + if (faceDetectionModule.is_running()) { + faceDetectionModule.stop(); + append("face detection module stopped"); + runFaceDetectionButton.textContent = "face demo"; + return; + } - videoCvInputName = selectVideoModelInputName(videoCvSession); - videoCvOutputName = selectVideoModelOutputName(videoCvSession); - lastVideoCvLabel = null; - lastVideoInferenceSummary = null; - append( - `video cv model loaded: input=${videoCvInputName} output=${videoCvOutputName} input_dims=${ - JSON.stringify(videoCvSession.inputMetadata?.[videoCvInputName]?.dimensions ?? []) - }`, - ); - syncVideoCvLoop(); + append("face detection module: calling start()"); + await faceDetectionModule.start(); + append("face detection module started"); + runFaceDetectionButton.textContent = "stop face demo"; } catch (error) { - videoCvSession = null; - videoCvInputName = null; - videoCvOutputName = null; - stopVideoCvLoop(); - lastVideoInferenceSummary = null; - updateVideoStatus([ - `model load error: ${error instanceof Error ? error.message : String(error)}`, - ]); - append(`video cv error: ${error instanceof Error ? error.message : String(error)}`); + append(`face detection module error: ${describeError(error)}`); console.error(error); + runFaceDetectionButton.textContent = "face demo"; } finally { - videoModelButton.disabled = false; - videoModelButton.textContent = videoCvSession ? "Reload video CV model" : "Load video CV model"; + runFaceDetectionButton.disabled = false; } }); @@ -1661,6 +1656,7 @@ try { window.client = client; window.sendAlive = () => client.send_alive(); window.runHarModule = () => runHarButton.click(); + window.runFaceDetectionModule = () => runFaceDetectionButton.click(); } catch (error) { append(`error: ${error instanceof Error ? error.message : String(error)}`); console.error(error); diff --git a/services/ws-server/static/index.html b/services/ws-server/static/index.html index 73335b0..158c0fd 100644 --- a/services/ws-server/static/index.html +++ b/services/ws-server/static/index.html @@ -122,6 +122,7 @@

WASM web agent

+ -

+ +
@@ -152,6 +152,8 @@

WASM web agent

>Waiting for device motion/orientation data… +
From a2fbff1af9605dc18cec31ed3bf516106d0aa54d Mon Sep 17 00:00:00 2001 From: Pierre Tenedero Date: Thu, 9 Apr 2026 18:28:31 +0800 Subject: [PATCH 3/4] Refactor inline js to Rust --- services/ws-modules/face-detection/Cargo.toml | 6 + services/ws-modules/face-detection/src/lib.rs | 493 +++++++++--------- services/ws-server/static/app.js | 4 +- 3 files changed, 269 insertions(+), 234 deletions(-) diff --git a/services/ws-modules/face-detection/Cargo.toml b/services/ws-modules/face-detection/Cargo.toml index 59725b4..a56553f 100644 --- a/services/ws-modules/face-detection/Cargo.toml +++ b/services/ws-modules/face-detection/Cargo.toml @@ -17,11 +17,17 @@ wasm-bindgen = "0.2" wasm-bindgen-futures = "0.4" web-sys = { version = "0.3", features = [ "BinaryType", + "CanvasRenderingContext2d", "Document", "Event", "EventTarget", + "HtmlCanvasElement", + "HtmlVideoElement", + "ImageData", + "MediaStream", "MessageEvent", "Storage", + "TextMetrics", "WebSocket", "Window", "console", diff --git a/services/ws-modules/face-detection/src/lib.rs b/services/ws-modules/face-detection/src/lib.rs index 980d784..5289b93 100644 --- a/services/ws-modules/face-detection/src/lib.rs +++ b/services/ws-modules/face-detection/src/lib.rs @@ -8,6 +8,7 @@ use tracing::info; use wasm_bindgen::JsCast; use wasm_bindgen::prelude::*; use wasm_bindgen_futures::{JsFuture, spawn_local}; +use web_sys::{CanvasRenderingContext2d, HtmlCanvasElement, HtmlVideoElement, ImageData, MediaStream}; const FACE_MODEL_PATH: &str = "/static/models/video_cv.onnx"; const FACE_INPUT_WIDTH: usize = 640; @@ -22,193 +23,6 @@ const RETINAFACE_VARIANCES: [f64; 2] = [0.1, 0.2]; const RETINAFACE_MIN_SIZES: [&[f64]; 3] = [&[16.0, 32.0], &[64.0, 128.0], &[256.0, 512.0]]; const RETINAFACE_STEPS: [f64; 3] = [8.0, 16.0, 32.0]; -#[wasm_bindgen(inline_js = r##" -export async function face_attach_stream(stream) { - const video = document.getElementById("face-video-preview"); - if (!video) { - throw new Error("Missing #face-video-preview element"); - } - - video.srcObject = stream; - video.hidden = false; - - if (!video.videoWidth || !video.videoHeight) { - await new Promise((resolve, reject) => { - const onLoaded = () => { - cleanup(); - resolve(); - }; - const onError = () => { - cleanup(); - reject(new Error("Video stream metadata did not load")); - }; - const cleanup = () => { - video.removeEventListener("loadedmetadata", onLoaded); - video.removeEventListener("error", onError); - }; - video.addEventListener("loadedmetadata", onLoaded, { once: true }); - video.addEventListener("error", onError, { once: true }); - }); - } - - const playResult = video.play?.(); - if (playResult?.catch) { - try { - await playResult; - } catch { - // Browsers may reject autoplay even after a gesture; metadata is enough for capture. - } - } -} - -export function face_detach_stream() { - const video = document.getElementById("face-video-preview"); - const canvas = document.getElementById("face-video-output-canvas"); - if (video) { - video.pause?.(); - video.srcObject = null; - video.hidden = true; - } - if (canvas) { - canvas.hidden = true; - const context = canvas.getContext("2d"); - context?.clearRect(0, 0, canvas.width, canvas.height); - } -} - -export function face_set_status(message) { - const output = document.getElementById("face-output"); - if (output) { - output.value = String(message); - } -} - -export function face_log(message) { - const line = `[face-detection] ${message}`; - console.log(line); - const logEl = document.getElementById("log"); - if (!logEl) { - return; - } - const current = logEl.textContent ?? ""; - logEl.textContent = current ? `${current}\n${line}` : line; -} - -export function face_capture_input_tensor() { - const video = document.getElementById("face-video-preview"); - if (!video?.videoWidth || !video?.videoHeight) { - throw new Error("Video stream is not ready yet."); - } - - const width = 640; - const height = 608; - const mean = [104, 117, 123]; - const canvas = globalThis.__etFacePreprocessCanvas ?? document.createElement("canvas"); - globalThis.__etFacePreprocessCanvas = canvas; - const context = canvas.getContext("2d", { willReadFrequently: true }); - if (!context) { - throw new Error("Unable to create face preprocessing canvas context."); - } - - canvas.width = width; - canvas.height = height; - - const sourceWidth = video.videoWidth; - const sourceHeight = video.videoHeight; - const targetRatio = height / width; - let resizeRatio; - if (sourceHeight / sourceWidth <= targetRatio) { - resizeRatio = width / sourceWidth; - } else { - resizeRatio = height / sourceHeight; - } - - const resizedWidth = Math.max(1, Math.min(width, Math.round(sourceWidth * resizeRatio))); - const resizedHeight = Math.max(1, Math.min(height, Math.round(sourceHeight * resizeRatio))); - context.clearRect(0, 0, width, height); - context.drawImage(video, 0, 0, resizedWidth, resizedHeight); - - const rgba = context.getImageData(0, 0, width, height).data; - const tensorData = new Float32Array(width * height * 3); - - for (let pixelIndex = 0; pixelIndex < width * height; pixelIndex += 1) { - const rgbaIndex = pixelIndex * 4; - const red = rgba[rgbaIndex]; - const green = rgba[rgbaIndex + 1]; - const blue = rgba[rgbaIndex + 2]; - const tensorIndex = pixelIndex * 3; - tensorData[tensorIndex] = blue - mean[0]; - tensorData[tensorIndex + 1] = green - mean[1]; - tensorData[tensorIndex + 2] = red - mean[2]; - } - - return { - data: tensorData, - resizeRatio, - sourceWidth, - sourceHeight, - }; -} - -export function face_render(detections) { - const video = document.getElementById("face-video-preview"); - const canvas = document.getElementById("face-video-output-canvas"); - if (!video?.videoWidth || !video?.videoHeight || !canvas) { - return; - } - - const context = canvas.getContext("2d"); - if (!context) { - throw new Error("Unable to create face output canvas context."); - } - - const width = video.videoWidth; - const height = video.videoHeight; - if (canvas.width !== width || canvas.height !== height) { - canvas.width = width; - canvas.height = height; - } - - canvas.hidden = false; - context.drawImage(video, 0, 0, width, height); - context.lineWidth = 3; - context.font = "16px ui-monospace, monospace"; - - for (const entry of detections ?? []) { - const [x1, y1, x2, y2] = entry.box ?? []; - const left = Number(x1 ?? 0); - const top = Number(y1 ?? 0); - const right = Number(x2 ?? 0); - const bottom = Number(y2 ?? 0); - const boxWidth = Math.max(1, right - left); - const boxHeight = Math.max(1, bottom - top); - context.strokeStyle = "#ef8f35"; - context.strokeRect(left, top, boxWidth, boxHeight); - - const label = `${entry.label ?? "face"} ${((entry.score ?? 0) * 100).toFixed(1)}%`; - const textWidth = context.measureText(label).width + 10; - context.fillStyle = "#182028"; - context.fillRect(left, Math.max(0, top - 24), textWidth, 22); - context.fillStyle = "#fffdfa"; - context.fillText(label, left + 5, Math.max(16, top - 8)); - } -} -"##)] -extern "C" { - #[wasm_bindgen(catch)] - async fn face_attach_stream(stream: JsValue) -> Result; - #[wasm_bindgen] - fn face_detach_stream(); - #[wasm_bindgen] - fn face_set_status(message: &str); - #[wasm_bindgen] - fn face_log(message: &str); - #[wasm_bindgen(catch)] - fn face_capture_input_tensor() -> Result; - #[wasm_bindgen(catch)] - fn face_render(detections: &JsValue) -> Result<(), JsValue>; -} - #[derive(Clone)] struct Detection { label: String, @@ -224,6 +38,13 @@ struct DetectionSummary { processed_at: String, } +struct FaceCaptureTensor { + data: Vec, + resize_ratio: f64, + source_width: f64, + source_height: f64, +} + struct FaceDetectionRuntime { client: WsClient, capture: VideoCapture, @@ -235,6 +56,7 @@ struct FaceDetectionRuntime { thread_local! { static FACE_RUNTIME: RefCell> = const { RefCell::new(None) }; + static FACE_PREPROCESS_CANVAS: RefCell> = const { RefCell::new(None) }; } #[wasm_bindgen(start)] @@ -249,7 +71,7 @@ pub fn is_running() -> bool { } #[wasm_bindgen] -pub async fn start() -> Result<(), JsValue> { +pub async fn run() -> Result<(), JsValue> { if is_running() { return Ok(()); } @@ -347,8 +169,8 @@ pub async fn start() -> Result<(), JsValue> { let detections = render_last_summary .borrow() .as_ref() - .map(|summary| detections_to_js(&summary.detections)) - .unwrap_or_else(|| Array::new().into()); + .map(|summary| summary.detections.clone()) + .unwrap_or_default(); let _ = face_render(&detections); }) as Box); @@ -384,6 +206,11 @@ pub async fn start() -> Result<(), JsValue> { Ok(()) } +#[wasm_bindgen] +pub async fn start() -> Result<(), JsValue> { + run().await +} + #[wasm_bindgen] pub fn stop() -> Result<(), JsValue> { FACE_RUNTIME.with(|runtime| { @@ -411,18 +238,7 @@ async fn infer_once( last_has_detection: &Cell, ) -> Result { let capture = face_capture_input_tensor()?; - let tensor_data = Reflect::get(&capture, &JsValue::from_str("data"))?; - let resize_ratio = Reflect::get(&capture, &JsValue::from_str("resizeRatio"))? - .as_f64() - .ok_or_else(|| JsValue::from_str("face capture resizeRatio was unavailable"))?; - let source_width = Reflect::get(&capture, &JsValue::from_str("sourceWidth"))? - .as_f64() - .ok_or_else(|| JsValue::from_str("face capture sourceWidth was unavailable"))?; - let source_height = Reflect::get(&capture, &JsValue::from_str("sourceHeight"))? - .as_f64() - .ok_or_else(|| JsValue::from_str("face capture sourceHeight was unavailable"))?; - - let tensor = create_tensor(&Float32Array::new(&tensor_data))?; + let tensor = create_tensor(&Float32Array::from(capture.data.as_slice()))?; let feeds = js_sys::Object::new(); Reflect::set(&feeds, &JsValue::from_str(input_name), &tensor)?; @@ -434,7 +250,13 @@ async fn infer_once( ) .await?; - let summary = decode_retinaface_outputs(&outputs, output_names, resize_ratio, source_width, source_height)?; + let summary = decode_retinaface_outputs( + &outputs, + output_names, + capture.resize_ratio, + capture.source_width, + capture.source_height, + )?; let has_detection = !summary.detections.is_empty(); let changed = last_has_detection.get() != has_detection; last_has_detection.set(has_detection); @@ -463,8 +285,8 @@ async fn infer_once( "input_name": input_name, "output_names": output_names, "source_resolution": { - "width": source_width, - "height": source_height, + "width": capture.source_width, + "height": capture.source_height, }, }), )?; @@ -494,32 +316,6 @@ fn update_face_status(input_name: &str, output_names: &[String], summary: &Detec face_set_status(&lines.join("\n")); } -fn detections_to_js(detections: &[Detection]) -> JsValue { - let array = Array::new(); - - for detection in detections { - let object = js_sys::Object::new(); - let box_values = Array::new(); - for value in detection.box_coords { - box_values.push(&JsValue::from_f64(value)); - } - let _ = Reflect::set( - &object, - &JsValue::from_str("label"), - &JsValue::from_str(&detection.label), - ); - let _ = Reflect::set( - &object, - &JsValue::from_str("score"), - &JsValue::from_f64(detection.score), - ); - let _ = Reflect::set(&object, &JsValue::from_str("box"), &box_values); - array.push(&object); - } - - array.into() -} - fn decode_retinaface_outputs( outputs: &JsValue, output_names: &[String], @@ -844,6 +640,239 @@ fn describe_js_error(error: &JsValue) -> String { } fn log(message: &str) -> Result<(), JsValue> { - face_log(message); + let line = format!("[face-detection] {message}"); + web_sys::console::log_1(&JsValue::from_str(&line)); + + if let Some(window) = web_sys::window() + && let Some(document) = window.document() + && let Some(log_el) = document.get_element_by_id("log") + { + let current = log_el.text_content().unwrap_or_default(); + let next = if current.is_empty() { + line + } else { + format!("{current}\n{line}") + }; + log_el.set_text_content(Some(&next)); + } + Ok(()) } + +async fn face_attach_stream(stream: JsValue) -> Result<(), JsValue> { + let video = face_video_element()?; + let stream = stream + .dyn_into::() + .map_err(|_| JsValue::from_str("Video capture stream was not a MediaStream"))?; + + Reflect::set(video.as_ref(), &JsValue::from_str("srcObject"), stream.as_ref())?; + set_hidden(video.as_ref(), false)?; + + for _ in 0..50 { + if video.video_width() > 0 && video.video_height() > 0 { + break; + } + sleep_ms(100).await?; + } + + if video.video_width() == 0 || video.video_height() == 0 { + return Err(JsValue::from_str("Video stream metadata did not load")); + } + + if let Ok(play_result) = method(video.as_ref(), "play").and_then(|play| play.call0(video.as_ref())) { + if let Ok(play_promise) = play_result.dyn_into::() { + let _ = JsFuture::from(play_promise).await; + } + } + + Ok(()) +} + +fn face_detach_stream() { + if let Ok(video) = face_video_element() { + if let Ok(pause) = method(video.as_ref(), "pause") { + let _ = pause.call0(video.as_ref()); + } + let _ = Reflect::set(video.as_ref(), &JsValue::from_str("srcObject"), &JsValue::NULL); + let _ = set_hidden(video.as_ref(), true); + } + + if let Ok(canvas) = face_output_canvas_element() { + let _ = set_hidden(canvas.as_ref(), true); + if let Ok(context) = canvas_2d_context(&canvas) { + context.clear_rect(0.0, 0.0, f64::from(canvas.width()), f64::from(canvas.height())); + } + } +} + +fn face_set_status(message: &str) { + let _ = set_textarea_value("face-output", message); +} + +fn face_capture_input_tensor() -> Result { + let video = face_video_element()?; + let source_width = f64::from(video.video_width()); + let source_height = f64::from(video.video_height()); + if source_width <= 0.0 || source_height <= 0.0 { + return Err(JsValue::from_str("Video stream is not ready yet.")); + } + + let canvas = face_preprocess_canvas()?; + canvas.set_width(FACE_INPUT_WIDTH as u32); + canvas.set_height(FACE_INPUT_HEIGHT as u32); + let context = canvas_2d_context(&canvas)?; + + let target_ratio = FACE_INPUT_HEIGHT_F64 / FACE_INPUT_WIDTH_F64; + let resize_ratio = if source_height / source_width <= target_ratio { + FACE_INPUT_WIDTH_F64 / source_width + } else { + FACE_INPUT_HEIGHT_F64 / source_height + }; + + let resized_width = (source_width * resize_ratio).round().clamp(1.0, FACE_INPUT_WIDTH_F64); + let resized_height = (source_height * resize_ratio) + .round() + .clamp(1.0, FACE_INPUT_HEIGHT_F64); + context.clear_rect(0.0, 0.0, FACE_INPUT_WIDTH_F64, FACE_INPUT_HEIGHT_F64); + context.draw_image_with_html_video_element_and_dw_and_dh(&video, 0.0, 0.0, resized_width, resized_height)?; + + let image_data = context.get_image_data(0.0, 0.0, FACE_INPUT_WIDTH_F64, FACE_INPUT_HEIGHT_F64)?; + let tensor_data = image_data_to_tensor(&image_data); + + Ok(FaceCaptureTensor { + data: tensor_data, + resize_ratio, + source_width, + source_height, + }) +} + +fn face_render(detections: &[Detection]) -> Result<(), JsValue> { + let video = face_video_element()?; + let width = video.video_width(); + let height = video.video_height(); + if width == 0 || height == 0 { + return Ok(()); + } + + let canvas = face_output_canvas_element()?; + let context = canvas_2d_context(&canvas)?; + if canvas.width() != width || canvas.height() != height { + canvas.set_width(width); + canvas.set_height(height); + } + + set_hidden(canvas.as_ref(), false)?; + context.draw_image_with_html_video_element_and_dw_and_dh(&video, 0.0, 0.0, f64::from(width), f64::from(height))?; + context.set_line_width(3.0); + context.set_font("16px ui-monospace, monospace"); + + for detection in detections { + let left = detection.box_coords[0]; + let top = detection.box_coords[1]; + let right = detection.box_coords[2]; + let bottom = detection.box_coords[3]; + let box_width = (right - left).max(1.0); + let box_height = (bottom - top).max(1.0); + + context.set_stroke_style_str("#ef8f35"); + context.stroke_rect(left, top, box_width, box_height); + + let label = format!("{} {:.1}%", detection.label, detection.score * 100.0); + let text_width = context.measure_text(&label)?.width() + 10.0; + context.set_fill_style_str("#182028"); + context.fill_rect(left, (top - 24.0).max(0.0), text_width, 22.0); + context.set_fill_style_str("#fffdfa"); + context.fill_text(&label, left + 5.0, (top - 8.0).max(16.0))?; + } + + Ok(()) +} + +fn image_data_to_tensor(image_data: &ImageData) -> Vec { + const CHANNEL_MEAN: [f32; 3] = [104.0, 117.0, 123.0]; + + let rgba = image_data.data().to_vec(); + let mut tensor_data = vec![0.0_f32; FACE_INPUT_WIDTH * FACE_INPUT_HEIGHT * 3]; + + for pixel_index in 0..(FACE_INPUT_WIDTH * FACE_INPUT_HEIGHT) { + let rgba_index = pixel_index * 4; + let tensor_index = pixel_index * 3; + let red = rgba[rgba_index] as f32; + let green = rgba[rgba_index + 1] as f32; + let blue = rgba[rgba_index + 2] as f32; + + tensor_data[tensor_index] = blue - CHANNEL_MEAN[0]; + tensor_data[tensor_index + 1] = green - CHANNEL_MEAN[1]; + tensor_data[tensor_index + 2] = red - CHANNEL_MEAN[2]; + } + + tensor_data +} + +fn set_textarea_value(element_id: &str, message: &str) -> Result<(), JsValue> { + if let Some(window) = web_sys::window() + && let Some(document) = window.document() + && let Some(output) = document.get_element_by_id(element_id) + { + Reflect::set( + output.as_ref(), + &JsValue::from_str("value"), + &JsValue::from_str(message), + )?; + } + + Ok(()) +} + +fn face_video_element() -> Result { + let document = web_sys::window() + .and_then(|window| window.document()) + .ok_or_else(|| JsValue::from_str("No document available"))?; + document + .get_element_by_id("face-video-preview") + .ok_or_else(|| JsValue::from_str("Missing #face-video-preview element"))? + .dyn_into::() + .map_err(|_| JsValue::from_str("#face-video-preview was not a video element")) +} + +fn face_output_canvas_element() -> Result { + let document = web_sys::window() + .and_then(|window| window.document()) + .ok_or_else(|| JsValue::from_str("No document available"))?; + document + .get_element_by_id("face-video-output-canvas") + .ok_or_else(|| JsValue::from_str("Missing #face-video-output-canvas element"))? + .dyn_into::() + .map_err(|_| JsValue::from_str("#face-video-output-canvas was not a canvas element")) +} + +fn face_preprocess_canvas() -> Result { + FACE_PREPROCESS_CANVAS.with(|slot| { + if let Some(canvas) = slot.borrow().as_ref() { + return Ok(canvas.clone()); + } + + let document = web_sys::window() + .and_then(|window| window.document()) + .ok_or_else(|| JsValue::from_str("No document available"))?; + let canvas = document + .create_element("canvas")? + .dyn_into::() + .map_err(|_| JsValue::from_str("Unable to create preprocessing canvas"))?; + *slot.borrow_mut() = Some(canvas.clone()); + Ok(canvas) + }) +} + +fn canvas_2d_context(canvas: &HtmlCanvasElement) -> Result { + canvas + .get_context("2d")? + .ok_or_else(|| JsValue::from_str("2d canvas context was unavailable"))? + .dyn_into::() + .map_err(|_| JsValue::from_str("Canvas context was not 2d")) +} + +fn set_hidden(target: &JsValue, hidden: bool) -> Result<(), JsValue> { + Reflect::set(target, &JsValue::from_str("hidden"), &JsValue::from_bool(hidden)).map(|_| ()) +} diff --git a/services/ws-server/static/app.js b/services/ws-server/static/app.js index d5d6529..96a60bf 100644 --- a/services/ws-server/static/app.js +++ b/services/ws-server/static/app.js @@ -1635,8 +1635,8 @@ try { return; } - append("face detection module: calling start()"); - await faceDetectionModule.start(); + append("face detection module: calling run()"); + await faceDetectionModule.run(); append("face detection module started"); runFaceDetectionButton.textContent = "stop face demo"; } catch (error) { From 51ea4220189cf31e5ac304aa4c5503117d0adccb Mon Sep 17 00:00:00 2001 From: Pierre Tenedero Date: Thu, 9 Apr 2026 18:38:47 +0800 Subject: [PATCH 4/4] Refactor for CI --- .mise.toml | 13 ++++++++++--- services/ws-modules/face-detection/src/lib.rs | 12 +++++------- 2 files changed, 15 insertions(+), 10 deletions(-) diff --git a/.mise.toml b/.mise.toml index 6d21ef5..96b7d5e 100644 --- a/.mise.toml +++ b/.mise.toml @@ -17,16 +17,23 @@ rust = { version = "latest", components = "clippy" } typos = "latest" [tasks] -dprint-check = "dprint check" -dprint-fmt = "dprint fmt" editorconfig-check = "ec" +[tasks.dprint-check] +run = "dprint check" + +[tasks.dprint-check.env] +DPRINT_CACHE_DIR = "/tmp/dprint-cache" + +[tasks.dprint-fmt] +run = "dprint fmt" + [tasks.fmt] depends = ["cargo-clippy-fix", "cargo-fmt", "dprint-fmt", "taplo-fmt"] description = "Run repository formatters" [tasks.install-nightly] -run = "rustup toolchain install nightly --component rustfmt" +run = "cargo +nightly fmt --version >/dev/null 2>&1 || rustup toolchain install nightly --component rustfmt" [tasks.check] depends = [ diff --git a/services/ws-modules/face-detection/src/lib.rs b/services/ws-modules/face-detection/src/lib.rs index 5289b93..2dbd874 100644 --- a/services/ws-modules/face-detection/src/lib.rs +++ b/services/ws-modules/face-detection/src/lib.rs @@ -679,10 +679,10 @@ async fn face_attach_stream(stream: JsValue) -> Result<(), JsValue> { return Err(JsValue::from_str("Video stream metadata did not load")); } - if let Ok(play_result) = method(video.as_ref(), "play").and_then(|play| play.call0(video.as_ref())) { - if let Ok(play_promise) = play_result.dyn_into::() { - let _ = JsFuture::from(play_promise).await; - } + if let Ok(play_result) = method(video.as_ref(), "play").and_then(|play| play.call0(video.as_ref())) + && let Ok(play_promise) = play_result.dyn_into::() + { + let _ = JsFuture::from(play_promise).await; } Ok(()) @@ -730,9 +730,7 @@ fn face_capture_input_tensor() -> Result { }; let resized_width = (source_width * resize_ratio).round().clamp(1.0, FACE_INPUT_WIDTH_F64); - let resized_height = (source_height * resize_ratio) - .round() - .clamp(1.0, FACE_INPUT_HEIGHT_F64); + let resized_height = (source_height * resize_ratio).round().clamp(1.0, FACE_INPUT_HEIGHT_F64); context.clear_rect(0.0, 0.0, FACE_INPUT_WIDTH_F64, FACE_INPUT_HEIGHT_F64); context.draw_image_with_html_video_element_and_dw_and_dh(&video, 0.0, 0.0, resized_width, resized_height)?;