From 922c9c62ff3dd07162a2fd8cd8ea8a13e7898e44 Mon Sep 17 00:00:00 2001 From: Gentle Date: Sat, 7 Feb 2026 19:23:47 +0100 Subject: [PATCH] Add incremental snapshotting support to wizer Add a `keep_instrumentation` option that preserves `__wizer_*` exports in the output module, and a `parse_instrumented` method that re-parses a previously snapshotted module so it can be snapshotted again. This enables savestate-style workflows where a module is initialized, mutated at runtime, and re-snapshotted without re-instrumenting from the original source. --- crates/wizer/src/component/parse.rs | 2 +- crates/wizer/src/component/rewrite.rs | 9 +++- crates/wizer/src/lib.rs | 56 +++++++++++++++++-- crates/wizer/src/parse.rs | 61 ++++++++++++++++++--- crates/wizer/src/rewrite.rs | 21 ++++++++ crates/wizer/tests/all/tests.rs | 78 ++++++++++++++++++++++++++- 6 files changed, 213 insertions(+), 14 deletions(-) diff --git a/crates/wizer/src/component/parse.rs b/crates/wizer/src/component/parse.rs index 2e1449c509e1..7179e026242c 100644 --- a/crates/wizer/src/component/parse.rs +++ b/crates/wizer/src/component/parse.rs @@ -39,7 +39,7 @@ fn parse_into<'a>( // Module sections get parsed with wizer's core wasm support. Payload::ModuleSection { .. } => match &mut cx { Some(component) => { - let info = crate::parse::parse_with(&full_wasm, &mut iter)?; + let info = crate::parse::parse_with(&full_wasm, &mut iter, false)?; component.push_module_section(info); } None => { diff --git a/crates/wizer/src/component/rewrite.rs b/crates/wizer/src/component/rewrite.rs index 4f7c5fc631e0..1039459c2fa0 100644 --- a/crates/wizer/src/component/rewrite.rs +++ b/crates/wizer/src/component/rewrite.rs @@ -37,8 +37,13 @@ impl Wizer { // and the results of that are spliced into the // component. Some(snapshot) => { - let rewritten_wasm = - self.rewrite(module, snapshot, &FuncRenames::default(), false); + let rewritten_wasm = self.rewrite( + module, + snapshot, + &FuncRenames::default(), + false, + false, + ); encoder.section(&wasm_encoder::RawSection { id: wasm_encoder::ComponentSectionId::CoreModule as u8, data: &rewritten_wasm, diff --git a/crates/wizer/src/lib.rs b/crates/wizer/src/lib.rs index 9dfe5f980037..9dcc17baaf6d 100644 --- a/crates/wizer/src/lib.rs +++ b/crates/wizer/src/lib.rs @@ -29,6 +29,7 @@ use std::collections::{HashMap, HashSet}; pub use wasmparser::ValType; const DEFAULT_KEEP_INIT_FUNC: bool = false; +const DEFAULT_KEEP_INSTRUMENTATION: bool = false; /// Wizer: the WebAssembly pre-initializer! /// @@ -98,6 +99,18 @@ pub struct Wizer { arg(long, require_equals = true, value_name = "true|false") )] keep_init_func: Option>, + + /// After initialization, should the Wasm module still contain the + /// instrumentation needed for further snapshotting? + /// + /// This is `false` by default, meaning that instrumentation is stripped + /// from the output module. Set to `true` to enable incremental + /// snapshotting (savestate) workflows. + #[cfg_attr( + feature = "clap", + arg(long, require_equals = true, value_name = "true|false") + )] + keep_instrumentation: Option>, } #[cfg(feature = "clap")] @@ -151,6 +164,7 @@ impl Wizer { init_func: "wizer-initialize".to_string(), func_renames: vec![], keep_init_func: None, + keep_instrumentation: None, } } @@ -184,6 +198,17 @@ impl Wizer { self } + /// After initialization, should the Wasm module still contain the + /// instrumentation needed for further snapshotting? + /// + /// This is `false` by default, meaning that instrumentation is stripped + /// from the output module. Set to `true` to enable incremental + /// snapshotting (savestate) workflows. + pub fn keep_instrumentation(&mut self, keep: bool) -> &mut Self { + self.keep_instrumentation = Some(Some(keep)); + self + } + /// First half of [`Self::run`] which instruments the provided `wasm` and /// produces a new wasm module which should be run by a runtime. /// @@ -193,7 +218,7 @@ impl Wizer { // Make sure we're given valid Wasm from the get go. self.wasm_validate(&wasm)?; - let mut cx = parse::parse(wasm)?; + let mut cx = parse::parse(wasm, false)?; // When wizening core modules directly some imports aren't supported, // so check for those here. @@ -220,6 +245,18 @@ impl Wizer { Ok((cx, instrumented_wasm)) } + /// Parse a previously instrumented Wasm module, returning its + /// [`ModuleContext`]. + /// + /// This is used in incremental snapshotting workflows where the module + /// was already instrumented by a prior call to [`Self::instrument`] or + /// [`Self::snapshot`] (with [`Self::keep_instrumentation`] enabled). The returned context can be + /// passed to [`Self::snapshot`] to produce a new snapshot. + pub fn parse_instrumented<'a>(&self, wasm: &'a [u8]) -> Result> { + self.wasm_validate(wasm)?; + parse::parse(wasm, true) + } + /// Second half of [`Self::run`] which takes the [`ModuleContext`] returned /// by [`Self::instrument`] and the state of the `instance` after it has /// possibly executed its initialization function. @@ -233,12 +270,16 @@ impl Wizer { ) -> Result> { // Parse rename spec. let renames = FuncRenames::parse(&self.func_renames)?; - let snapshot = snapshot::snapshot(&cx, instance).await; - let rewritten_wasm = self.rewrite(&mut cx, &snapshot, &renames, true); + let rewritten_wasm = self.rewrite( + &mut cx, + &snapshot, + &renames, + true, + self.get_keep_instrumentation(), + ); self.debug_assert_valid_wasm(&rewritten_wasm); - Ok(rewritten_wasm) } @@ -369,6 +410,13 @@ impl Wizer { None => DEFAULT_KEEP_INIT_FUNC, } } + + fn get_keep_instrumentation(&self) -> bool { + match self.keep_instrumentation { + Some(keep) => keep.unwrap_or(true), + None => DEFAULT_KEEP_INSTRUMENTATION, + } + } } /// Abstract ability to load state from a WebAssembly instance after it's been diff --git a/crates/wizer/src/parse.rs b/crates/wizer/src/parse.rs index 90e1487aa980..7c06dbebec14 100644 --- a/crates/wizer/src/parse.rs +++ b/crates/wizer/src/parse.rs @@ -3,13 +3,25 @@ use wasmparser::{Encoding, Parser}; use wasmtime::{bail, error::Context as _}; /// Parse the given Wasm bytes into a `ModuleInfo` tree. -pub(crate) fn parse<'a>(full_wasm: &'a [u8]) -> wasmtime::Result> { - parse_with(full_wasm, &mut Parser::new(0).parse_all(full_wasm)) +/// +/// When `instrumented` is true, `__wizer_*` exports are required and used to +/// populate the `defined_global_exports` and `defined_memory_exports` fields +/// rather than being rejected. +pub(crate) fn parse<'a>( + full_wasm: &'a [u8], + instrumented: bool, +) -> wasmtime::Result> { + parse_with( + full_wasm, + &mut Parser::new(0).parse_all(full_wasm), + instrumented, + ) } pub(crate) fn parse_with<'a>( full_wasm: &'a [u8], payloads: &mut impl Iterator>>, + instrumented: bool, ) -> wasmtime::Result> { log::debug!("Parsing the input Wasm"); @@ -36,7 +48,7 @@ pub(crate) fn parse_with<'a>( TableSection(tables) => table_section(&mut module, tables)?, MemorySection(mems) => memory_section(&mut module, mems)?, GlobalSection(globals) => global_section(&mut module, globals)?, - ExportSection(exports) => export_section(&mut module, exports)?, + ExportSection(exports) => export_section(&mut module, exports, instrumented)?, End { .. } => break, _ => {} } @@ -107,14 +119,33 @@ fn global_section<'a>( fn export_section<'a>( module: &mut ModuleContext<'a>, exports: wasmparser::ExportSectionReader<'a>, + instrumented: bool, ) -> wasmtime::Result<()> { + let mut has_instrumentation: bool = false; + let mut defined_global_exports = Vec::new(); + let mut defined_memory_exports = Vec::new(); + for export in exports { let export = export?; if export.name.starts_with("__wizer_") { - wasmtime::bail!( - "input Wasm module already exports entities named with the `__wizer_*` prefix" - ); + if !instrumented { + wasmtime::bail!( + "input Wasm module already exports entities named with the `__wizer_*` prefix" + ); + } + + has_instrumentation = true; + if export.name.starts_with("__wizer_global_") + && export.kind == wasmparser::ExternalKind::Global + { + defined_global_exports.push((export.index, export.name.to_string())); + } else if export.name.starts_with("__wizer_memory_") + && export.kind == wasmparser::ExternalKind::Memory + { + defined_memory_exports.push(export.name.to_string()); + } + continue; } match export.kind { @@ -128,5 +159,23 @@ fn export_section<'a>( } } } + + if instrumented { + if !has_instrumentation { + wasmtime::bail!("input Wasm module is not instrumented") + } + // Sort to match the order expected by defined_globals() and + // defined_memories(). + defined_global_exports.sort_by_key(|(idx, _)| *idx); + defined_memory_exports.sort_by_key(|name| { + name.strip_prefix("__wizer_memory_") + .and_then(|n| n.parse::().ok()) + .unwrap_or(0) + }); + + module.defined_global_exports = Some(defined_global_exports); + module.defined_memory_exports = Some(defined_memory_exports); + } + Ok(()) } diff --git a/crates/wizer/src/rewrite.rs b/crates/wizer/src/rewrite.rs index 33083b5467ce..c7db5ff648bc 100644 --- a/crates/wizer/src/rewrite.rs +++ b/crates/wizer/src/rewrite.rs @@ -10,12 +10,15 @@ impl Wizer { /// Given the initialized snapshot, rewrite the Wasm so that it is already /// initialized. /// + /// When `preserve_instrumentation` is true, the `__wizer_*` exports are + /// preserved so the output can be snapshotted again. pub(crate) fn rewrite( &self, module: &mut ModuleContext<'_>, snapshot: &Snapshot, renames: &FuncRenames, remove_wasi_initialize: bool, + preserve_instrumentation: bool, ) -> Vec { log::debug!("Rewriting input Wasm to pre-initialized state"); @@ -154,6 +157,24 @@ impl Wizer { let kind = RoundtripReencoder.export_kind(export.kind).unwrap(); exports.export(field, kind, export.index); } + + // Re-add __wizer_* exports so the output remains + // instrumentable for future snapshots. + if preserve_instrumentation { + if let Some(ref global_exports) = module.defined_global_exports { + for (idx, name) in global_exports { + exports.export(name, wasm_encoder::ExportKind::Global, *idx); + } + } + if let Some(ref memory_exports) = module.defined_memory_exports { + for ((mem_idx, _), name) in + module.defined_memories().zip(memory_exports) + { + exports.export(name, wasm_encoder::ExportKind::Memory, mem_idx); + } + } + } + encoder.section(&exports); } diff --git a/crates/wizer/tests/all/tests.rs b/crates/wizer/tests/all/tests.rs index a0fe6d9715fc..75f1002f4292 100644 --- a/crates/wizer/tests/all/tests.rs +++ b/crates/wizer/tests/all/tests.rs @@ -5,7 +5,7 @@ use wasmtime::{ error::Context as _, }; use wasmtime_wasi::{WasiCtxBuilder, p1}; -use wasmtime_wizer::Wizer; +use wasmtime_wizer::{WasmtimeWizer, Wizer}; use wat::parse_str as wat_to_wasm; async fn run_wat(args: &[wasmtime::Val], expected: i32, wat: &str) -> Result<()> { @@ -1018,3 +1018,79 @@ async fn memory64() -> Result<()> { let wizer = get_wizer(); wizen_and_run_wasm(&[], 10, &wasm, wizer).await } + +#[tokio::test] +async fn keep_instrumentation_incremental_snapshot() -> Result<()> { + let _ = env_logger::try_init(); + + let wat = r#" +(module + (global $g (mut i32) i32.const 99) + (func (export "wizer-initialize") + i32.const 0 + global.set $g) + (func (export "get") (result i32) + global.get $g) + (func (export "set") (param i32) + local.get 0 + global.set $g)) + "#; + + let wasm = wat_to_wasm(wat)?; + + let mut wizer = Wizer::new(); + wizer.keep_instrumentation(true); + + let (cx, instrumented_wasm) = wizer.instrument(&wasm)?; + + let mut s = store()?; + let module = Module::new(s.engine(), &instrumented_wasm)?; + let instance = instantiate(&mut s, &module).await?; + + let init = instance + .get_typed_func::<(), ()>(&mut s, "wizer-initialize") + .unwrap(); + init.call_async(&mut s, ()).await?; + + let snapshot1 = wizer + .snapshot( + cx, + &mut WasmtimeWizer { + store: &mut s, + instance, + }, + ) + .await?; + + let mut s = store()?; + let module = Module::new(s.engine(), &snapshot1)?; + let instance = instantiate(&mut s, &module).await?; + + let get = instance.get_typed_func::<(), i32>(&mut s, "get")?; + let val = get.call_async(&mut s, ()).await?; + assert_eq!(val, 0, "after first snapshot, global should be 0"); + + let set = instance.get_typed_func::<(i32,), ()>(&mut s, "set")?; + set.call_async(&mut s, (42,)).await?; + + let cx2 = wizer.parse_instrumented(&snapshot1)?; + let snapshot2 = wizer + .snapshot( + cx2, + &mut WasmtimeWizer { + store: &mut s, + instance, + }, + ) + .await?; + + let mut s = store()?; + let module = Module::new(s.engine(), &snapshot2)?; + let instance = instantiate(&mut s, &module).await?; + + let get = instance.get_typed_func::<(), i32>(&mut s, "get")?; + let val = get.call_async(&mut s, ()).await?; + assert_eq!(val, 42, "after second snapshot, global should be 42"); + + Ok(()) +}