diff --git a/Cargo.lock b/Cargo.lock index 4d5ad6b65..7b062d31a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1428,12 +1428,14 @@ dependencies = [ "anyhow", "clap", "heck", + "indexmap", "test-helpers", "wasm-encoder 0.249.0", "wasm-metadata 0.249.0", "wit-bindgen-c", "wit-bindgen-core", "wit-component", + "wit-parser", ] [[package]] diff --git a/crates/cpp/Cargo.toml b/crates/cpp/Cargo.toml index 04a7603f9..d74076ba2 100644 --- a/crates/cpp/Cargo.toml +++ b/crates/cpp/Cargo.toml @@ -26,6 +26,8 @@ wit-bindgen-c = { workspace = true } anyhow = { workspace = true } heck = { workspace = true } clap = { workspace = true, optional = true } +indexmap = { workspace = true } +wit-parser = { workspace = true } [dev-dependencies] test-helpers = { path = '../test-helpers' } diff --git a/crates/cpp/src/lib.rs b/crates/cpp/src/lib.rs index 0348ff7f9..82bcf0764 100644 --- a/crates/cpp/src/lib.rs +++ b/crates/cpp/src/lib.rs @@ -1,5 +1,6 @@ use anyhow::bail; use heck::{ToPascalCase, ToShoutySnakeCase, ToSnakeCase, ToUpperCamelCase}; +use indexmap::{IndexMap, IndexSet}; use std::{ collections::{HashMap, HashSet}, fmt::{self, Display, Write as FmtWrite}, @@ -20,6 +21,7 @@ use wit_bindgen_core::{ WorldKey, }, }; +use wit_parser::TypeIdVisitor; // mod wamr; mod symbol_name; @@ -878,22 +880,12 @@ impl CppInterfaceGenerator<'_> { fn types(&mut self, iface: InterfaceId) { let iface_data = &self.resolve().interfaces[iface]; - // First pass: emit forward declarations for all resources - // This ensures resources can reference each other in method signatures - for (name, id) in iface_data.types.iter() { - let ty = &self.resolve().types[*id]; - if matches!(&ty.kind, TypeDefKind::Resource) { - let pascal = name.to_upper_camel_case(); - let guest_import = self.r#gen.imported_interfaces.contains(&iface); - let namespc = namespace(self.resolve, &ty.owner, !guest_import, &self.r#gen.opts); - self.r#gen.h_src.change_namespace(&namespc); - uwriteln!(self.r#gen.h_src.src, "class {pascal};"); - } - } - - // Second pass: emit full type definitions - for (name, id) in iface_data.types.iter() { - self.define_type(name, *id); + // Here we sort the types topologically before emitting code for them, + // taking into consideration the parameter and return types of functions + // associated with resource types. This ensures each type is declared + // before any uses of it. + for (name, id) in sort_types(self.resolve, &iface_data.types) { + self.define_type(name, id); } } @@ -3629,3 +3621,79 @@ fn is_special_method(func: &Function) -> SpecialMethod { SpecialMethod::None } } + +/// Sort the specified types topologically, taking the parameter and result +/// types of resource functions into consideration. +fn sort_types<'a>( + resolve: &Resolve, + types: &'a IndexMap, +) -> IndexMap<&'a str, TypeId> { + struct Visitor<'a> { + resolve: &'a Resolve, + sorted: IndexSet, + visited: HashSet, + } + + impl TypeIdVisitor for Visitor<'_> { + fn before_visit_type_id(&mut self, id: TypeId) -> bool { + let ty = &self.resolve.types[id]; + if let TypeDefKind::Resource = &ty.kind { + // Here we avoid infinite recursion by remembering if we've seen + // this type already. + if self.visited.contains(&id) { + false + } else { + self.visited.insert(id); + + // TypeIdVisitor does not consider the parameter and return + // types of resource associated functions to be transitive + // types, so we need to handle that ourselves here. + if let TypeOwner::Interface(interface) = ty.owner { + for function in self.resolve.interfaces[interface].functions.values() { + if let Some(resource) = function.kind.resource() + && resource == id + { + for parameter in &function.params { + self.visit_type(self.resolve, ¶meter.ty); + } + + if let Some(ty) = function.result { + self.visit_type(self.resolve, &ty); + } + } + } + } + + true + } + } else { + true + } + } + + fn after_visit_type_id(&mut self, id: TypeId) { + self.sorted.insert(id); + } + } + + let mut visitor = Visitor { + resolve, + sorted: Default::default(), + visited: Default::default(), + }; + + for &id in types.values() { + visitor.visit_type_id(resolve, id); + } + + let names = types + .iter() + .map(|(k, v)| (*v, k.as_str())) + .collect::>(); + + visitor + .sorted + .into_iter() + .filter_map(|v| names.get(&v).map(|k| (*k, v))) + .collect() +} diff --git a/tests/codegen/issue1615-declaration-order.wit b/tests/codegen/issue1615-declaration-order.wit new file mode 100644 index 000000000..84a6aec06 --- /dev/null +++ b/tests/codegen/issue1615-declaration-order.wit @@ -0,0 +1,40 @@ +package test:test; + +interface sqlite { + resource connection { + open: static func(database: string) -> result; + execute: func(statement: string, parameters: list) -> result; + last-insert-rowid: func() -> s64; + changes: func() -> u64; + } + + variant error { + no-such-database, + access-denied, + invalid-connection, + database-full, + io(string) + } + + record query-result { + columns: list, + rows: list, + } + + record row-result { + values: list + } + + variant value { + integer(s64), + real(f64), + text(string), + blob(list), + null + } +} + +world test { + import sqlite; + export foo: func(); +}