Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion crates/wizer/src/component/parse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ fn parse_into<'a>(
// Module sections get parsed with wizer's core wasm support.
Payload::ModuleSection { .. } => match &mut cx {
Some(component) => {
let info = crate::parse::parse_with(&full_wasm, &mut iter)?;
let info = crate::parse::parse_with(&full_wasm, &mut iter, false)?;
component.push_module_section(info);
}
None => {
Expand Down
9 changes: 7 additions & 2 deletions crates/wizer/src/component/rewrite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,13 @@ impl Wizer {
// and the results of that are spliced into the
// component.
Some(snapshot) => {
let rewritten_wasm =
self.rewrite(module, snapshot, &FuncRenames::default(), false);
let rewritten_wasm = self.rewrite(
module,
snapshot,
&FuncRenames::default(),
false,
false,
);
encoder.section(&wasm_encoder::RawSection {
id: wasm_encoder::ComponentSectionId::CoreModule as u8,
data: &rewritten_wasm,
Expand Down
56 changes: 52 additions & 4 deletions crates/wizer/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ use std::collections::{HashMap, HashSet};
pub use wasmparser::ValType;

const DEFAULT_KEEP_INIT_FUNC: bool = false;
const DEFAULT_KEEP_INSTRUMENTATION: bool = false;

/// Wizer: the WebAssembly pre-initializer!
///
Expand Down Expand Up @@ -98,6 +99,18 @@ pub struct Wizer {
arg(long, require_equals = true, value_name = "true|false")
)]
keep_init_func: Option<Option<bool>>,

/// After initialization, should the Wasm module still contain the
/// instrumentation needed for further snapshotting?
///
/// This is `false` by default, meaning that instrumentation is stripped
/// from the output module. Set to `true` to enable incremental
/// snapshotting (savestate) workflows.
#[cfg_attr(
feature = "clap",
arg(long, require_equals = true, value_name = "true|false")
)]
keep_instrumentation: Option<Option<bool>>,
}

#[cfg(feature = "clap")]
Expand Down Expand Up @@ -151,6 +164,7 @@ impl Wizer {
init_func: "wizer-initialize".to_string(),
func_renames: vec![],
keep_init_func: None,
keep_instrumentation: None,
}
}

Expand Down Expand Up @@ -184,6 +198,17 @@ impl Wizer {
self
}

/// After initialization, should the Wasm module still contain the
/// instrumentation needed for further snapshotting?
///
/// This is `false` by default, meaning that instrumentation is stripped
/// from the output module. Set to `true` to enable incremental
/// snapshotting (savestate) workflows.
pub fn keep_instrumentation(&mut self, keep: bool) -> &mut Self {
self.keep_instrumentation = Some(Some(keep));
self
}

/// First half of [`Self::run`] which instruments the provided `wasm` and
/// produces a new wasm module which should be run by a runtime.
///
Expand All @@ -193,7 +218,7 @@ impl Wizer {
// Make sure we're given valid Wasm from the get go.
self.wasm_validate(&wasm)?;

let mut cx = parse::parse(wasm)?;
let mut cx = parse::parse(wasm, false)?;

// When wizening core modules directly some imports aren't supported,
// so check for those here.
Expand All @@ -220,6 +245,18 @@ impl Wizer {
Ok((cx, instrumented_wasm))
}

/// Parse a previously instrumented Wasm module, returning its
/// [`ModuleContext`].
///
/// This is used in incremental snapshotting workflows where the module
/// was already instrumented by a prior call to [`Self::instrument`] or
/// [`Self::snapshot`] (with [`Self::keep_instrumentation`] enabled). The returned context can be
/// passed to [`Self::snapshot`] to produce a new snapshot.
pub fn parse_instrumented<'a>(&self, wasm: &'a [u8]) -> Result<ModuleContext<'a>> {
self.wasm_validate(wasm)?;
parse::parse(wasm, true)
}

/// Second half of [`Self::run`] which takes the [`ModuleContext`] returned
/// by [`Self::instrument`] and the state of the `instance` after it has
/// possibly executed its initialization function.
Expand All @@ -233,12 +270,16 @@ impl Wizer {
) -> Result<Vec<u8>> {
// Parse rename spec.
let renames = FuncRenames::parse(&self.func_renames)?;

let snapshot = snapshot::snapshot(&cx, instance).await;
let rewritten_wasm = self.rewrite(&mut cx, &snapshot, &renames, true);
let rewritten_wasm = self.rewrite(
&mut cx,
&snapshot,
&renames,
true,
self.get_keep_instrumentation(),
);

self.debug_assert_valid_wasm(&rewritten_wasm);

Ok(rewritten_wasm)
}

Expand Down Expand Up @@ -369,6 +410,13 @@ impl Wizer {
None => DEFAULT_KEEP_INIT_FUNC,
}
}

fn get_keep_instrumentation(&self) -> bool {
match self.keep_instrumentation {
Some(keep) => keep.unwrap_or(true),
None => DEFAULT_KEEP_INSTRUMENTATION,
}
}
}

/// Abstract ability to load state from a WebAssembly instance after it's been
Expand Down
61 changes: 55 additions & 6 deletions crates/wizer/src/parse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,25 @@ use wasmparser::{Encoding, Parser};
use wasmtime::{bail, error::Context as _};

/// Parse the given Wasm bytes into a `ModuleInfo` tree.
pub(crate) fn parse<'a>(full_wasm: &'a [u8]) -> wasmtime::Result<ModuleContext<'a>> {
parse_with(full_wasm, &mut Parser::new(0).parse_all(full_wasm))
///
/// When `instrumented` is true, `__wizer_*` exports are required and used to
/// populate the `defined_global_exports` and `defined_memory_exports` fields
/// rather than being rejected.
pub(crate) fn parse<'a>(
full_wasm: &'a [u8],
instrumented: bool,
) -> wasmtime::Result<ModuleContext<'a>> {
parse_with(
full_wasm,
&mut Parser::new(0).parse_all(full_wasm),
instrumented,
)
}

pub(crate) fn parse_with<'a>(
full_wasm: &'a [u8],
payloads: &mut impl Iterator<Item = wasmparser::Result<wasmparser::Payload<'a>>>,
instrumented: bool,
) -> wasmtime::Result<ModuleContext<'a>> {
log::debug!("Parsing the input Wasm");

Expand All @@ -36,7 +48,7 @@ pub(crate) fn parse_with<'a>(
TableSection(tables) => table_section(&mut module, tables)?,
MemorySection(mems) => memory_section(&mut module, mems)?,
GlobalSection(globals) => global_section(&mut module, globals)?,
ExportSection(exports) => export_section(&mut module, exports)?,
ExportSection(exports) => export_section(&mut module, exports, instrumented)?,
End { .. } => break,
_ => {}
}
Expand Down Expand Up @@ -107,14 +119,33 @@ fn global_section<'a>(
fn export_section<'a>(
module: &mut ModuleContext<'a>,
exports: wasmparser::ExportSectionReader<'a>,
instrumented: bool,
) -> wasmtime::Result<()> {
let mut has_instrumentation: bool = false;
let mut defined_global_exports = Vec::new();
let mut defined_memory_exports = Vec::new();

for export in exports {
let export = export?;

if export.name.starts_with("__wizer_") {
wasmtime::bail!(
"input Wasm module already exports entities named with the `__wizer_*` prefix"
);
if !instrumented {
wasmtime::bail!(
"input Wasm module already exports entities named with the `__wizer_*` prefix"
);
}

has_instrumentation = true;
if export.name.starts_with("__wizer_global_")
&& export.kind == wasmparser::ExternalKind::Global
{
defined_global_exports.push((export.index, export.name.to_string()));
} else if export.name.starts_with("__wizer_memory_")
&& export.kind == wasmparser::ExternalKind::Memory
{
defined_memory_exports.push(export.name.to_string());
}
continue;
}

match export.kind {
Expand All @@ -128,5 +159,23 @@ fn export_section<'a>(
}
}
}

if instrumented {
if !has_instrumentation {
wasmtime::bail!("input Wasm module is not instrumented")
}
// Sort to match the order expected by defined_globals() and
// defined_memories().
defined_global_exports.sort_by_key(|(idx, _)| *idx);
defined_memory_exports.sort_by_key(|name| {
name.strip_prefix("__wizer_memory_")
.and_then(|n| n.parse::<u32>().ok())
.unwrap_or(0)
});

module.defined_global_exports = Some(defined_global_exports);
module.defined_memory_exports = Some(defined_memory_exports);
}

Ok(())
}
21 changes: 21 additions & 0 deletions crates/wizer/src/rewrite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,15 @@ impl Wizer {
/// Given the initialized snapshot, rewrite the Wasm so that it is already
/// initialized.
///
/// When `preserve_instrumentation` is true, the `__wizer_*` exports are
/// preserved so the output can be snapshotted again.
pub(crate) fn rewrite(
&self,
module: &mut ModuleContext<'_>,
snapshot: &Snapshot,
renames: &FuncRenames,
remove_wasi_initialize: bool,
preserve_instrumentation: bool,
) -> Vec<u8> {
log::debug!("Rewriting input Wasm to pre-initialized state");

Expand Down Expand Up @@ -154,6 +157,24 @@ impl Wizer {
let kind = RoundtripReencoder.export_kind(export.kind).unwrap();
exports.export(field, kind, export.index);
}

// Re-add __wizer_* exports so the output remains
// instrumentable for future snapshots.
if preserve_instrumentation {
if let Some(ref global_exports) = module.defined_global_exports {
for (idx, name) in global_exports {
exports.export(name, wasm_encoder::ExportKind::Global, *idx);
}
}
if let Some(ref memory_exports) = module.defined_memory_exports {
for ((mem_idx, _), name) in
module.defined_memories().zip(memory_exports)
{
exports.export(name, wasm_encoder::ExportKind::Memory, mem_idx);
}
}
}

encoder.section(&exports);
}

Expand Down
78 changes: 77 additions & 1 deletion crates/wizer/tests/all/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use wasmtime::{
error::Context as _,
};
use wasmtime_wasi::{WasiCtxBuilder, p1};
use wasmtime_wizer::Wizer;
use wasmtime_wizer::{WasmtimeWizer, Wizer};
use wat::parse_str as wat_to_wasm;

async fn run_wat(args: &[wasmtime::Val], expected: i32, wat: &str) -> Result<()> {
Expand Down Expand Up @@ -1018,3 +1018,79 @@ async fn memory64() -> Result<()> {
let wizer = get_wizer();
wizen_and_run_wasm(&[], 10, &wasm, wizer).await
}

#[tokio::test]
async fn keep_instrumentation_incremental_snapshot() -> Result<()> {
let _ = env_logger::try_init();

let wat = r#"
(module
(global $g (mut i32) i32.const 99)
(func (export "wizer-initialize")
i32.const 0
global.set $g)
(func (export "get") (result i32)
global.get $g)
(func (export "set") (param i32)
local.get 0
global.set $g))
"#;

let wasm = wat_to_wasm(wat)?;

let mut wizer = Wizer::new();
wizer.keep_instrumentation(true);

let (cx, instrumented_wasm) = wizer.instrument(&wasm)?;

let mut s = store()?;
let module = Module::new(s.engine(), &instrumented_wasm)?;
let instance = instantiate(&mut s, &module).await?;

let init = instance
.get_typed_func::<(), ()>(&mut s, "wizer-initialize")
.unwrap();
init.call_async(&mut s, ()).await?;

let snapshot1 = wizer
.snapshot(
cx,
&mut WasmtimeWizer {
store: &mut s,
instance,
},
)
.await?;

let mut s = store()?;
let module = Module::new(s.engine(), &snapshot1)?;
let instance = instantiate(&mut s, &module).await?;

let get = instance.get_typed_func::<(), i32>(&mut s, "get")?;
let val = get.call_async(&mut s, ()).await?;
assert_eq!(val, 0, "after first snapshot, global should be 0");

let set = instance.get_typed_func::<(i32,), ()>(&mut s, "set")?;
set.call_async(&mut s, (42,)).await?;

let cx2 = wizer.parse_instrumented(&snapshot1)?;
let snapshot2 = wizer
.snapshot(
cx2,
&mut WasmtimeWizer {
store: &mut s,
instance,
},
)
.await?;

let mut s = store()?;
let module = Module::new(s.engine(), &snapshot2)?;
let instance = instantiate(&mut s, &module).await?;

let get = instance.get_typed_func::<(), i32>(&mut s, "get")?;
let val = get.call_async(&mut s, ()).await?;
assert_eq!(val, 42, "after second snapshot, global should be 42");

Ok(())
}