diff --git a/.agents/docs-and-formatting.md b/.agents/docs-and-formatting.md index 753c2ac815..a8b7e54250 100644 --- a/.agents/docs-and-formatting.md +++ b/.agents/docs-and-formatting.md @@ -45,10 +45,9 @@ Load this file when changing documentation, public APIs, protocol specs, benchma - Python code, including `compiler/`, `benchmarks/`, `integration_tests/`, and `python/`: `python -m ruff format ` and `python -m ruff check --fix ` -- JavaScript/TypeScript under `javascript/`: use the package's ESLint-owned formatting path - (`npm run lint -- --fix` when fixing style, `npm run lint -- --quiet` when checking). Do not run - Prettier on JavaScript or TypeScript files unless that package has an explicit Prettier config or - script; otherwise it creates unrelated formatting churn. +- JavaScript/TypeScript under `javascript/`: use the package formatter scripts + (`npm run format` when fixing style, `npm run format-check` when checking). They run Prettier + before ESLint; do not use raw ESLint as the formatting gate. - Repo-wide format and lint sweep: `bash ci/format.sh --all` When code changes touch `compiler/` or `benchmarks/`, format those changed source files with the diff --git a/.agents/languages/cpp.md b/.agents/languages/cpp.md index 8a28fe0d03..d1047934b2 100644 --- a/.agents/languages/cpp.md +++ b/.agents/languages/cpp.md @@ -17,6 +17,29 @@ Load this file when changing `cpp/`, Cython build plumbing, or C++ xlang behavio - Do not redesign alias-based or low-level public type shapes to add convenience methods unless the user explicitly asks for that API change. - For cross-language feature ports, match protocol behavior but use idiomatic C++ ownership and layering instead of mirroring Java structure literally. - Compatible scalar, list-array, and binary/uint8-array adaptations are immediate-field-only. Recursive matched-field comparison for collection elements, array elements, map keys, and map values must require exact nullability, ref tracking, generic arity, and type shape except documented user-type family normalization. +- Root deserialization graph budgets are owned by `ReadContext` and initialized by the root + `Fory::deserialize` overload. Keep `max_graph_memory_bytes` as a fixed-default graph limit: + unset/default is `128 MiB`, positive explicit values override it, and explicit non-positive + values are invalid at config creation. Byte and stream roots use the same + configured/default budget behavior. Root `Fory` overloads reset the budget only; they must not + pre-reserve root type or root self bytes. + Do not mirror the configured max into a second active-limit field; use config plus mutable + remaining budget. + Reserve estimated shallow graph-owner memory before allocation while preserving existing + byte-availability checks and their non-empty metadata ordering. `ReadContext` may expose only raw + byte reservation; collection, map, array, struct, and object + formulas belong in serializer owners. Skip dedicated string, binary, primitive scalar, primitive + vector, and primitive dense-array leaf owners; `std::vector` charges rounded packed-bit + storage. General `std::vector` for non-primitive `T` is inline value storage and must be + reserved by the vector owner. +- C++ graph budget formulas must be portable lower-bound estimates, not STL heap-layout accounting. + Generic collection-like containers reserve `count_or_capacity * sizeof(value_type)`, map-like + containers reserve `count * (sizeof(key_type) + sizeof(mapped_type))`, and set-like containers + reserve `count * sizeof(key_type)`. Root struct/product owners and smart-pointer/box allocation + owners reserve shallow self storage exactly once; nested value serializers reserve only dynamic + storage they allocate, not their own inline self storage again. Do not add guessed + node/header/debug-STL overhead, red-black-tree fields, allocator probing, object-layout + inspection, generic per-entry pointer overhead, or unordered bucket-table guesses. ## Key Paths diff --git a/.agents/languages/csharp.md b/.agents/languages/csharp.md index 098f9a50fe..09924e7637 100644 --- a/.agents/languages/csharp.md +++ b/.agents/languages/csharp.md @@ -12,6 +12,26 @@ Load this file when changing `csharp/` or C# xlang behavior. - Generated C# gRPC service companions are compiler-owned files that depend on application-provided gRPC packages, not `csharp/src/Fory`. Keep gRPC package references out of the Fory runtime package. - C# generated schema modules are source-file owners. Service companions must use that module's `ThreadSafeFory` and must not introduce namespace-owned aliases or duplicate serializer registration paths. - Compatible scalar, list-array, and binary/uint8-array adaptations are immediate-field-only. Recursive matched-field comparison for collection elements, array elements, map keys, and map values must require exact nullability, ref tracking, generic arity, and type shape except documented user-type family normalization. +- Root deserialization graph memory budget state belongs to `ReadContext`. C# public roots are + memory-backed today, but the graph budget uses the same fixed default for every root shape. + Root APIs reset the budget only; they must not pre-reserve root type or root self bytes. + Do not mirror the configured max into a second active-limit field; use config plus mutable + remaining budget. + `ReadContext` may expose only raw byte reservation; concrete serializers and generated + serializers must compute list, array, map, struct, and object byte formulas before calling it. +- `ReadContext` must not expose ref-publication pause/resume APIs or any non-budget owner + controls. Concrete serializers and generated serializers own ref publication timing directly, + and must not publish temporary owners. +- For C# graph budget formulas, distinguish inline value storage from reference storage: use cheap + value-type size for `List`/`T[]` value paths and the 4-byte reference fallback for reference + paths. Class/reference serializers reserve their own shallow self cost plus field storage when + materialized; struct/value serializers reserve self storage only on standalone serializer, + dynamic/boxed, or root materialization entries because field, list, array, map, set, and box + holders reserve the inline storage they own. Maps reserve key plus value storage, linked/hash/tree + conversions must not add guessed node or entry overhead, and independently materialized + collection/map/array owners reserve nonzero shallow self cost. + Dedicated string, binary, primitive scalar, and primitive dense-array serializers stay skipped and + rely on byte availability checks. - When extending C# tests from Java references, prioritize xlang spec behavior and the public C# contract before adding complex Java-specific parity cases. ## Commands diff --git a/.agents/languages/dart.md b/.agents/languages/dart.md index d4f6bebbbc..b6e96751ff 100644 --- a/.agents/languages/dart.md +++ b/.agents/languages/dart.md @@ -14,6 +14,21 @@ Load this file when changing `dart/`. - Keep root numeric wrapper defaults separate from generated field metadata. Root wrapper resolution belongs in the builtin resolver, while annotations and generated metadata choose fixed, tagged, or declared-field encodings. - Dart 64-bit carriers are optimized for each platform. Do not replace native extension-type wrappers with allocation-heavy classes or route web/native hot paths through `BigInt` unless the user approves a representation change. - In `Buffer`, cursor, serializer, and generated-code hot paths, prefer direct byte/local integer operations and conditional import/export files over callbacks, records, holder objects, wrapper round-trips, or runtime platform branches. +- Root deserialization graph memory budgets are owned by `ReadContext`; + `maxGraphMemoryBytes` defaults to fixed `128 MiB`, positive explicit values override it, and + explicit non-positive values are invalid at config creation. Do not derive the budget from + `buffer.readableBytes`. `ReadContext` may expose only raw byte reservation; list, set, map, array, + struct, and object formulas + belong in serializer owners. Reserve Dart list/set/object-array reference + slots plus nonzero owner self cost, map key/value slots plus nonzero owner + self cost, compatible array-to-list materialization, and generated object reads before + allocation. Compatible list-to-typed-array reads skip the dense primitive-array leaf owner while + preserving byte checks. Object/struct + owners reserve nonzero shallow self memory plus shallow field storage. Skip + only dedicated string, binary, primitive scalar, `BoolList`, and typed-array + dense owner paths with byte checks. Do not add stream bytes-read accounting, + per-element accounting, extra hot-path allocations, or stale narrower-scope + formulas. - Do not add parallel header-low/header-high slot caches or multi-slot recent caches in TypeMeta hot paths to chase benchmark gaps. Header-cache hits must use the concrete checked cache owner directly; if a hit hint is needed, cache one TypeInfo/TypeMeta object and compare the validated header identity on that object, not separate low/high header fields or benchmark-pattern state. - If Dart TypeMeta cache ownership changes, keep the invariant in a source comment near the hit path: a checked metadata-cache hit skips the body and must not grow low-bit sentinels, accepted-header fields, parallel header slots, or benchmark-pattern state. - Dart expected-type TypeDef reads should compare the expected `TypeInfo` object's cached local TypeDef header before consulting the parsed-metadata map. A match is a direct local-schema hit: skip the remote body, add the expected type to the per-read shared type table, and do not publish to `ParsedTypeMetaCache`, record a remote schema version, or parse/hash the body. diff --git a/.agents/languages/go.md b/.agents/languages/go.md index 94d47fe94c..bad052077a 100644 --- a/.agents/languages/go.md +++ b/.agents/languages/go.md @@ -7,6 +7,23 @@ Load this file when changing `go/fory/` or Go xlang behavior. - Run Go commands from within `go/fory/`. - Changes under `go/` must pass formatting and tests. - The Go implementation focuses on reflection-based and codegen-based serialization. +- Root deserialization graph memory budgets are owned by `ReadContext`. + `WithMaxGraphMemoryBytes` uses a fixed `128 MiB` default; positive explicit + values override it, and explicit non-positive values are invalid at config creation. + Byte-slice and stream roots use the same + configured/default budget behavior. Root APIs reset the budget only; they must not pre-reserve + root type or root self bytes. Do not mirror the configured max into a second active-limit field; + root setup should update only the mutable remaining budget. `ReadContext` may expose only raw byte + reservation; slice, map, array, struct, and object + formulas belong in handwritten or generated serializer owners. Reserve Go + slices as `len * elemBytes`, maps as `len * (keyBytes + valueBytes)`, + map-backed sets, and LIST-encoded inline/value slices in the owner that + allocates that storage. Struct root materialization paths, pointer allocations, and generated + allocation entries reserve shallow value storage exactly once; nested inline + struct serializers do not charge their own self storage again. Fixed arrays + are caller-owned unless a read path materializes a temporary owner. Skip + dedicated string, binary, BufferObject, primitive scalar, primitive ARRAY + slice, and primitive array owners with byte checks. - Set `FORY_PANIC_ON_ERROR=1` when debugging a failing Go test so you get the full call stack. - Do not set `FORY_PANIC_ON_ERROR=1` when running the full Go test suite, because some tests assert on error contents. diff --git a/.agents/languages/java.md b/.agents/languages/java.md index 41b19b206d..7021dc18be 100644 --- a/.agents/languages/java.md +++ b/.agents/languages/java.md @@ -14,6 +14,25 @@ Load this file when changing anything under `java/` or when Java drives a cross- values; use qualified names only when a real name conflict requires it. - If you run temporary tests with `java -cp`, run `mvn -T16 install -DskipTests` first so local Fory jars are current. - `WriteContext`, `ReadContext`, and `CopyContext` must stay explicit. Do not reintroduce `ThreadLocal` or ambient runtime-context patterns. +- Java root deserialization graph memory budgeting belongs to `ReadContext` + and is initialized by `Fory` root APIs. Public config is `maxGraphMemoryBytes` + with fixed `128 MiB` default. Positive explicit values override the default; + explicit non-positive values are invalid and must be rejected at config creation. + Byte-array, memory-buffer, and stream roots use the same configured/default + budget behavior. Root APIs reset the budget only; they must not pre-reserve + root type or root self bytes. Do not mirror the configured max into a second + active-limit field; use config plus mutable remaining budget. `ReadContext` + may expose only raw byte reservation; + collection, map, array, struct, and object formulas belong in the concrete + serializer or generated serializer owner. Java collection, map, and + object-array owners reserve nonzero shallow self cost plus reference storage; + referenced object serializers reserve their own nonzero shallow self memory + plus shallow field storage when materialized. + Reference fields use the 4-byte fallback when the JVM reference size is not + queried cheaply; primitive fields use their encoded storage width. Preserve + existing `checkReadableBytes` guards before backing allocation or capacity + reservation. Do not add nested serializer-path `try/finally`, per-element + work, dynamic stream bytes-read accounting, or stale narrower-scope formulas. - Generated serializers must not retain runtime context fields. `Fory` should stay a root-operation facade rather than accumulating serializer or convenience state. - When the serializer class and constructor shape are known at the call site, prefer direct constructor lambdas or direct instantiation over reflective `Serializers.newSerializer(...)`. - For GraalVM, use `fory codegen` to generate serializers when building native images. Do not add reflection configuration except for JDK `proxy`. diff --git a/.agents/languages/javascript.md b/.agents/languages/javascript.md index e6d62fa494..9bf3f62b3f 100644 --- a/.agents/languages/javascript.md +++ b/.agents/languages/javascript.md @@ -14,6 +14,18 @@ Load this file when changing `javascript/`. - Runtime value carriers such as decimal or reduced-precision numeric types belong under the core `types/` ownership boundary, with imports, exports, and codegen externals updated together. - Keep `TypeInfo` as schema metadata. Compatibility-sensitive decisions belong on `TypeResolver` or explicit operations, not as retained resolver state on metadata objects. - Normalize optional boolean config values at config construction; do not carry `null` through runtime paths when it means `false`. +- JavaScript root deserialization graph memory budgeting belongs to `ReadContext`. + `maxGraphMemoryBytes` uses a fixed `128 MiB` default, positive explicit limits override it, and + explicit non-positive values are invalid at config creation. Do not derive the budget from the + `Uint8Array` root length. `ReadContext` may expose only raw + byte reservation; generated and dynamic + list/set/map/array/struct/object readers must reserve before allocation while preserving existing + byte checks. Lists/sets/object arrays reserve nonzero owner self cost plus 4-byte reference slots, + maps reserve nonzero owner self cost plus key/value reference storage, object/struct readers + reserve nonzero shallow self memory plus shallow field storage, compatible array-to-list reads + reserve target list materialization, and compatible list-to-typed-array reads skip the dense + primitive-array leaf owner while preserving byte checks. Keep dedicated string, binary, primitive + scalar, and dense typed-array leaf owners out of this budget. - Regenerated compatible read serializers are remote-schema-specific. After classification marks a field as direct, compatible scalar, or skip, generated JavaScript should emit straight-line remote-field-order code. Do not add an outer matched-id switch unless the current regenerated shape cannot preserve those semantics. - Compatible scalar codegen must decide the exact remote/local scalar pair before emitting source. Generate the concrete `reader.readXxx()` call plus inline trivial conversions such as boolean-to-string or numeric widening, and keep helpers only for semantic validation such as range checks, exactness checks, decimal parsing/formatting, and string-to-bool. Do not call a generic hot-path converter that redispatches on `remoteTypeId`, `localTypeId`, field descriptors, or field names. - Compatible scalar conversion is immediate-field-only. Recursive schema comparison for collection elements, array elements, map keys, and map values must reject scalar mismatches instead of applying the top-level scalar conversion matrix. diff --git a/.agents/languages/python.md b/.agents/languages/python.md index 7365632c37..5e8005c9b2 100644 --- a/.agents/languages/python.md +++ b/.agents/languages/python.md @@ -13,6 +13,20 @@ Load this file when changing `python/`, Cython serialization, or Python xlang be - Cython mode owns the hot runtime path. Do not duplicate core runtime types between Python and Cython, tunnel Python facade methods into hidden Cython internals, or keep dead shims unless the user explicitly needs a compatibility module path. - Use explicit Cython fields and methods for fixed hot-path shapes. Avoid `__getattr__`, generic `object` fields, public bridge internals, or `Fory` backreferences where ownership can stay explicit. - Keep Python and Cython context/ref-tracking branch conditions and stack mutations semantically aligned unless a documented intentional difference exists. +- Root deserialization graph memory budgets are owned by pure-Python and Cython `ReadContext`. + Keep `max_graph_memory_bytes` public on `pyfory.Fory`/`Config`; the default effective limit is + fixed `128 MiB`, positive explicit values override it, and explicit non-positive values are + invalid at config creation. Byte and stream roots use the same + configured/default budget behavior. Do not mirror the configured max into a + second active-limit field; keep one configured max plus mutable remaining + budget. `ReadContext` may expose only raw + byte reservation; collection, dict, array, struct, and object + formulas belong in the pure-Python or Cython serializer owner. Lists, tuples, sets, and + object-dtype ndarray item storage reserve nonzero owner self cost plus `count * PyObject*`; dicts + reserve nonzero owner self cost plus `entryCount * 2 * PyObject*`. Python object owners reserve a + nonzero shallow self cost plus shallow field/reference storage. Keep string, bytes, primitive + scalar, `array.array`, primitive dense array, and primitive ndarray owners skipped, and preserve + byte-availability checks after budget reservation. - Public value constructors should accept normal Python values. Raw-bit, raw-buffer, and memoryview entry points should be explicit low-level APIs, and packed carriers should expose the buffer protocol from the actual storage owner when appropriate. - When debugging runtime or benchmark behavior, install the local package into the exact interpreter under test instead of relying on mixed `PYTHONPATH` state. - For wheel or extension pipeline changes, derive extension-module paths from current build targets, packaging config, or wheel payload discovery rather than historical module names. diff --git a/.agents/languages/rust.md b/.agents/languages/rust.md index ffe5648330..ea57f19624 100644 --- a/.agents/languages/rust.md +++ b/.agents/languages/rust.md @@ -18,6 +18,27 @@ Load this file when changing `rust/` or Rust xlang behavior. - If breakage is explicitly acceptable during a Rust module refactor, rewire macros, tests, and sibling crates directly to the new boundaries instead of adding compatibility re-exports. - For panic-safety in hot paths, preserve TLS context reuse. Add scoped guards or owned fallbacks rather than per-call context allocation, and reset reused contexts at entry and successful exit. - Compatible scalar, list-array, and binary/uint8-array adaptations are immediate-field-only. Keep recursive matched-field shape classification owned by `fory-core/src/meta/type_meta.rs`; collection elements, array elements, map keys, and map values must require exact nullability, ref tracking, generic arity, and type shape except documented user-type family normalization. +- Root deserialization graph memory budget state belongs to `ReadContext` and is initialized by the + root `Fory` read methods before the header is consumed. Use the fixed `128 MiB` default unless a + positive explicit value overrides it; explicit non-positive values are invalid at config + creation. Root `Fory` read methods reset the budget only; they must not pre-reserve root type or + root self bytes. Do not derive the budget from root input size or add dynamic bytes-read + accounting. + Do not mirror the configured max into a second active-limit field; keep one configured max plus + mutable remaining budget. + `ReadContext` may expose only raw byte reservation; `Vec`, + collection, map, array, struct, object, and derive codec formulas belong in their serializer + owners. +- Rust `Vec` stores inline element storage, so general LIST paths reserve + `len * size_of::()`, including `Vec` and `Vec`. Maps reserve + `len * (size_of::() + size_of::())`. Root/product/box owners reserve shallow value storage + exactly once; nested inline value serializers do not charge their own self storage again. + Dedicated primitive dense ARRAY `Vec` readers, strings, binary, primitive scalars, and + primitive fixed-array owners stay skipped and keep their byte checks. +- Direct `Serializer` collection/map paths and derive `Codec` collection/map paths are separate + allocation owners. Keep reservations in both before `Vec::with_capacity`, + `HashMap::with_capacity`, or collection materialization. Empty non-leaf owners that allocate an + independent owner object or storage reserve nonzero shallow self cost. ## Key Paths diff --git a/.agents/languages/scala.md b/.agents/languages/scala.md index 03f22b70b6..cdc9459bed 100644 --- a/.agents/languages/scala.md +++ b/.agents/languages/scala.md @@ -16,6 +16,8 @@ sbt compile # Run tests sbt test -# Format code -sbt scalafmt +# Repo-owned formatter pass for changed files +cd .. && ci/format.sh ``` + +The Scala module does not currently wire a `scalafmt` sbt command. diff --git a/.agents/languages/swift.md b/.agents/languages/swift.md index ec493ea1ac..2b73998a52 100644 --- a/.agents/languages/swift.md +++ b/.agents/languages/swift.md @@ -6,6 +6,8 @@ Load this file when changing `swift/` or Swift xlang behavior. - Run Swift commands from within `swift/`. - Changes under `swift/` must pass lint and tests. +- SwiftLint is mandatory but not sufficient for readability; manually clean touched Swift source + formatting before committing because this repo does not run a Swift formatter in CI. - Swift code must compile without compiler warnings. Treat warnings as blockers, including warnings in generated Swift code. - Swift lint uses `swift/.swiftlint.yml`. - Swift formatting uses `swift/.swift-format`; do not rely on SwiftLint for indentation or source @@ -15,6 +17,21 @@ Load this file when changing `swift/` or Swift xlang behavior. - Preserve distinct temporal semantics. Timestamp values and day-only local dates should have protocol-accurate helper names and no stale aliases after a refactor. - When temporal or public-type refactors touch generated Swift code, sweep message fields, union payloads, macros, xlang harnesses, and integration fixtures together. - Compatible scalar, list-array, and binary/uint8-array adaptations are immediate-field-only. Recursive matched-field comparison for collection elements, array elements, map keys, and map values must require exact nullability, ref tracking, generic arity, and type shape except documented user-type family normalization. +- Root deserialization graph memory budget state belongs to `ReadContext`. Swift public roots are + `Data` and `ByteBuffer`, and both use the same fixed default graph budget; do not add stream + bytes-read accounting or serializer-local budget state. Root APIs reset the budget only; they must + not pre-reserve root type or root self bytes. `ReadContext` may expose only raw byte reservation; + array, set, map, struct, and object formulas belong in serializer and field-codec owners. +- For Swift graph budget formulas, distinguish inline/value storage from reference storage: use + `MemoryLayout.stride` for value arrays/lists/sets/maps and the 4-byte reference fallback for + `Serializer.isRefType` / `FieldCodec.isRefType` paths. Class/reference paths reserve their own + shallow self cost plus field storage when materialized; value serializers reserve self storage + only on standalone serializer, generated, or root materialization entries because field, array, + set, map, and box holders reserve inline storage exactly once. Independently materialized + collection/map/array owners reserve nonzero shallow self cost plus backing/reference/inline + storage. Dedicated `String`, `Data`/binary, + primitive scalar, and primitive packed-array owners stay skipped, except compatible + packed-array-to-list reads must reserve the target list materialization before allocation. ## Commands diff --git a/.agents/repo-reference.md b/.agents/repo-reference.md index 4d420b1f05..f110bf9bb3 100644 --- a/.agents/repo-reference.md +++ b/.agents/repo-reference.md @@ -80,6 +80,16 @@ Apache Fory is a multi-language serialization framework with multiple wire forma copy/decompression and before field-list allocation, and never add cache-hit or generated-reader hot-path work for them. +## Root Graph Memory Budget Ownership + +Root graph memory budgeting is a read-state accounting feature only. Read context or equivalent +read state may expose raw byte reservation and, when a runtime cannot reasonably avoid it, +root-operation budget setup/reset. Root facades may reset the per-operation budget, but must not +pre-reserve root type or root self bytes. It must not grow semantic APIs for collection, map, array, +struct, object, temporary-owner, serializer-owner, conversion, counted-allocation, or +ref-publication control. Concrete serializers and generated serializers own allocation formulas, +overflow checks, root value storage reservation, and reference publication timing. + ## Runtime Map ### Java diff --git a/.agents/skills/fory-performance-optimization/SKILL.md b/.agents/skills/fory-performance-optimization/SKILL.md index 2bbe17f801..a66955da8f 100644 --- a/.agents/skills/fory-performance-optimization/SKILL.md +++ b/.agents/skills/fory-performance-optimization/SKILL.md @@ -15,6 +15,8 @@ Deliver measurable performance improvements in Apache Fory without protocol drif - Profile before changing hot code. - Change one bottleneck at a time. - Benchmark sequentially on the same machine state (one benchmark process at a time). +- Compare old/new benchmark results case-by-case in adjacent pairs: run one case on `apache/main`, + then immediately run that same case on the current branch before moving to the next case. - Keep only measured wins or explicitly requested architecture cleanups. - Revert speculative changes that do not pay off. - Align with reference runtimes (usually C++ first, then Rust/Java) when behavior and ownership models differ. @@ -73,6 +75,8 @@ Deliver measurable performance improvements in Apache Fory without protocol drif 7. Benchmark and compare. - Run targeted benchmark at least twice sequentially. +- Pair each baseline case with the matching current-branch case before starting another case, so + both measurements see closer machine load conditions. - Use longer duration when signal is noisy. - Run one short full-suite sanity benchmark to catch collateral regressions. diff --git a/.agents/skills/fory-performance-optimization/references/language-command-matrix.md b/.agents/skills/fory-performance-optimization/references/language-command-matrix.md index abd1fb8a6a..cf22665b8b 100644 --- a/.agents/skills/fory-performance-optimization/references/language-command-matrix.md +++ b/.agents/skills/fory-performance-optimization/references/language-command-matrix.md @@ -79,7 +79,8 @@ Canonical runtime-specific rules now live under `../../../languages/*.md` and `. - Build: `sbt compile` - Tests: `sbt test` -- Format: `sbt scalafmt` +- Format: run the repo formatter from the repository root with `ci/format.sh`; the Scala module + does not currently wire a `scalafmt` sbt command. ## Cross-Language Xlang Verification diff --git a/AGENTS.md b/AGENTS.md index 5cd2581d5c..421ea8c81f 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -32,7 +32,8 @@ This is the entry point for AI guidance in Apache Fory. Read this file first, th - Respect ownership. Keep logic, state, and helpers in their natural owner, and do not move serializer-local, context-local, runtime-type-local, or protocol-local problems into global utilities. - Check the spec before implementation. For wire behavior and xlang mapping, use the specs as the source of truth and never copy one runtime's bug into another runtime just to make tests pass. - Do not make assumptions about runtime behavior, ownership, registration, metadata construction, protocol semantics, or test coverage. Read the current code, owning docs/specs, and relevant tests before making a design judgment or implementation decision. If the evidence is incomplete, inspect more or state the uncertainty explicitly instead of filling gaps from memory or analogy with another runtime. -- For untrusted deserialization, read `docs/security/deserialization.md` before changing allocation, stream filling, skip, reference, metadata, or policy validation behavior. Variable-length deserialization must not allocate or reserve from attacker-declared lengths or counts before the byte owner has proven proportional readable bytes with `checkReadableBytes` or the runtime equivalent. +- For untrusted deserialization, read `docs/security/deserialization.md` before changing allocation, stream filling, skip, reference, metadata, or policy validation behavior. Variable-length deserialization must not allocate or reserve backing/output capacity from attacker-declared lengths or counts before the byte owner has proven proportional readable bytes with `checkReadableBytes` or the runtime equivalent. Root graph memory reservation is accounting only and may happen before that byte check, but it must not replace the byte check. +- Root deserialization graph memory budgets estimate lower-bound shallow memory for materialized graph owners, not exact heap accounting, input byte accounting, or raw element counts. `maxGraphMemoryBytes` defaults to fixed `128 MiB`; positive values override the default; explicit non-positive values are invalid and must be rejected at config/Fory creation. Do not add a disabled-budget sentinel path, derive this budget from root input size, or split known-length and stream root behavior. Read context/read state owns only raw byte reservation with `reserveGraphMemory(bytes)`; it must not expose counted arithmetic helpers or collection, map, array, struct, or object semantic reservation APIs. Do not add any non-memory-budget read-context/read-state API for this feature, including ref-publication controls, temporary-owner controls, serializer-owner controls, conversion helpers, or APIs that encode what kind of value is being materialized. Root facades may set/reset the per-operation budget, but they must not pre-reserve root type or root self bytes; the serializer or materialization owner reserves root value storage. Because the budget is fixed per root, read state must not mirror the configured max into a second active-limit field; use the existing config or one configured max field plus the mutable remaining budget. Concrete serializers and generated serializers own counted formulas, overflow checks, allocation-owner decisions, and reference publication timing for their allocation path. Reserve self storage exactly once at the owner that stores or allocates the value: reference/object runtimes reserve parent owner self cost plus reference storage and every referenced heap owner reserves its own shallow self cost when materialized; inline/value runtimes reserve value storage in the struct/product, collection, map, set, array, smart-pointer, box, or root materialization owner that actually owns that storage, and nested value serializers do not unconditionally reserve their own `sizeof(T)` again. Collection, map, set, and reference-array owners reserve nonzero shallow self cost when independently materialized, plus backing/reference/inline storage. Struct, record, POJO, tuple/product, compatible, generated, and dynamic object owners reserve a nonzero shallow self cost plus shallow field storage. Reference fields use 4 bytes when reference size is not cheap or reliable to query; primitive/value fields use their storage width. Parents do not recursively include child object, collection, map, string, binary, or primitive dense-array contents. Skip enum/union as separate owners and skip dedicated string, binary, primitive scalar, primitive array, and primitive dense-array leaf owners. Do not guess allocator, bucket-table, node, debug, or map-entry overhead unless it is a documented lower-bound owner allocation. Do not add dynamic stream bytes-read accounting or nested hot-path cleanup just for this budget. - For remote TypeDef/TypeMeta reads, the checked metadata cache is the only owner of remote "already validated" state. Cache hit means the header was previously parsed, body/hash-validated, policy-checked, and published by that cache, so the hot path must skip the body and use cached metadata without extra validation, hashing, limit checks, exact-local checks, allocation, or policy work. A known expected local TypeDef/TypeMeta header/hash match is a local-schema hit, not a remote cache miss: it may skip the body and use the local TypeInfo/TypeMeta without schema-version counting or cache publish. Cache miss is the only path that parses and validates non-local metadata, enforces limits, performs exact-local byte comparison when needed, and publishes remote metadata to the cache. Do not add nullable accepted-header fields, sentinel headers, per-TypeInfo markers, pending metadata state, parallel header-low/header-high slots, or parallel acceptance state for this decision. If a runtime needs a metadata hit hint, cache the concrete checked metadata owner object, such as the TypeInfo, TypeDef, or TypeMeta used by that runtime, and compare its validated header identity directly. - When a user corrects a non-obvious invariant, encode it in the nearest source comment before continuing, and also update `AGENTS.md`, `.agents/**`, docs, or specs when the rule is reusable beyond one file. Do not rely only on chat history, task notes, commit messages, or benchmark logs for corrections that protect security, protocol behavior, ownership, naming, or hot-path performance. - Reject semantic hacks. Do not bypass broken semantics by deleting cases, simplifying callers, adding coercion hooks, or using workaround fallbacks; fix the underlying bug and prove it with focused tests. @@ -111,6 +112,19 @@ This is the entry point for AI guidance in Apache Fory. Read this file first, th - `docs/DEVELOPMENT.md` plus updates under `docs/guide/` and `docs/benchmarks/` are synced to `apache/fory-site`; other website content belongs there. - When benchmark logic, scripts, configuration, or compared serializers change, rerun the relevant benchmarks and refresh the artifacts under `docs/benchmarks/**`. +## Network Error Command Log + +- 2026-06-26: `cargo check` from `rust/` failed while updating crates.io through + `127.0.0.1:7890`; retried as + `env -u all_proxy -u http_proxy -u https_proxy -u ALL_PROXY -u HTTP_PROXY -u HTTPS_PROXY cargo check`, + which still used the configured proxy. `cargo check --offline` succeeded using the local Cargo + cache. +- 2026-06-26: `cmake -S . -B ../tasks/cpp-cmake-build -DFORY_BUILD_TESTS=ON +-DFORY_BUILD_SHARED=OFF -DFORY_BUILD_STATIC=ON` from `cpp/` failed while FetchContent tried to + clone googletest through `127.0.0.1:7890`; retried as + `env -u all_proxy -u http_proxy -u https_proxy -u ALL_PROXY -u HTTP_PROXY -u HTTPS_PROXY cmake -S . -B ../tasks/cpp-cmake-build -DFORY_BUILD_TESTS=ON -DFORY_BUILD_SHARED=OFF -DFORY_BUILD_STATIC=ON`, + which still used the configured proxy in the nested clone. + ## Shared Engineering Expectations - Favor zero-copy techniques, JIT or codegen opportunities, and cache-friendly memory access patterns in performance-critical paths. @@ -162,6 +176,7 @@ This is the entry point for AI guidance in Apache Fory. Read this file first, th - When comparing benchmark results against `apache/main`, use a separate sibling worktree named `fory-benchmark-baseline` by default. Before creating a new worktree, check whether `../fory-benchmark-baseline` already exists and reuse it to avoid repeated benchmark dependency rebuilds. Always fetch and sync that baseline worktree to the latest `apache/main` before measuring it, and store benchmark result files under that worktree so older runs remain available as reference data. Treat stored benchmark results as historical references, not truth, because machine load and benchmark variance change over time. Create a different baseline worktree only when explicitly requested. - Before benchmarking a checked-out version, install or build the required Fory packages for that version, such as the Java artifacts, Python package, and the target runtime package needed by the benchmark. - Run and close old/new benchmark comparisons for exactly one language at a time. If that language has a slowdown greater than 1%, keep working only on that language until the slowdown is within 1% before moving to the next language. +- Within one language, compare benchmarks case-by-case in adjacent old/new pairs: run one case on fresh `apache/main`, then immediately run the same case on the current branch before moving to the next case. Do not batch all baseline cases and then all current cases, because machine load drift makes that comparison noisier. - Treat a same-benchmark slowdown greater than 1% as unresolved until the retained median is within 1% of the baseline. Faster results are acceptable only after verifying that generated code, benchmark shape, safety checks, and protocol semantics did not skip required work. Do not add artificial slowdowns or benchmark-shape changes to force a match. - Do not change protocol behavior, benchmark payloads, or public APIs solely to manufacture performance wins. - For performance work, run the relevant benchmark immediately after each change and report the command plus before/after numbers. diff --git a/cpp/fory/meta/field_info.h b/cpp/fory/meta/field_info.h index 0da5d22b07..4e066c72a8 100644 --- a/cpp/fory/meta/field_info.h +++ b/cpp/fory/meta/field_info.h @@ -740,8 +740,8 @@ constexpr auto concat_tuples_from_tuple(const Tuple &tuple) { static inline constexpr size_t Size = 0; \ static inline constexpr std::string_view Name = #type; \ static inline constexpr std::array Names = {}; \ - static constexpr bool has_config = false; \ - static inline constexpr auto entries = std::tuple{}; \ + [[maybe_unused]] static constexpr bool has_config = false; \ + [[maybe_unused]] static inline constexpr auto entries = std::tuple{}; \ [[maybe_unused]] static constexpr size_t field_count = 0; \ using PtrsType = decltype(std::tuple{}); \ static constexpr PtrsType ptrs() { return {}; } \ diff --git a/cpp/fory/serialization/BUILD b/cpp/fory/serialization/BUILD index b74c356a2b..67df6188ae 100644 --- a/cpp/fory/serialization/BUILD +++ b/cpp/fory/serialization/BUILD @@ -109,6 +109,16 @@ cc_test( ], ) +cc_test( + name = "graph_memory_budget_test", + srcs = ["graph_memory_budget_test.cc"], + deps = [ + ":fory_serialization", + "@googletest//:gtest", + "@googletest//:gtest_main", + ], +) + cc_test( name = "variant_serializer_test", srcs = ["variant_serializer_test.cc"], diff --git a/cpp/fory/serialization/CMakeLists.txt b/cpp/fory/serialization/CMakeLists.txt index 0b88f0e5e3..a3e5cd58bc 100644 --- a/cpp/fory/serialization/CMakeLists.txt +++ b/cpp/fory/serialization/CMakeLists.txt @@ -102,6 +102,11 @@ if(FORY_BUILD_TESTS) target_link_libraries(fory_serialization_map_test fory_serialization GTest::gtest GTest::gtest_main) gtest_discover_tests(fory_serialization_map_test) + add_executable(fory_serialization_graph_memory_budget_test graph_memory_budget_test.cc) + fory_configure_target(fory_serialization_graph_memory_budget_test) + target_link_libraries(fory_serialization_graph_memory_budget_test fory_serialization GTest::gtest GTest::gtest_main) + gtest_discover_tests(fory_serialization_graph_memory_budget_test) + add_executable(fory_serialization_variant_test variant_serializer_test.cc) fory_configure_target(fory_serialization_variant_test) target_link_libraries(fory_serialization_variant_test fory_serialization GTest::gtest GTest::gtest_main) diff --git a/cpp/fory/serialization/collection_serializer.h b/cpp/fory/serialization/collection_serializer.h index 473f9d6950..417b958727 100644 --- a/cpp/fory/serialization/collection_serializer.h +++ b/cpp/fory/serialization/collection_serializer.h @@ -29,6 +29,7 @@ #include #include #include +#include #include #include #include @@ -388,6 +389,21 @@ inline bool reserve_collection(Container &result, ReadContext &ctx, if (FORY_PREDICT_FALSE(ctx.has_error())) { return false; } + using Elem = typename Container::value_type; + constexpr size_t elem_bytes = sizeof(Elem); + // Portable lower-bound estimate only: STL node/header/debug-layout details + // differ across implementations, so generic collections charge value storage. + if (FORY_PREDICT_FALSE(static_cast(length) > + std::numeric_limits::max() / elem_bytes)) { + ctx.set_error(Error::invalid_data( + "graph memory estimate overflows: length=" + std::to_string(length) + + " elementBytes=" + std::to_string(elem_bytes))); + return false; + } + if (FORY_PREDICT_FALSE(!ctx.reserve_graph_memory(static_cast(length) * + elem_bytes))) { + return false; + } if (FORY_PREDICT_FALSE(!ctx.buffer().ensure_readable(length, ctx.error()))) { return false; } @@ -397,10 +413,37 @@ inline bool reserve_collection(Container &result, ReadContext &ctx, return true; } +template +inline bool reserve_collection(std::vector &result, + ReadContext &ctx, uint32_t length) { + if (FORY_PREDICT_FALSE(ctx.has_error())) { + return false; + } + const size_t length_bytes = static_cast(length); + if (FORY_PREDICT_FALSE(length_bytes > + std::numeric_limits::max() - 7)) { + ctx.set_error(Error::invalid_data( + "graph memory estimate overflows: length=" + std::to_string(length) + + " elementBytes=1")); + return false; + } + const size_t packed_bytes = (length_bytes + 7) / 8; + if (FORY_PREDICT_FALSE(!ctx.reserve_graph_memory(packed_bytes))) { + return false; + } + if (FORY_PREDICT_FALSE(!ctx.buffer().ensure_readable(length, ctx.error()))) { + return false; + } + result.reserve(length); + return true; +} + // Helper to insert element into container (vector or set) template inline void collection_insert(Container &result, T &&elem) { - if constexpr (has_push_back_v) { + if constexpr (is_forward_list_v) { + result.push_front(std::forward(elem)); + } else if constexpr (has_push_back_v) { result.push_back(std::forward(elem)); } else { result.insert(std::forward(elem)); @@ -414,9 +457,6 @@ inline Container read_collection_data_slow(ReadContext &ctx, uint32_t length) { if (length == 0) { return result; } - if (FORY_PREDICT_FALSE(!reserve_collection(result, ctx, length))) { - return result; - } constexpr bool elem_is_polymorphic = is_polymorphic_v; @@ -443,6 +483,10 @@ inline Container read_collection_data_slow(ReadContext &ctx, uint32_t length) { } } + if (FORY_PREDICT_FALSE(!reserve_collection(result, ctx, length))) { + return result; + } + // Read elements if (is_same_type) { if (track_ref) { @@ -482,8 +526,9 @@ inline Container read_collection_data_slow(ReadContext &ctx, uint32_t length) { } bool has_value = read_null_only_flag(ctx, RefMode::NullOnly); if (!has_value) { - if constexpr (has_push_back_v) { - result.push_back(T{}); + if constexpr (has_push_back_v || + is_forward_list_v) { + collection_insert(result, T{}); } // For sets, skip null elements } else { @@ -508,8 +553,9 @@ inline Container read_collection_data_slow(ReadContext &ctx, uint32_t length) { } bool has_value = read_null_only_flag(ctx, RefMode::NullOnly); if (!has_value) { - if constexpr (has_push_back_v) { - result.push_back(T{}); + if constexpr (has_push_back_v || + is_forward_list_v) { + collection_insert(result, T{}); } } else { // Read type info + data without ref flag @@ -530,6 +576,9 @@ inline Container read_collection_data_slow(ReadContext &ctx, uint32_t length) { } } + if constexpr (is_forward_list_v) { + result.reverse(); + } return result; } @@ -902,6 +951,11 @@ struct Serializer< if (ctx.has_error() || !has_value) { return std::vector(); } + if (ref_mode != RefMode::None && + FORY_PREDICT_FALSE( + (!reserve_allocated_value_owner>(ctx)))) { + return std::vector(); + } // Optional type info for polymorphic containers if (read_type) { @@ -922,16 +976,16 @@ struct Serializer< return std::vector(); } - // Per xlang spec: header and type_info are omitted when length is 0 - if (length == 0) { - return std::vector(); - } - // Dispatch to slow path for polymorphic/shared-ref elements constexpr bool is_slow_path = is_polymorphic_v || is_shared_ref_v; if constexpr (is_slow_path) { return read_collection_data_slow>(ctx, length); } else { + std::vector result; + // Per xlang spec: header and type_info are omitted when length is 0 + if (length == 0) { + return result; + } // Fast path for non-polymorphic, non-shared-ref elements // Elements header bitmap (CollectionFlags) @@ -961,7 +1015,6 @@ struct Serializer< } } - std::vector result; if (FORY_PREDICT_FALSE(!reserve_collection(result, ctx, length))) { return result; } @@ -1151,11 +1204,11 @@ template struct Serializer> { return std::vector(); } - if (FORY_PREDICT_FALSE(!ctx.buffer().ensure_readable(size, ctx.error()))) { - return std::vector(); + std::vector result; + if (FORY_PREDICT_FALSE(!reserve_collection(result, ctx, size))) { + return result; } - - std::vector result(size); + result.resize(size); Buffer &buffer = ctx.buffer(); if (size > 0) { const uint8_t *src = buffer.data() + buffer.reader_index(); @@ -1197,6 +1250,11 @@ template struct Serializer> { if (ctx.has_error() || !has_value) { return std::list(); } + if (ref_mode != RefMode::None && + FORY_PREDICT_FALSE( + (!reserve_allocated_value_owner>(ctx)))) { + return std::list(); + } // Optional type info for polymorphic containers if (read_type) { @@ -1217,16 +1275,16 @@ template struct Serializer> { return std::list(); } - // Per xlang spec: header and type_info are omitted when length is 0 - if (length == 0) { - return std::list(); - } - // Dispatch to slow path for polymorphic/shared-ref elements constexpr bool is_slow_path = is_polymorphic_v || is_shared_ref_v; if constexpr (is_slow_path) { return read_collection_data_slow>(ctx, length); } else { + std::list result; + // Per xlang spec: header and type_info are omitted when length is 0 + if (length == 0) { + return result; + } // Fast path for non-polymorphic, non-shared-ref elements // Elements header bitmap (CollectionFlags) @@ -1256,7 +1314,9 @@ template struct Serializer> { } } - std::list result; + if (FORY_PREDICT_FALSE(!reserve_collection(result, ctx, length))) { + return result; + } // Fast path: no tracking, no nulls, elements have declared type if (!track_ref && !has_null && is_same_type) { @@ -1349,6 +1409,12 @@ template struct Serializer> { } std::list result; + if (size == 0) { + return result; + } + if (FORY_PREDICT_FALSE(!reserve_collection(result, ctx, size))) { + return result; + } for (uint32_t i = 0; i < size; ++i) { if (FORY_PREDICT_FALSE(ctx.has_error())) { return result; @@ -1389,6 +1455,11 @@ template struct Serializer> { if (ctx.has_error() || !has_value) { return std::deque(); } + if (ref_mode != RefMode::None && + FORY_PREDICT_FALSE( + (!reserve_allocated_value_owner>(ctx)))) { + return std::deque(); + } // Optional type info for polymorphic containers if (read_type) { @@ -1409,16 +1480,16 @@ template struct Serializer> { return std::deque(); } - // Per xlang spec: header and type_info are omitted when length is 0 - if (length == 0) { - return std::deque(); - } - // Dispatch to slow path for polymorphic/shared-ref elements constexpr bool is_slow_path = is_polymorphic_v || is_shared_ref_v; if constexpr (is_slow_path) { return read_collection_data_slow>(ctx, length); } else { + std::deque result; + // Per xlang spec: header and type_info are omitted when length is 0 + if (length == 0) { + return result; + } // Fast path for non-polymorphic, non-shared-ref elements // Elements header bitmap (CollectionFlags) @@ -1448,7 +1519,9 @@ template struct Serializer> { } } - std::deque result; + if (FORY_PREDICT_FALSE(!reserve_collection(result, ctx, length))) { + return result; + } // Fast path: no tracking, no nulls, elements have declared type if (!track_ref && !has_null && is_same_type) { @@ -1541,6 +1614,12 @@ template struct Serializer> { } std::deque result; + if (size == 0) { + return result; + } + if (FORY_PREDICT_FALSE(!reserve_collection(result, ctx, size))) { + return result; + } for (uint32_t i = 0; i < size; ++i) { if (FORY_PREDICT_FALSE(ctx.has_error())) { return result; @@ -1582,6 +1661,12 @@ struct Serializer> { if (ctx.has_error() || !has_value) { return std::forward_list(); } + if (ref_mode != RefMode::None && + FORY_PREDICT_FALSE( + (!reserve_allocated_value_owner>( + ctx)))) { + return std::forward_list(); + } // Optional type info for polymorphic containers if (read_type) { @@ -1602,25 +1687,25 @@ struct Serializer> { return std::forward_list(); } + std::forward_list result; // Per xlang spec: header and type_info are omitted when length is 0 if (length == 0) { - return std::forward_list(); + return result; } // Dispatch to slow path for polymorphic/shared-ref elements - // Read elements into a temporary vector then build forward_list - // (forward_list doesn't have push_back, only push_front) - std::vector temp; constexpr bool is_slow_path = is_polymorphic_v || is_shared_ref_v; if constexpr (is_slow_path) { - temp = read_collection_data_slow>(ctx, length); + return read_collection_data_slow>(ctx, + length); } else { + auto tail = result.before_begin(); // Fast path for non-polymorphic, non-shared-ref elements // Elements header bitmap (CollectionFlags) uint8_t bitmap = ctx.read_uint8(ctx.error()); if (FORY_PREDICT_FALSE(ctx.has_error())) { - return std::forward_list(); + return result; } bool track_ref = (bitmap & COLL_TRACKING_REF) != 0; bool has_null = (bitmap & COLL_HAS_NULL) != 0; @@ -1632,7 +1717,7 @@ struct Serializer> { if (is_same_type && !is_decl_type) { const TypeInfo *elem_type_info = ctx.read_any_type_info(ctx.error()); if (FORY_PREDICT_FALSE(ctx.has_error())) { - return std::forward_list(); + return result; } using ElemType = nullable_element_t; uint32_t expected = @@ -1644,54 +1729,48 @@ struct Serializer> { } } - if (FORY_PREDICT_FALSE(!reserve_collection(temp, ctx, length))) { - return std::forward_list(); + if (FORY_PREDICT_FALSE(!reserve_collection(result, ctx, length))) { + return result; } // Fast path: no tracking, no nulls, elements have declared type if (!track_ref && !has_null && is_same_type) { for (uint32_t i = 0; i < length; ++i) { if (FORY_PREDICT_FALSE(ctx.has_error())) { - break; + return result; } auto elem = Serializer::read(ctx, RefMode::None, false); - temp.push_back(std::move(elem)); + tail = result.insert_after(tail, std::move(elem)); } } else { // General path: handle HAS_NULL and/or TRACKING_REF for (uint32_t i = 0; i < length; ++i) { if (FORY_PREDICT_FALSE(ctx.has_error())) { - break; + return result; } if (track_ref) { auto elem = Serializer::read(ctx, RefMode::Tracking, false); - temp.push_back(std::move(elem)); + tail = result.insert_after(tail, std::move(elem)); } else if (has_null) { bool has_value_elem = read_null_only_flag(ctx, RefMode::NullOnly); if (!has_value_elem) { - temp.emplace_back(); + tail = result.emplace_after(tail); } else { if constexpr (is_nullable_v) { using Inner = nullable_element_t; auto inner = Serializer::read(ctx, RefMode::None, false); - temp.emplace_back(std::move(inner)); + tail = result.emplace_after(tail, std::move(inner)); } else { auto elem = Serializer::read(ctx, RefMode::None, false); - temp.push_back(std::move(elem)); + tail = result.insert_after(tail, std::move(elem)); } } } else { auto elem = Serializer::read(ctx, RefMode::None, false); - temp.push_back(std::move(elem)); + tail = result.insert_after(tail, std::move(elem)); } } } } - - // Build forward_list in reverse order using push_front - std::forward_list result; - for (auto it = temp.rbegin(); it != temp.rend(); ++it) { - result.push_front(std::move(*it)); - } return result; } @@ -1968,21 +2047,20 @@ struct Serializer> { return std::forward_list(); } - std::vector temp; - if (FORY_PREDICT_FALSE(!reserve_collection(temp, ctx, size))) { - return std::forward_list(); + std::forward_list result; + if (size == 0) { + return result; } + if (FORY_PREDICT_FALSE(!reserve_collection(result, ctx, size))) { + return result; + } + auto tail = result.before_begin(); for (uint32_t i = 0; i < size; ++i) { if (FORY_PREDICT_FALSE(ctx.has_error())) { - break; + return result; } auto elem = Serializer::read_data(ctx); - temp.push_back(std::move(elem)); - } - // Build forward_list in reverse order - std::forward_list result; - for (auto it = temp.rbegin(); it != temp.rend(); ++it) { - result.push_front(std::move(*it)); + tail = result.insert_after(tail, std::move(elem)); } return result; } @@ -2049,6 +2127,11 @@ struct Serializer> { if (ctx.has_error() || !has_value) { return std::set(); } + if (ref_mode != RefMode::None && + FORY_PREDICT_FALSE( + (!reserve_allocated_value_owner>(ctx)))) { + return std::set(); + } // Read type info if (read_type) { @@ -2069,16 +2152,16 @@ struct Serializer> { return std::set(); } - // Per xlang spec: header and type_info are omitted when length is 0 - if (size == 0) { - return std::set(); - } - // Dispatch to slow path for polymorphic/shared-ref elements constexpr bool is_slow_path = is_polymorphic_v || is_shared_ref_v; if constexpr (is_slow_path) { return read_collection_data_slow>(ctx, size); } else { + std::set result; + // Per xlang spec: header and type_info are omitted when length is 0 + if (size == 0) { + return result; + } // Fast path for non-polymorphic, non-shared-ref elements // Read elements header bitmap (CollectionFlags) in xlang mode @@ -2094,17 +2177,20 @@ struct Serializer> { if (is_same_type && !is_decl_type) { const TypeInfo *elem_type_info = ctx.read_any_type_info(ctx.error()); if (FORY_PREDICT_FALSE(ctx.has_error())) { - return std::set(); + return result; } uint32_t expected = static_cast(Serializer::type_id); if (!type_id_matches(elem_type_info->type_id, expected)) { ctx.set_error( Error::type_mismatch(elem_type_info->type_id, expected)); - return std::set(); + return result; } } - std::set result; + if (FORY_PREDICT_FALSE(!reserve_collection(result, ctx, size))) { + return result; + } + if (!track_ref && !has_null && is_same_type) { for (uint32_t i = 0; i < size; ++i) { if (FORY_PREDICT_FALSE(ctx.has_error())) { @@ -2151,6 +2237,12 @@ struct Serializer> { } std::set result; + if (size == 0) { + return result; + } + if (FORY_PREDICT_FALSE(!reserve_collection(result, ctx, size))) { + return result; + } for (uint32_t i = 0; i < size; ++i) { if (FORY_PREDICT_FALSE(ctx.has_error())) { return result; @@ -2224,6 +2316,12 @@ struct Serializer> { if (ctx.has_error() || !has_value) { return std::unordered_set(); } + if (ref_mode != RefMode::None && + FORY_PREDICT_FALSE( + (!reserve_allocated_value_owner>( + ctx)))) { + return std::unordered_set(); + } // Read type info if (read_type) { @@ -2244,17 +2342,17 @@ struct Serializer> { return std::unordered_set(); } - // Per xlang spec: header and type_info are omitted when length is 0 - if (size == 0) { - return std::unordered_set(); - } - // Dispatch to slow path for polymorphic/shared-ref elements constexpr bool is_slow_path = is_polymorphic_v || is_shared_ref_v; if constexpr (is_slow_path) { return read_collection_data_slow>(ctx, size); } else { + std::unordered_set result; + // Per xlang spec: header and type_info are omitted when length is 0 + if (size == 0) { + return result; + } // Fast path for non-polymorphic, non-shared-ref elements // Read elements header bitmap (CollectionFlags) in xlang mode @@ -2270,20 +2368,20 @@ struct Serializer> { if (is_same_type && !is_decl_type) { const TypeInfo *elem_type_info = ctx.read_any_type_info(ctx.error()); if (FORY_PREDICT_FALSE(ctx.has_error())) { - return std::unordered_set(); + return result; } uint32_t expected = static_cast(Serializer::type_id); if (!type_id_matches(elem_type_info->type_id, expected)) { ctx.set_error( Error::type_mismatch(elem_type_info->type_id, expected)); - return std::unordered_set(); + return result; } } - std::unordered_set result; if (FORY_PREDICT_FALSE(!reserve_collection(result, ctx, size))) { return result; } + if (!track_ref && !has_null && is_same_type) { for (uint32_t i = 0; i < size; ++i) { if (FORY_PREDICT_FALSE(ctx.has_error())) { @@ -2330,6 +2428,9 @@ struct Serializer> { } std::unordered_set result; + if (size == 0) { + return result; + } if (FORY_PREDICT_FALSE(!reserve_collection(result, ctx, size))) { return result; } diff --git a/cpp/fory/serialization/config.h b/cpp/fory/serialization/config.h index eac9c14436..84d3b89a96 100644 --- a/cpp/fory/serialization/config.h +++ b/cpp/fory/serialization/config.h @@ -52,6 +52,10 @@ struct Config { /// When enabled, avoids duplicating shared objects and handles cycles. bool track_ref = true; + /// Maximum estimated graph memory accepted during one root deserialization. + /// Value must be a positive byte limit. + int64_t max_graph_memory_bytes = 128LL * 1024LL * 1024LL; + /// Maximum accepted field count in one received struct TypeMeta. uint32_t max_type_fields = 512; diff --git a/cpp/fory/serialization/context.cc b/cpp/fory/serialization/context.cc index deff5ee16c..7e8569d218 100644 --- a/cpp/fory/serialization/context.cc +++ b/cpp/fory/serialization/context.cc @@ -434,7 +434,10 @@ uint32_t WriteContext::get_type_id_for_cache(const std::type_index &type_idx) { ReadContext::ReadContext(const Config &config, std::unique_ptr type_resolver) : buffer_(nullptr), config_(&config), - type_resolver_(std::move(type_resolver)), current_dyn_depth_(0) {} + type_resolver_(std::move(type_resolver)), current_dyn_depth_(0) { + FORY_CHECK(config.max_graph_memory_bytes > 0) + << "max_graph_memory_bytes must be positive"; +} ReadContext::~ReadContext() = default; @@ -739,6 +742,14 @@ const TypeInfo *ReadContext::read_any_type_info(Error &error) { return result.value(); } +bool ReadContext::set_graph_memory_exceeded(size_t bytes, size_t remaining) { + set_error(Error::invalid_data( + "estimated graph memory request " + std::to_string(bytes) + + " bytes exceeds max_graph_memory_bytes remaining budget " + + std::to_string(remaining) + " bytes")); + return false; +} + void ReadContext::reset() { // Clear error state first error_ = Error(); @@ -747,6 +758,7 @@ void ReadContext::reset() { } reading_type_infos_.clear(); current_dyn_depth_ = 0; + remaining_graph_memory_bytes_ = 0; if (meta_string_table_active_) { meta_string_table_.reset(); meta_string_table_active_ = false; diff --git a/cpp/fory/serialization/context.h b/cpp/fory/serialization/context.h index 6af99c4ccc..3ddbf8b4af 100644 --- a/cpp/fory/serialization/context.h +++ b/cpp/fory/serialization/context.h @@ -32,6 +32,7 @@ #include "fory/util/result.h" #include +#include #include #include @@ -504,6 +505,15 @@ class ReadContext { } } + FORY_ALWAYS_INLINE bool reserve_graph_memory(size_t bytes) { + const size_t remaining = remaining_graph_memory_bytes_; + if (FORY_PREDICT_FALSE(bytes > remaining)) { + return set_graph_memory_exceeded(bytes, remaining); + } + remaining_graph_memory_bytes_ = remaining - bytes; + return true; + } + // =========================================================================== // Read methods with Error& parameter // All methods accept Error& as parameter for reduced overhead. @@ -659,9 +669,12 @@ class ReadContext { inline const Config &config() const { return *config_; } private: + friend class Fory; + FORY_NOINLINE Result check_remote_type_meta_limit(const TypeMeta &type_meta); void record_remote_type_meta(const std::string &type_key); + FORY_NOINLINE bool set_graph_memory_exceeded(size_t bytes, size_t remaining); // Error state - accumulated during deserialization, checked at the end Error error_; @@ -671,6 +684,7 @@ class ReadContext { std::unique_ptr type_resolver_; RefReader ref_reader_; uint32_t current_dyn_depth_; + size_t remaining_graph_memory_bytes_ = 0; // Meta sharing state (for compatible mode) // Persistent cache storage for TypeInfo objects keyed by meta header. diff --git a/cpp/fory/serialization/fory.h b/cpp/fory/serialization/fory.h index 36ef992d17..527a385e32 100644 --- a/cpp/fory/serialization/fory.h +++ b/cpp/fory/serialization/fory.h @@ -40,6 +40,7 @@ #include "fory/util/result.h" #include "fory/util/stream.h" #include +#include #include #include #include @@ -109,6 +110,15 @@ class ForyBuilder { return *this; } + /// Set maximum estimated graph memory for one root deserialization. + /// + /// Defaults to 128 MiB. Values must be positive byte limits. + ForyBuilder &max_graph_memory_bytes(int64_t max_bytes) { + FORY_CHECK(max_bytes > 0) << "max_graph_memory_bytes must be positive"; + config_.max_graph_memory_bytes = max_bytes; + return *this; + } + /// Set maximum accepted field count in one received struct TypeMeta. ForyBuilder &max_type_fields(uint32_t max_fields) { FORY_CHECK(max_fields > 0) << "max_type_fields must be positive"; @@ -673,19 +683,7 @@ class Fory : public BaseFory { Buffer buffer(const_cast(data), static_cast(size), false); - - Error header_error; - const uint8_t header = buffer.read_uint8(header_error); - if (FORY_PREDICT_FALSE(!header_error.ok())) { - return Unexpected(std::move(header_error)); - } - if (FORY_PREDICT_FALSE(header != precomputed_header_)) { - return Unexpected(invalid_root_header(header)); - } - - read_ctx_->attach(buffer); - ReadContextGuard guard(*read_ctx_); - return deserialize_impl(buffer); + return deserialize_buffer(buffer); } /// Deserialize an object from a byte vector. @@ -711,18 +709,7 @@ class Fory : public BaseFory { if (FORY_PREDICT_FALSE(!finalized_)) { ensure_finalized(); } - Error header_error; - const uint8_t header = buffer.read_uint8(header_error); - if (FORY_PREDICT_FALSE(!header_error.ok())) { - return Unexpected(std::move(header_error)); - } - if (FORY_PREDICT_FALSE(header != precomputed_header_)) { - return Unexpected(invalid_root_header(header)); - } - - read_ctx_->attach(buffer); - ReadContextGuard guard(*read_ctx_); - return deserialize_impl(buffer); + return deserialize_buffer(buffer); } /// Deserialize an object from an input stream. @@ -745,7 +732,10 @@ class Fory : public BaseFory { }; StreamShrinkGuard shrink_guard{&input_stream}; Buffer &buffer = input_stream.get_buffer(); - return deserialize(buffer); + if (FORY_PREDICT_FALSE(!finalized_)) { + ensure_finalized(); + } + return deserialize_buffer(buffer); } /// Deserialize an object from StdInputStream. @@ -883,6 +873,26 @@ class Fory : public BaseFory { return result; } + template + FORY_ALWAYS_INLINE Result deserialize_buffer(Buffer &buffer) { + Error header_error; + const uint8_t header = buffer.read_uint8(header_error); + if (FORY_PREDICT_FALSE(!header_error.ok())) { + read_ctx_->reset(); + return Unexpected(std::move(header_error)); + } + if (FORY_PREDICT_FALSE(header != precomputed_header_)) { + read_ctx_->reset(); + return Unexpected(invalid_root_header(header)); + } + + read_ctx_->attach(buffer); + read_ctx_->remaining_graph_memory_bytes_ = + static_cast(read_ctx_->config_->max_graph_memory_bytes); + ReadContextGuard guard(*read_ctx_); + return deserialize_impl(buffer); + } + template Result cached_write_root_type_info() { constexpr uint64_t ctid = type_index(); diff --git a/cpp/fory/serialization/graph_memory_budget_test.cc b/cpp/fory/serialization/graph_memory_budget_test.cc new file mode 100644 index 0000000000..5cbe55d8f1 --- /dev/null +++ b/cpp/fory/serialization/graph_memory_budget_test.cc @@ -0,0 +1,431 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "fory/serialization/fory.h" +#include "gtest/gtest.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace fory { +namespace serialization { +namespace { + +constexpr int64_t kDefaultGraphMemoryBytes = 128LL * 1024LL * 1024LL; + +struct BudgetItem { + int32_t id = 0; + std::string name; + + bool operator==(const BudgetItem &other) const { + return id == other.id && name == other.name; + } + + FORY_STRUCT(BudgetItem, id, name); +}; + +struct BudgetEmpty { + bool operator==(const BudgetEmpty &) const { return true; } + + FORY_STRUCT(BudgetEmpty); +}; + +struct BudgetSiblings { + std::vector left; + std::vector right; + + bool operator==(const BudgetSiblings &other) const { + return left == other.left && right == other.right; + } + + FORY_STRUCT(BudgetSiblings, left, right); +}; + +struct BudgetFixedArrayOwner { + std::array prefix{}; + std::vector items; + + bool operator==(const BudgetFixedArrayOwner &other) const { + return prefix == other.prefix && items == other.items; + } + + FORY_STRUCT(BudgetFixedArrayOwner, prefix, items); +}; + +template auto with_fory(int64_t max_graph_memory_bytes, Fn &&fn) { + auto fory = Fory::builder() + .xlang(true) + .compatible(false) + .track_ref(false) + .max_graph_memory_bytes(max_graph_memory_bytes) + .build(); + fory.register_struct(1); + fory.register_struct(2); + fory.register_struct(3); + fory.register_struct(4); + return std::forward(fn)(fory); +} + +template std::vector serialize_value(const T &value) { + auto bytes = with_fory(kDefaultGraphMemoryBytes, + [&](Fory &fory) { return fory.serialize(value); }); + EXPECT_TRUE(bytes.ok()) << bytes.error().to_string(); + return std::move(bytes).value(); +} + +size_t nested_empty_budget(size_t count) { + using Outer = std::vector>; + using Inner = std::vector; + return sizeof(Outer) + count * sizeof(Inner); +} + +template +void expect_budget_boundary(const T &value, size_t required) { + ASSERT_GT(required, 0u); + auto bytes = serialize_value(value); + + if (required > 1) { + auto small_result = + with_fory(static_cast(required - 1), + [&](Fory &fory) { return fory.deserialize(bytes); }); + ASSERT_FALSE(small_result.ok()); + EXPECT_EQ(small_result.error().code(), ErrorCode::InvalidData); + } + + auto exact_result = + with_fory(static_cast(required), + [&](Fory &fory) { return fory.deserialize(bytes); }); + ASSERT_TRUE(exact_result.ok()) << exact_result.error().to_string(); + EXPECT_EQ(exact_result.value(), value); +} + +TEST(GraphMemoryBudgetTest, FixedDefaultBudgetAndValidation) { + Config config; + EXPECT_EQ(config.max_graph_memory_bytes, kDefaultGraphMemoryBytes); + + EXPECT_DEATH((void)Fory::builder().max_graph_memory_bytes(0), + "max_graph_memory_bytes"); +} + +TEST(GraphMemoryBudgetTest, RootKindsShareConfiguredBudget) { + constexpr size_t count = 3; + std::vector> value(count); + auto bytes = serialize_value(value); + const size_t required = nested_empty_budget(count); + + auto byte_result = + with_fory(static_cast(required - 1), [&](Fory &fory) { + return fory.deserialize>>(bytes); + }); + ASSERT_FALSE(byte_result.ok()); + EXPECT_EQ(byte_result.error().code(), ErrorCode::InvalidData); + + std::string input(reinterpret_cast(bytes.data()), bytes.size()); + std::istringstream source(input); + StdInputStream stream(source, 8); + auto stream_result = + with_fory(static_cast(required - 1), [&](Fory &fory) { + return fory.deserialize>>(stream); + }); + ASSERT_FALSE(stream_result.ok()); + EXPECT_EQ(stream_result.error().code(), ErrorCode::InvalidData); + + std::istringstream exact_source(input); + StdInputStream exact_stream(exact_source, 8); + auto exact_result = + with_fory(static_cast(required), [&](Fory &fory) { + return fory.deserialize>>( + exact_stream); + }); + ASSERT_TRUE(exact_result.ok()) << exact_result.error().to_string(); + EXPECT_EQ(exact_result.value(), value); +} + +TEST(GraphMemoryBudgetTest, ExplicitOverride) { + std::vector value(8); + auto bytes = serialize_value(value); + const size_t required = + sizeof(std::vector) + value.size() * sizeof(BudgetItem); + + auto small_result = + with_fory(static_cast(required - 1), [&](Fory &fory) { + return fory.deserialize>(bytes); + }); + ASSERT_FALSE(small_result.ok()); + EXPECT_EQ(small_result.error().code(), ErrorCode::InvalidData); + + auto exact_result = + with_fory(static_cast(required), [&](Fory &fory) { + return fory.deserialize>(bytes); + }); + ASSERT_TRUE(exact_result.ok()) << exact_result.error().to_string(); + EXPECT_EQ(exact_result.value(), value); +} + +TEST(GraphMemoryBudgetTest, SmartPointerStructOwners) { + auto shared_value = std::make_shared(); + shared_value->id = 7; + shared_value->name = "shared"; + auto shared_bytes = serialize_value(shared_value); + constexpr size_t shared_required = + sizeof(std::shared_ptr) + sizeof(BudgetItem); + + auto shared_small = + with_fory(static_cast(shared_required - 1), [&](Fory &fory) { + return fory.deserialize>(shared_bytes); + }); + ASSERT_FALSE(shared_small.ok()); + EXPECT_EQ(shared_small.error().code(), ErrorCode::InvalidData); + + auto shared_exact = + with_fory(static_cast(shared_required), [&](Fory &fory) { + return fory.deserialize>(shared_bytes); + }); + ASSERT_TRUE(shared_exact.ok()) << shared_exact.error().to_string(); + ASSERT_NE(shared_exact.value(), nullptr); + EXPECT_EQ(*shared_exact.value(), *shared_value); + + auto unique_value = std::make_unique(); + unique_value->id = 9; + unique_value->name = "unique"; + auto unique_bytes = serialize_value(unique_value); + + constexpr size_t unique_required = + sizeof(std::unique_ptr) + sizeof(BudgetItem); + auto unique_small = + with_fory(static_cast(unique_required - 1), [&](Fory &fory) { + return fory.deserialize>(unique_bytes); + }); + ASSERT_FALSE(unique_small.ok()); + EXPECT_EQ(unique_small.error().code(), ErrorCode::InvalidData); + + auto unique_exact = + with_fory(static_cast(unique_required), [&](Fory &fory) { + return fory.deserialize>(unique_bytes); + }); + ASSERT_TRUE(unique_exact.ok()) << unique_exact.error().to_string(); + ASSERT_NE(unique_exact.value(), nullptr); + EXPECT_EQ(*unique_exact.value(), *unique_value); +} + +TEST(GraphMemoryBudgetTest, SmartPointerVectorOwner) { + auto value = std::make_shared>(3); + auto bytes = serialize_value(value); + const size_t required = sizeof(std::shared_ptr>) + + sizeof(std::vector) + + value->size() * sizeof(BudgetItem); + + auto small_result = + with_fory(static_cast(required - 1), [&](Fory &fory) { + return fory.deserialize>>( + bytes); + }); + ASSERT_FALSE(small_result.ok()); + EXPECT_EQ(small_result.error().code(), ErrorCode::InvalidData); + + auto exact_result = + with_fory(static_cast(required), [&](Fory &fory) { + return fory.deserialize>>( + bytes); + }); + ASSERT_TRUE(exact_result.ok()) << exact_result.error().to_string(); + ASSERT_NE(exact_result.value(), nullptr); + EXPECT_EQ(*exact_result.value(), *value); +} + +TEST(GraphMemoryBudgetTest, EmptyStructRootChargesOwner) { + BudgetEmpty value; + expect_budget_boundary(value, sizeof(BudgetEmpty)); +} + +TEST(GraphMemoryBudgetTest, NestedEmptyContainers) { + std::vector> value(1); + auto bytes = serialize_value(value); + const size_t required = sizeof(std::vector>) + + sizeof(std::vector); + + auto small_result = + with_fory(static_cast(required - 1), [&](Fory &fory) { + return fory.deserialize>>(bytes); + }); + ASSERT_FALSE(small_result.ok()); + EXPECT_EQ(small_result.error().code(), ErrorCode::InvalidData); + + auto exact_result = + with_fory(static_cast(required), [&](Fory &fory) { + return fory.deserialize>>(bytes); + }); + ASSERT_TRUE(exact_result.ok()) << exact_result.error().to_string(); + EXPECT_EQ(exact_result.value(), value); +} + +TEST(GraphMemoryBudgetTest, SiblingCumulativeBudget) { + BudgetSiblings value; + value.left.resize(16); + value.right.resize(16); + auto bytes = serialize_value(value); + const size_t root_owner = sizeof(BudgetSiblings); + const size_t one_vector = value.left.size() * sizeof(BudgetItem); + + auto small_result = + with_fory(static_cast(root_owner + one_vector), [&](Fory &fory) { + return fory.deserialize(bytes); + }); + ASSERT_FALSE(small_result.ok()); + EXPECT_EQ(small_result.error().code(), ErrorCode::InvalidData); + + auto enough_result = with_fory( + static_cast(root_owner + one_vector * 2), + [&](Fory &fory) { return fory.deserialize(bytes); }); + ASSERT_TRUE(enough_result.ok()) << enough_result.error().to_string(); + EXPECT_EQ(enough_result.value(), value); +} + +TEST(GraphMemoryBudgetTest, MapBudget) { + std::map value{{"a", 1}, {"b", 2}, {"c", 3}}; + const size_t entry_bytes = sizeof(std::string) + sizeof(int32_t); + const size_t required = + sizeof(std::map) + value.size() * entry_bytes; + + expect_budget_boundary(value, required); +} + +TEST(GraphMemoryBudgetTest, CollectionLowerBounds) { + std::deque deque_value(4); + expect_budget_boundary(deque_value, + sizeof(std::deque) + + deque_value.size() * sizeof(BudgetItem)); + + std::list list_value(4); + expect_budget_boundary(list_value, + sizeof(std::list) + + list_value.size() * sizeof(BudgetItem)); + + std::forward_list forward_value(4); + expect_budget_boundary(forward_value, sizeof(std::forward_list) + + size_t{4} * sizeof(BudgetItem)); +} + +TEST(GraphMemoryBudgetTest, VectorBoolChargesPackedStorage) { + std::vector value(33); + value[0] = true; + value[32] = true; + expect_budget_boundary(value, size_t{5}); +} + +TEST(GraphMemoryBudgetTest, OrderedSetAndMapLowerBounds) { + std::set set_value{1, 2, 3, 4}; + expect_budget_boundary(set_value, sizeof(std::set) + + set_value.size() * sizeof(int32_t)); + + std::map map_value{{"a", 1}, {"b", 2}}; + expect_budget_boundary(map_value, + sizeof(std::map) + + map_value.size() * + (sizeof(std::string) + sizeof(int32_t))); +} + +TEST(GraphMemoryBudgetTest, UnorderedContainersLowerBounds) { + std::unordered_set set_value{1, 2, 3, 4}; + expect_budget_boundary(set_value, sizeof(std::unordered_set) + + set_value.size() * sizeof(int32_t)); + + std::unordered_map map_value{{"a", 1}, {"b", 2}}; + expect_budget_boundary(map_value, + sizeof(std::unordered_map) + + map_value.size() * + (sizeof(std::string) + sizeof(int32_t))); +} + +TEST(GraphMemoryBudgetTest, ArrayHasNoStandaloneReservation) { + std::array value{{1, 2, 3, 4}}; + auto bytes = serialize_value(value); + auto result = with_fory(1, [&](Fory &fory) { + return fory.deserialize>(bytes); + }); + ASSERT_TRUE(result.ok()) << result.error().to_string(); + EXPECT_EQ(result.value(), value); +} + +TEST(GraphMemoryBudgetTest, FixedInlineOwnerChargesNestedVector) { + BudgetFixedArrayOwner value; + value.prefix = {{1, 2, 3, 4}}; + value.items.resize(3); + const size_t required = + sizeof(BudgetFixedArrayOwner) + value.items.size() * sizeof(BudgetItem); + + expect_budget_boundary(value, required); +} + +TEST(GraphMemoryBudgetTest, DensePathsSkipped) { + { + std::string value = "graph-budget-string"; + auto bytes = serialize_value(value); + auto result = with_fory( + 1, [&](Fory &fory) { return fory.deserialize(bytes); }); + ASSERT_TRUE(result.ok()) << result.error().to_string(); + EXPECT_EQ(result.value(), value); + } + { + std::vector value(256, 7); + auto bytes = serialize_value(value); + auto result = with_fory(1, [&](Fory &fory) { + return fory.deserialize>(bytes); + }); + ASSERT_TRUE(result.ok()) << result.error().to_string(); + EXPECT_EQ(result.value(), value); + } + { + std::vector value(256, 42); + auto bytes = serialize_value(value); + auto result = with_fory(1, [&](Fory &fory) { + return fory.deserialize>(bytes); + }); + ASSERT_TRUE(result.ok()) << result.error().to_string(); + EXPECT_EQ(result.value(), value); + } +} + +TEST(GraphMemoryBudgetTest, ByteCheckStillRejectsLargeLength) { + auto bytes = serialize_value(std::vector{}); + ASSERT_FALSE(bytes.empty()); + bytes.back() = 64; + bytes.push_back(0); + + auto result = with_fory(kDefaultGraphMemoryBytes, [&](Fory &fory) { + return fory.deserialize>(bytes); + }); + ASSERT_FALSE(result.ok()); + EXPECT_EQ(result.error().code(), ErrorCode::BufferOutOfBound); +} + +} // namespace +} // namespace serialization +} // namespace fory diff --git a/cpp/fory/serialization/map_serializer.h b/cpp/fory/serialization/map_serializer.h index 830e5fbae5..28db52cc53 100644 --- a/cpp/fory/serialization/map_serializer.h +++ b/cpp/fory/serialization/map_serializer.h @@ -21,6 +21,7 @@ #include "fory/serialization/serializer.h" #include +#include #include #include #include @@ -81,6 +82,19 @@ struct MapReserver +inline bool reserve_map_storage(ReadContext &ctx, uint32_t length) { + if (FORY_PREDICT_FALSE(elem_bytes != 0 && + static_cast(length) > + std::numeric_limits::max() / elem_bytes)) { + ctx.set_error(Error::invalid_data( + "graph memory estimate overflows: length=" + std::to_string(length) + + " elementBytes=" + std::to_string(elem_bytes))); + return false; + } + return ctx.reserve_graph_memory(static_cast(length) * elem_bytes); +} + template inline bool reserve_map(MapType &map, ReadContext &ctx, uint32_t length) { // Lazy error propagation may continue into later readers; do not let that @@ -88,6 +102,17 @@ inline bool reserve_map(MapType &map, ReadContext &ctx, uint32_t length) { if (FORY_PREDICT_FALSE(ctx.has_error())) { return false; } + using Key = typename MapType::key_type; + using Value = typename MapType::mapped_type; + // Portable lower-bound estimate only: ordered and unordered map node layouts + // vary across STL implementations, allocators, and debug modes. + static_assert(sizeof(Key) <= + std::numeric_limits::max() - sizeof(Value), + "map entry memory estimate overflows"); + constexpr size_t elem_bytes = sizeof(Key) + sizeof(Value); + if (FORY_PREDICT_FALSE((!reserve_map_storage(ctx, length)))) { + return false; + } if (FORY_PREDICT_FALSE(!ctx.buffer().ensure_readable(length, ctx.error()))) { return false; } @@ -1019,6 +1044,10 @@ struct Serializer> { if (!has_value) { return MapType{}; } + if (ref_mode != RefMode::None && + FORY_PREDICT_FALSE(!reserve_allocated_value_owner(ctx))) { + return MapType{}; + } if (read_type) { uint32_t type_id_read = ctx.read_uint8(ctx.error()); @@ -1127,6 +1156,10 @@ struct Serializer> { if (!has_value) { return MapType{}; } + if (ref_mode != RefMode::None && + FORY_PREDICT_FALSE(!reserve_allocated_value_owner(ctx))) { + return MapType{}; + } if (read_type) { uint32_t type_id_read = ctx.read_uint8(ctx.error()); diff --git a/cpp/fory/serialization/serializer_traits.h b/cpp/fory/serialization/serializer_traits.h index ad07c4aa6d..854a083a4a 100644 --- a/cpp/fory/serialization/serializer_traits.h +++ b/cpp/fory/serialization/serializer_traits.h @@ -23,7 +23,9 @@ #include "fory/meta/type_index.h" #include "fory/meta/type_traits.h" #include "fory/type/type.h" +#include "fory/util/macros.h" #include +#include #include #include #include @@ -244,6 +246,97 @@ struct is_fory_serializable< template inline constexpr bool is_fory_serializable_v = is_fory_serializable::value; +// ============================================================================ +// Graph budget reachability +// ============================================================================ + +template +struct needs_graph_budget : std::false_type {}; + +template <> struct needs_graph_budget : std::false_type {}; +template <> struct needs_graph_budget : std::false_type {}; +template <> struct needs_graph_budget : std::false_type {}; +template <> struct needs_graph_budget : std::false_type {}; +template <> struct needs_graph_budget : std::false_type {}; +template <> struct needs_graph_budget : std::false_type {}; +template <> struct needs_graph_budget : std::false_type {}; +template <> struct needs_graph_budget : std::false_type {}; +template <> struct needs_graph_budget : std::false_type {}; +template <> struct needs_graph_budget : std::false_type {}; +template <> struct needs_graph_budget : std::false_type {}; +template <> struct needs_graph_budget : std::false_type {}; + +template +struct needs_graph_budget< + T, std::enable_if_t || is_list_v || is_deque_v || + is_forward_list_v || is_set_like_v || + is_map_like_v>> : std::true_type {}; + +template +struct needs_graph_budget, void> + : std::bool_constant>>::value> {}; + +template +struct needs_graph_budget, void> + : std::bool_constant>>::value> {}; + +template +struct needs_graph_budget, void> : std::true_type {}; + +template +struct needs_graph_budget, void> : std::true_type {}; + +template +struct needs_graph_budget, void> : std::true_type {}; + +template +struct needs_graph_budget, void> + : std::bool_constant<(needs_graph_budget>>::value || + ...)> {}; + +template +struct needs_graph_budget>> + : std::true_type {}; + +template +inline constexpr bool needs_graph_budget_v = + needs_graph_budget>>::value; + +template struct is_dense_primitive_vector : std::false_type {}; + +template +struct is_dense_primitive_vector> + : std::bool_constant> {}; + +template +inline constexpr bool is_dense_primitive_vector_v = is_dense_primitive_vector< + std::remove_cv_t>>::value; + +template constexpr size_t graph_value_owner_self_bytes() { + using Value = std::remove_cv_t>; + if constexpr (!needs_graph_budget_v || + is_dense_primitive_vector_v) { + return size_t{0}; + } else if constexpr (std::is_empty_v) { + return size_t{1}; + } else { + return sizeof(Value); + } +} + +template +FORY_ALWAYS_INLINE bool reserve_allocated_value_owner(Context &ctx) { + constexpr size_t bytes = graph_value_owner_self_bytes(); + if constexpr (bytes == 0) { + return true; + } else { + return ctx.reserve_graph_memory(bytes); + } +} + // ============================================================================ // Generic Type Detection // ============================================================================ diff --git a/cpp/fory/serialization/smart_ptr_serializers.h b/cpp/fory/serialization/smart_ptr_serializers.h index 2ffe0ca98d..91e506325a 100644 --- a/cpp/fory/serialization/smart_ptr_serializers.h +++ b/cpp/fory/serialization/smart_ptr_serializers.h @@ -444,6 +444,9 @@ template struct Serializer> { "Cannot use monomorphic deserialization for abstract type")); return nullptr; } else { + if (FORY_PREDICT_FALSE(!reserve_allocated_value_owner(ctx))) { + return nullptr; + } T value = Serializer::read(ctx, RefMode::None, false); if (ctx.has_error()) { return nullptr; @@ -454,6 +457,9 @@ template struct Serializer> { } else { // T is guaranteed to be a value type (not pointer or nullable wrapper) // by static_assert, so no inner ref metadata needed. + if (FORY_PREDICT_FALSE(!reserve_allocated_value_owner(ctx))) { + return nullptr; + } T value = Serializer::read(ctx, RefMode::None, read_type); if (ctx.has_error()) { return nullptr; @@ -505,6 +511,10 @@ template struct Serializer> { } reserved_ref_id = ctx.ref_reader().reserve_ref_id(); } + if (FORY_PREDICT_FALSE( + !ctx.reserve_graph_memory(sizeof(std::shared_ptr)))) { + return nullptr; + } // For polymorphic types, read type info AFTER handling ref flags if constexpr (is_polymorphic) { @@ -546,6 +556,9 @@ template struct Serializer> { } else { // For circular references: pre-allocate and store BEFORE reading if (is_first_occurrence) { + if (FORY_PREDICT_FALSE(!reserve_allocated_value_owner(ctx))) { + return nullptr; + } auto result = std::make_shared(); ctx.ref_reader().store_shared_ref_at(reserved_ref_id, result); T value = Serializer::read(ctx, RefMode::None, false); @@ -555,6 +568,9 @@ template struct Serializer> { *result = std::move(value); return result; } else { + if (FORY_PREDICT_FALSE(!reserve_allocated_value_owner(ctx))) { + return nullptr; + } T value = Serializer::read(ctx, RefMode::None, false); if (ctx.has_error()) { return nullptr; @@ -572,6 +588,9 @@ template struct Serializer> { // references (like self_ref pointing back to the parent) to resolve. if (is_first_occurrence) { // Pre-allocate with default construction and store immediately + if (FORY_PREDICT_FALSE(!reserve_allocated_value_owner(ctx))) { + return nullptr; + } auto result = std::make_shared(); ctx.ref_reader().store_shared_ref_at(reserved_ref_id, result); // Read struct data - forward refs can now find this object @@ -584,6 +603,9 @@ template struct Serializer> { return result; } else { // Not first occurrence, just read and wrap + if (FORY_PREDICT_FALSE(!reserve_allocated_value_owner(ctx))) { + return nullptr; + } T value = Serializer::read(ctx, RefMode::None, read_type); if (ctx.has_error()) { return nullptr; @@ -610,6 +632,9 @@ template struct Serializer> { return std::shared_ptr(obj_ptr); } else { // T is guaranteed to be a value type by static_assert. + if (FORY_PREDICT_FALSE(!reserve_allocated_value_owner(ctx))) { + return nullptr; + } T value = Serializer::read_with_type_info(ctx, RefMode::None, type_info); if (ctx.has_error()) { @@ -689,6 +714,9 @@ template struct Serializer> { // For circular references: pre-allocate and store BEFORE reading const bool is_first_occurrence = flag == REF_VALUE_FLAG; if (is_first_occurrence) { + if (FORY_PREDICT_FALSE(!reserve_allocated_value_owner(ctx))) { + return nullptr; + } auto result = std::make_shared(); ctx.ref_reader().store_shared_ref_at(reserved_ref_id, result); T value = @@ -699,6 +727,9 @@ template struct Serializer> { *result = std::move(value); return result; } else { + if (FORY_PREDICT_FALSE(!reserve_allocated_value_owner(ctx))) { + return nullptr; + } T value = Serializer::read_with_type_info(ctx, RefMode::None, type_info); if (ctx.has_error()) { @@ -710,6 +741,9 @@ template struct Serializer> { } static inline std::shared_ptr read_data(ReadContext &ctx) { + if (FORY_PREDICT_FALSE(!reserve_allocated_value_owner(ctx))) { + return nullptr; + } T value = Serializer::read_data(ctx); if (ctx.has_error()) { return nullptr; @@ -894,6 +928,9 @@ template struct Serializer> { "Cannot use monomorphic deserialization for abstract type")); return nullptr; } else { + if (FORY_PREDICT_FALSE(!reserve_allocated_value_owner(ctx))) { + return nullptr; + } T value = Serializer::read(ctx, RefMode::None, false); if (ctx.has_error()) { return nullptr; @@ -903,6 +940,9 @@ template struct Serializer> { } } else { // T is guaranteed to be a value type by static_assert. + if (FORY_PREDICT_FALSE(!reserve_allocated_value_owner(ctx))) { + return nullptr; + } T value = Serializer::read(ctx, RefMode::None, read_type); if (ctx.has_error()) { return nullptr; @@ -925,6 +965,10 @@ template struct Serializer> { std::to_string(static_cast(flag)))); return nullptr; } + if (FORY_PREDICT_FALSE( + !ctx.reserve_graph_memory(sizeof(std::unique_ptr)))) { + return nullptr; + } // For polymorphic types, read type info AFTER handling ref flags if constexpr (is_polymorphic) { @@ -960,6 +1004,9 @@ template struct Serializer> { "Cannot use monomorphic deserialization for abstract type")); return nullptr; } else { + if (FORY_PREDICT_FALSE(!reserve_allocated_value_owner(ctx))) { + return nullptr; + } T value = Serializer::read(ctx, RefMode::None, false); if (ctx.has_error()) { return nullptr; @@ -969,6 +1016,9 @@ template struct Serializer> { } } else { // T is guaranteed to be a value type by static_assert. + if (FORY_PREDICT_FALSE(!reserve_allocated_value_owner(ctx))) { + return nullptr; + } T value = Serializer::read(ctx, RefMode::None, read_type); if (ctx.has_error()) { return nullptr; @@ -994,6 +1044,9 @@ template struct Serializer> { return std::unique_ptr(obj_ptr); } else { // T is guaranteed to be a value type by static_assert. + if (FORY_PREDICT_FALSE(!reserve_allocated_value_owner(ctx))) { + return nullptr; + } T value = Serializer::read_with_type_info(ctx, RefMode::None, type_info); if (ctx.has_error()) { @@ -1037,6 +1090,9 @@ template struct Serializer> { return std::unique_ptr(obj_ptr); } else { // T is guaranteed to be a value type by static_assert. + if (FORY_PREDICT_FALSE(!reserve_allocated_value_owner(ctx))) { + return nullptr; + } T value = Serializer::read_with_type_info(ctx, RefMode::None, type_info); if (ctx.has_error()) { @@ -1047,6 +1103,9 @@ template struct Serializer> { } static inline std::unique_ptr read_data(ReadContext &ctx) { + if (FORY_PREDICT_FALSE(!reserve_allocated_value_owner(ctx))) { + return nullptr; + } T value = Serializer::read_data(ctx); if (ctx.has_error()) { return nullptr; diff --git a/cpp/fory/serialization/struct_serializer.h b/cpp/fory/serialization/struct_serializer.h index 4da4f7751c..f2c3488031 100644 --- a/cpp/fory/serialization/struct_serializer.h +++ b/cpp/fory/serialization/struct_serializer.h @@ -899,9 +899,6 @@ Container read_configured_list_data(ReadContext &ctx) { if (length == 0) { return result; } - if (FORY_PREDICT_FALSE(!reserve_collection(result, ctx, length))) { - return result; - } uint8_t bitmap = ctx.read_uint8(ctx.error()); if (FORY_PREDICT_FALSE(ctx.has_error())) { return result; @@ -916,6 +913,9 @@ Container read_configured_list_data(ReadContext &ctx) { return result; } } + if (FORY_PREDICT_FALSE(!reserve_collection(result, ctx, length))) { + return result; + } const RefMode elem_ref_mode = track_ref ? RefMode::Tracking : (has_null ? RefMode::NullOnly : RefMode::None); @@ -939,7 +939,10 @@ FORY_NOINLINE Container read_configured_list_data_as_array_field( using Elem = element_type_t; uint32_t length = ctx.read_var_uint32(ctx.error()); Container result; - if (FORY_PREDICT_FALSE(ctx.has_error()) || length == 0) { + if (FORY_PREDICT_FALSE(ctx.has_error())) { + return result; + } + if (length == 0) { return result; } uint8_t bitmap = ctx.read_uint8(ctx.error()); @@ -4581,6 +4584,10 @@ struct Serializer>> { if (ctx.track_ref() && ref_flag == ref_value_flag) { ctx.ref_reader().reserve_ref_id(); } + if (ref_mode != RefMode::None && + FORY_PREDICT_FALSE(!reserve_allocated_value_owner(ctx))) { + return T{}; + } // In compatible mode: use meta sharing (matches Rust behavior) if (ctx.is_compatible()) { // In compatible mode: always use remote TypeMeta for schema evolution diff --git a/cpp/fory/serialization/type_resolver.h b/cpp/fory/serialization/type_resolver.h index 1a14a8570f..fa3b61f5af 100644 --- a/cpp/fory/serialization/type_resolver.h +++ b/cpp/fory/serialization/type_resolver.h @@ -1573,6 +1573,9 @@ template inline std::any any_read_adapter(ReadContext &ctx) { if (FORY_PREDICT_FALSE(ctx.has_error())) { return std::any(); } + if (FORY_PREDICT_FALSE(!reserve_allocated_value_owner(ctx))) { + return std::any(); + } if constexpr (std::is_copy_constructible::value) { return std::any(std::move(value)); } @@ -2127,6 +2130,9 @@ void *TypeResolver::harness_read_adapter(ReadContext &ctx, RefMode ref_mode, if (FORY_PREDICT_FALSE(ctx.has_error())) { return nullptr; } + if (FORY_PREDICT_FALSE(!reserve_allocated_value_owner(ctx))) { + return nullptr; + } return new T(std::move(value)); } @@ -2151,6 +2157,9 @@ void *TypeResolver::harness_read_data_adapter(ReadContext &ctx) { if (FORY_PREDICT_FALSE(ctx.has_error())) { return nullptr; } + if (FORY_PREDICT_FALSE(!reserve_allocated_value_owner(ctx))) { + return nullptr; + } return new T(std::move(value)); } @@ -2173,6 +2182,9 @@ void *TypeResolver::harness_read_compatible_adapter(ReadContext &ctx, if (FORY_PREDICT_FALSE(ctx.has_error())) { return nullptr; } + if (FORY_PREDICT_FALSE(!reserve_allocated_value_owner(ctx))) { + return nullptr; + } return new T(std::move(value)); } diff --git a/cpp/fory/serialization/union_serializer.h b/cpp/fory/serialization/union_serializer.h index d5247d431f..6a2c4db55d 100644 --- a/cpp/fory/serialization/union_serializer.h +++ b/cpp/fory/serialization/union_serializer.h @@ -468,9 +468,6 @@ Container read_union_configured_list_data(ReadContext &ctx) { if (length == 0) { return result; } - if (FORY_PREDICT_FALSE(!reserve_collection(result, ctx, length))) { - return result; - } uint8_t bitmap = ctx.read_uint8(ctx.error()); if (FORY_PREDICT_FALSE(ctx.has_error())) { return result; @@ -483,6 +480,9 @@ Container read_union_configured_list_data(ReadContext &ctx) { return result; } } + if (FORY_PREDICT_FALSE(!reserve_collection(result, ctx, length))) { + return result; + } for (uint32_t i = 0; i < length; ++i) { if constexpr (ElemNode >= 0) { auto elem = read_union_configured_value( diff --git a/csharp/src/Fory.Generator/ForyModelGenerator.cs b/csharp/src/Fory.Generator/ForyModelGenerator.cs index 8e051da478..91c7f7fa81 100644 --- a/csharp/src/Fory.Generator/ForyModelGenerator.cs +++ b/csharp/src/Fory.Generator/ForyModelGenerator.cs @@ -212,6 +212,20 @@ private static void EmitObjectSerializer(StringBuilder sb, TypeModel model) sb.AppendLine(" private static bool __ForyRefMetaMatches;"); sb.AppendLine( $" private const bool __ForyAllFieldsBuiltIn = {BoolLiteral(model.SortedMembers.All(m => m.DynamicAnyKind == DynamicAnyKind.None && m.Classification.IsBuiltIn))};"); + if (model.Kind == DeclKind.Class) + { + string graphMemoryExpr = ModelGraphMemoryExpr(model); + bool constGraphMemory = IsConstGraphMemoryExpr(graphMemoryExpr); + string graphMemoryStorage = constGraphMemory ? "const" : "static readonly"; + string graphMemoryType = constGraphMemory ? "int" : "long"; + if (constGraphMemory) + { + graphMemoryExpr = graphMemoryExpr.Replace("L", string.Empty); + } + + sb.AppendLine($" private {graphMemoryStorage} {graphMemoryType} __ForyGraphMemoryBytes = {graphMemoryExpr};"); + } + sb.AppendLine( " private static global::System.Collections.Generic.IReadOnlyList? __ForyNoRefTypeMetaFields;"); sb.AppendLine( @@ -446,14 +460,27 @@ private static void EmitObjectSerializer(StringBuilder sb, TypeModel model) sb.AppendLine(" }"); sb.AppendLine(); - sb.AppendLine($" private {model.TypeName} ReadDataWithoutTypeMeta(global::Apache.Fory.ReadContext context)"); + EmitReadDataWithoutTypeMeta(sb, model, "ReadDataWithoutTypeMeta"); + EmitReadDataMethod(sb, model, "ReadData", "ReadDataWithoutTypeMeta", "public"); + + sb.AppendLine("}"); + } + + private static void EmitReadDataWithoutTypeMeta( + StringBuilder sb, + TypeModel model, + string methodName) + { + sb.AppendLine($" private {model.TypeName} {methodName}(global::Apache.Fory.ReadContext context)"); sb.AppendLine(" {"); - sb.AppendLine($" {model.TypeName} valueNoTypeMeta = new {model.TypeName}();"); if (model.Kind == DeclKind.Class) { - sb.AppendLine(" context.StoreRef(valueNoTypeMeta);"); + sb.AppendLine(" context.ReserveGraphMemory(__ForyGraphMemoryBytes);"); } + sb.AppendLine($" {model.TypeName} valueNoTypeMeta = new {model.TypeName}();"); + EmitStoreRef(sb, model, "valueNoTypeMeta", 2); + foreach (MemberModel member in model.SortedMembers) { EmitReadMemberAssignment( @@ -470,7 +497,16 @@ private static void EmitObjectSerializer(StringBuilder sb, TypeModel model) sb.AppendLine(" return valueNoTypeMeta;"); sb.AppendLine(" }"); sb.AppendLine(); - sb.AppendLine($" public override {model.TypeName} ReadData(global::Apache.Fory.ReadContext context)"); + } + + private static void EmitReadDataMethod( + StringBuilder sb, + TypeModel model, + string methodName, + string noTypeMetaMethodName, + string accessibility) + { + sb.AppendLine($" {accessibility} override {model.TypeName} {methodName}(global::Apache.Fory.ReadContext context)"); sb.AppendLine(" {"); sb.AppendLine(" if (context.Compatible)"); sb.AppendLine(" {"); @@ -478,16 +514,18 @@ private static void EmitObjectSerializer(StringBuilder sb, TypeModel model) $" global::Apache.Fory.TypeMeta? maybeTypeMeta = context.GetTypeMeta<{model.TypeName}>();"); sb.AppendLine(" if (maybeTypeMeta is null)"); sb.AppendLine(" {"); - sb.AppendLine(" return ReadDataWithoutTypeMeta(context);"); + sb.AppendLine($" return {noTypeMetaMethodName}(context);"); sb.AppendLine(" }"); sb.AppendLine(); sb.AppendLine(" global::Apache.Fory.TypeMeta typeMeta = maybeTypeMeta;"); - sb.AppendLine($" {model.TypeName} value = new {model.TypeName}();"); if (model.Kind == DeclKind.Class) { - sb.AppendLine(" context.StoreRef(value);"); + sb.AppendLine(" context.ReserveGraphMemory(__ForyGraphMemoryBytes);"); } + sb.AppendLine($" {model.TypeName} value = new {model.TypeName}();"); + EmitStoreRef(sb, model, "value", 3); + sb.AppendLine(" bool __ForyExactTypeMeta = __ForyMatchesCachedTypeMeta(typeMeta, context.TrackRef, context.TypeResolver);"); sb.AppendLine(" if (__ForyAllFieldsBuiltIn && __ForyExactTypeMeta)"); sb.AppendLine(" {"); @@ -592,12 +630,14 @@ private static void EmitObjectSerializer(StringBuilder sb, TypeModel model) sb.AppendLine(" }"); sb.AppendLine(" }"); sb.AppendLine(); - sb.AppendLine($" {model.TypeName} valueSchema = new {model.TypeName}();"); if (model.Kind == DeclKind.Class) { - sb.AppendLine(" context.StoreRef(valueSchema);"); + sb.AppendLine(" context.ReserveGraphMemory(__ForyGraphMemoryBytes);"); } + sb.AppendLine($" {model.TypeName} valueSchema = new {model.TypeName}();"); + EmitStoreRef(sb, model, "valueSchema", 2); + foreach (MemberModel member in model.SortedMembers) { EmitReadMemberAssignment(sb, member, BuildWriteRefModeExpression(member), "false", "valueSchema", "Schema", 2, true); @@ -605,7 +645,22 @@ private static void EmitObjectSerializer(StringBuilder sb, TypeModel model) sb.AppendLine(" return valueSchema;"); sb.AppendLine(" }"); - sb.AppendLine("}"); + sb.AppendLine(); + } + + private static void EmitStoreRef( + StringBuilder sb, + TypeModel model, + string valueName, + int indentLevel) + { + if (model.Kind != DeclKind.Class) + { + return; + } + + string indent = new(' ', indentLevel * 4); + sb.AppendLine($"{indent}context.StoreRef({valueName});"); } private static void EmitUnionSerializer(StringBuilder sb, TypeModel model) @@ -693,7 +748,8 @@ private static void EmitUnionSerializer(StringBuilder sb, TypeModel model) sb.AppendLine($" case {caseId}:"); sb.AppendLine(" {"); EmitReadUnionCasePayload(sb, unionCase, valueVar, 4); - sb.AppendLine($" return new {unionCase.TypeName}({valueVar});"); + sb.AppendLine($" {model.TypeName} __foryUnion = new {unionCase.TypeName}({valueVar});"); + sb.AppendLine(" return __foryUnion;"); sb.AppendLine(" }"); } @@ -705,7 +761,8 @@ private static void EmitUnionSerializer(StringBuilder sb, TypeModel model) } else { - sb.AppendLine($" return new {unknownCase.TypeName}(global::Apache.Fory.UnknownCaseSerializer.ReadPayload(context, caseId));"); + sb.AppendLine($" {model.TypeName} __foryUnion = new {unknownCase.TypeName}(global::Apache.Fory.UnknownCaseSerializer.ReadPayload(context, caseId));"); + sb.AppendLine(" return __foryUnion;"); } sb.AppendLine(" }"); @@ -1161,12 +1218,15 @@ private static void EmitReadCompatibleListArrayPayload( sb.AppendLine($"{indent}}}"); string elementTypeName = codec.CarrierKind == CarrierKind.Array ? ElementTypeName(codec.TypeName) : PackedArrayElementTypeName(codec.TypeId); uint elementTypeId = PackedArrayElementTypeId(codec.TypeId); + string elementBytesExpr = + $"(typeof({elementTypeName}).IsValueType ? global::System.Runtime.CompilerServices.Unsafe.SizeOf<{elementTypeName}>() : 4)"; if (codec.CarrierKind == CarrierKind.Array) { sb.AppendLine($"{indent}{codec.TypeName} {targetVar} = new {ElementTypeName(codec.TypeName)}[{lengthVar}];"); } else { + sb.AppendLine($"{indent}context.ReserveGraphMemory(1L + (long){lengthVar} * {elementBytesExpr});"); sb.AppendLine($"{indent}{codec.TypeName} {targetVar} = new({lengthVar});"); } @@ -1513,6 +1573,8 @@ private static void EmitReadPackedArrayPayload( } else { + string elementBytesExpr = GraphElementBytesExpr(PackedArrayElementTypeName(codec.TypeId)); + sb.AppendLine($"{indent}context.ReserveGraphMemory(1L + (long){countVar} * {elementBytesExpr});"); sb.AppendLine($"{indent}{codec.TypeName} {targetVar} = new({countVar});"); } @@ -1552,6 +1614,7 @@ private static void EmitReadCollectionPayload( string sameTypeVar = $"__forySameType{id++}"; string declaredVar = $"__foryDeclared{id++}"; sb.AppendLine($"{indent}int {lengthVar} = checked((int)context.Reader.ReadVarUInt32());"); + sb.AppendLine($"{indent}context.ReserveGraphMemory(1L + (long){lengthVar} * {GraphElementBytesExpr(element)});"); sb.AppendLine($"{indent}if ({lengthVar} != 0)"); sb.AppendLine($"{indent}{{"); sb.AppendLine($"{indent} context.Reader.CheckBound({lengthVar});"); @@ -1657,6 +1720,7 @@ private static void EmitReadMapPayload( FieldCodecModel value = codec.Generics[1]; string totalVar = $"__foryTotal{id++}"; sb.AppendLine($"{indent}int {totalVar} = checked((int)context.Reader.ReadVarUInt32());"); + sb.AppendLine($"{indent}context.ReserveGraphMemory(1L + (long){totalVar} * {GraphMapElementBytesExpr(key, value)});"); sb.AppendLine($"{indent}if ({totalVar} != 0)"); sb.AppendLine($"{indent}{{"); sb.AppendLine($"{indent} context.Reader.CheckBound({totalVar});"); @@ -1796,6 +1860,8 @@ private static MemberModel NonNullableMember(MemberModel member) member.Group, member.IsCollection, member.UseDictionaryTypeInfoCache, + member.UsesReferenceStorage, + member.FixedValueBytes, member.IsRefType, member.NeedsFieldTypeInfo, member.DynamicAnyKind, @@ -1815,6 +1881,62 @@ private static string ElementTypeName(string arrayTypeName) : "object"; } + private static string GraphElementBytesExpr(FieldCodecModel codec) + { + return GraphElementBytesExpr( + codec.Nullable && !codec.NullableValueType + ? StripNullableForTypeOf(codec.TypeName) + : codec.TypeName); + } + + private static string GraphElementBytesExpr(string typeName) + { + return $"(typeof({typeName}).IsValueType ? global::System.Runtime.CompilerServices.Unsafe.SizeOf<{typeName}>() : 4)"; + } + + private static string GraphMapElementBytesExpr(FieldCodecModel key, FieldCodecModel value) + { + return $"((long){GraphElementBytesExpr(key)} + {GraphElementBytesExpr(value)})"; + } + + private static string ModelGraphMemoryExpr(TypeModel model) + { + System.Collections.Generic.List parts = new() { "1L" }; + foreach (MemberModel member in model.SortedMembers) + { + parts.Add(FieldGraphMemoryExpr(member)); + } + + return string.Join(" + ", parts); + } + + private static string FieldGraphMemoryExpr(MemberModel member) + { + if (member.Classification.IsPrimitive && member.Classification.PrimitiveSize > 0) + { + return $"{member.Classification.PrimitiveSize}L"; + } + + if (member.UsesReferenceStorage) + { + return "4L"; + } + + if (member.FixedValueBytes > 0) + { + return $"{member.FixedValueBytes}L"; + } + + string typeName = StripNullableForTypeOf(member.TypeName); + return $"global::System.Runtime.CompilerServices.Unsafe.SizeOf<{typeName}>()"; + } + + private static bool IsConstGraphMemoryExpr(string expression) + { + return expression.IndexOf("typeof(", StringComparison.Ordinal) < 0 && + expression.IndexOf("Unsafe.", StringComparison.Ordinal) < 0; + } + private static string PackedArrayElementTypeName(uint typeId) { return typeId switch @@ -3055,6 +3177,7 @@ private static ForyAttributeKind GetForyAttributeKind(INamedTypeSymbol typeSymbo } TypeClassification classification = resolution.Classification; + int fixedValueBytes = FixedGraphValueBytes(unwrappedType, classification); int group = classification.IsPrimitive ? (isOptional ? 2 : 1) : 3; @@ -3089,6 +3212,8 @@ memberType is INamedTypeSymbol nts && group, classification.IsCollection || classification.IsMap, classification.IsMap && !IsTypeSealed(unwrappedType), + !unwrappedType.IsValueType, + fixedValueBytes, !unwrappedType.IsValueType && classification.TypeId != 21, FieldNeedsTypeInfo(classification, dynamicAnyKind, unwrappedType), dynamicAnyKind == DynamicAnyKind.None ? DynamicAnyKind.None : dynamicAnyKind, @@ -3097,6 +3222,42 @@ memberType is INamedTypeSymbol nts && schemaType is not null); } + private static int FixedGraphValueBytes(ITypeSymbol type, TypeClassification classification) + { + if (classification.IsPrimitive && classification.PrimitiveSize > 0) + { + return classification.PrimitiveSize; + } + + if (type.TypeKind == TypeKind.Enum && + type is INamedTypeSymbol enumType && + enumType.EnumUnderlyingType is not null) + { + return SpecialTypeBytes(enumType.EnumUnderlyingType.SpecialType); + } + + return type.SpecialType == SpecialType.System_Decimal ? 16 : 0; + } + + private static int SpecialTypeBytes(SpecialType specialType) + { + return specialType switch + { + SpecialType.System_Boolean or + SpecialType.System_SByte or + SpecialType.System_Byte => 1, + SpecialType.System_Int16 or + SpecialType.System_UInt16 => 2, + SpecialType.System_Int32 or + SpecialType.System_UInt32 or + SpecialType.System_Single => 4, + SpecialType.System_Int64 or + SpecialType.System_UInt64 or + SpecialType.System_Double => 8, + _ => 0, + }; + } + private static TypeMetaFieldTypeModel BuildTypeMetaFieldTypeModel( ITypeSymbol memberType, bool nullable, @@ -4363,6 +4524,8 @@ public MemberModel( int group, bool isCollection, bool useDictionaryTypeInfoCache, + bool usesReferenceStorage, + int fixedValueBytes, bool isRefType, bool needsFieldTypeInfo, DynamicAnyKind dynamicAnyKind, @@ -4382,6 +4545,8 @@ public MemberModel( Group = group; IsCollection = isCollection; UseDictionaryTypeInfoCache = useDictionaryTypeInfoCache; + UsesReferenceStorage = usesReferenceStorage; + FixedValueBytes = fixedValueBytes; IsRefType = isRefType; NeedsFieldTypeInfo = needsFieldTypeInfo; DynamicAnyKind = dynamicAnyKind; @@ -4402,6 +4567,8 @@ public MemberModel( public int Group { get; } public bool IsCollection { get; } public bool UseDictionaryTypeInfoCache { get; } + public bool UsesReferenceStorage { get; } + public int FixedValueBytes { get; } public bool IsRefType { get; } public bool NeedsFieldTypeInfo { get; } public DynamicAnyKind DynamicAnyKind { get; } diff --git a/csharp/src/Fory/AnySerializer.cs b/csharp/src/Fory/AnySerializer.cs index 1be0c3e20c..80fe56b478 100644 --- a/csharp/src/Fory/AnySerializer.cs +++ b/csharp/src/Fory/AnySerializer.cs @@ -95,16 +95,10 @@ public override void Write(WriteContext context, in object? value, RefMode refMo { uint reservedRefId = context.RefReader.ReserveRefId(); context.SetReservedRefId(reservedRefId); - try - { - object? value = ReadNonNullDynamicAny(context, readTypeInfo); - context.StoreRef(value); - return value; - } - finally - { - context.ClearReservedRefId(); - } + object? value = ReadNonNullDynamicAny(context, readTypeInfo); + context.StoreRef(value); + context.ClearReservedRefId(); + return value; } case RefFlag.NotNullValue: break; diff --git a/csharp/src/Fory/CollectionSerializers.cs b/csharp/src/Fory/CollectionSerializers.cs index c407153fd5..d66061a0cf 100644 --- a/csharp/src/Fory/CollectionSerializers.cs +++ b/csharp/src/Fory/CollectionSerializers.cs @@ -17,6 +17,7 @@ using System.Collections; using System.Collections.Immutable; +using System.Runtime.CompilerServices; namespace Apache.Fory; @@ -31,6 +32,18 @@ internal static class CollectionBits internal static class CollectionCodec { + private const int CollectionBytes = 1; + private const int ReferenceBytes = 4; + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal static int ElementBytes() => ElementStorage.Bytes; + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal static void ReserveElementStorage(ReadContext context, int count) + { + context.ReserveGraphMemory(CollectionBytes + (long)count * ElementBytes()); + } + private static bool NeedsCompatibleElementTypeMeta(TypeInfo typeInfo, WriteContext context) { return context.Compatible && @@ -195,16 +208,65 @@ public static void WriteCollectionData( } } - public static List ReadCollectionData(Serializer elementSerializer, ReadContext context) + private static class ElementStorage { - TypeInfo elementTypeInfo = context.TypeResolver.GetTypeInfo(); - int length = checked((int)context.Reader.ReadVarUInt32()); - if (length == 0) + internal static readonly int Bytes = typeof(T).IsValueType ? Unsafe.SizeOf() : ReferenceBytes; + } + + private interface IValueSink + { + void Add(T value); + } + + private readonly struct CollectionSink(TCollection values) : IValueSink + where TCollection : ICollection + { + public void Add(T value) => values.Add(value); + } + + private struct ArraySink(T[] values) : IValueSink + { + private int _index; + + public void Add(T value) { - return []; + values[_index] = value; + _index++; } + } + private readonly struct QueueSink(Queue values) : IValueSink + { + public void Add(T value) => values.Enqueue(value); + } + + private readonly struct StackSink(Stack values) : IValueSink + { + public void Add(T value) => values.Push(value); + } + + private static int ReadLength(ReadContext context) + { + int length = checked((int)context.Reader.ReadVarUInt32()); + ReserveElementStorage(context, length); + return length; + } + + private static byte ReadHeader(ReadContext context, int length) + { byte header = context.Reader.ReadUInt8(); + context.Reader.CheckBound(length); + return header; + } + + private static void ReadElements( + Serializer elementSerializer, + ReadContext context, + int length, + byte header, + TSink sink) + where TSink : struct, IValueSink + { // IMPORTANT: collection readers must obey the ref/null bits written on // the wire, not the local generic metadata that may imply a different // ref policy. Shared xlang tests intentionally deserialize one ref @@ -213,18 +275,17 @@ public static List ReadCollectionData(Serializer elementSerializer, Rea bool hasNull = (header & CollectionBits.HasNull) != 0; bool declared = (header & CollectionBits.DeclaredElementType) != 0; bool sameType = (header & CollectionBits.SameType) != 0; - context.Reader.CheckBound(length); - List values = new(length); + if (!sameType) { if (trackRef) { for (int i = 0; i < length; i++) { - values.Add(elementSerializer.Read(context, RefMode.Tracking, true)); + sink.Add(elementSerializer.Read(context, RefMode.Tracking, true)); } - return values; + return; } if (hasNull) @@ -234,11 +295,11 @@ public static List ReadCollectionData(Serializer elementSerializer, Rea sbyte refFlag = context.Reader.ReadInt8(); if (refFlag == (sbyte)RefFlag.Null) { - values.Add((T)elementSerializer.DefaultObject!); + sink.Add((T)elementSerializer.DefaultObject!); } else if (refFlag == (sbyte)RefFlag.NotNullValue) { - values.Add(elementSerializer.Read(context, RefMode.None, true)); + sink.Add(elementSerializer.Read(context, RefMode.None, true)); } else { @@ -250,11 +311,11 @@ public static List ReadCollectionData(Serializer elementSerializer, Rea { for (int i = 0; i < length; i++) { - values.Add(elementSerializer.Read(context, RefMode.None, true)); + sink.Add(elementSerializer.Read(context, RefMode.None, true)); } } - return values; + return; } if (!declared) @@ -266,7 +327,7 @@ public static List ReadCollectionData(Serializer elementSerializer, Rea { for (int i = 0; i < length; i++) { - values.Add(elementSerializer.Read(context, RefMode.Tracking, false)); + sink.Add(elementSerializer.Read(context, RefMode.Tracking, false)); } if (!declared) @@ -274,7 +335,7 @@ public static List ReadCollectionData(Serializer elementSerializer, Rea context.ClearReadTypeInfo(typeof(T)); } - return values; + return; } if (hasNull) @@ -284,11 +345,11 @@ public static List ReadCollectionData(Serializer elementSerializer, Rea sbyte refFlag = context.Reader.ReadInt8(); if (refFlag == (sbyte)RefFlag.Null) { - values.Add((T)elementSerializer.DefaultObject!); + sink.Add((T)elementSerializer.DefaultObject!); } else { - values.Add(elementSerializer.ReadData(context)); + sink.Add(elementSerializer.ReadData(context)); } } } @@ -296,7 +357,7 @@ public static List ReadCollectionData(Serializer elementSerializer, Rea { for (int i = 0; i < length; i++) { - values.Add(elementSerializer.ReadData(context)); + sink.Add(elementSerializer.ReadData(context)); } } @@ -304,7 +365,140 @@ public static List ReadCollectionData(Serializer elementSerializer, Rea { context.ClearReadTypeInfo(typeof(T)); } + } + + public static List ReadCollectionData(Serializer elementSerializer, ReadContext context) + { + int length = ReadLength(context); + if (length == 0) + { + List empty = []; + context.StoreRef(empty); + return empty; + } + byte header = ReadHeader(context, length); + List values = new(length); + context.StoreRef(values); + ReadElements(elementSerializer, context, length, header, new CollectionSink, T>(values)); + return values; + } + + internal static HashSet ReadHashSetData(Serializer elementSerializer, ReadContext context) + where T : notnull + { + int length = ReadLength(context); + if (length == 0) + { + HashSet empty = new(length); + context.StoreRef(empty); + return empty; + } + + byte header = ReadHeader(context, length); + HashSet values = new(length); + context.StoreRef(values); + ReadElements(elementSerializer, context, length, header, new CollectionSink, T>(values)); + return values; + } + + internal static SortedSet ReadSortedSetData(Serializer elementSerializer, ReadContext context) + where T : notnull + { + int length = ReadLength(context); + SortedSet values = new(); + context.StoreRef(values); + if (length == 0) + { + return values; + } + + byte header = ReadHeader(context, length); + ReadElements(elementSerializer, context, length, header, new CollectionSink, T>(values)); + return values; + } + + internal static ImmutableHashSet ReadImmutableHashSetData( + Serializer elementSerializer, + ReadContext context) + where T : notnull + { + int length = ReadLength(context); + ImmutableHashSet.Builder values = ImmutableHashSet.CreateBuilder(); + if (length == 0) + { + return values.ToImmutable(); + } + + byte header = ReadHeader(context, length); + ReadElements( + elementSerializer, + context, + length, + header, + new CollectionSink.Builder, T>(values)); + return values.ToImmutable(); + } + + internal static LinkedList ReadLinkedListData(Serializer elementSerializer, ReadContext context) + { + int length = ReadLength(context); + LinkedList values = new(); + context.StoreRef(values); + if (length == 0) + { + return values; + } + + byte header = ReadHeader(context, length); + ReadElements(elementSerializer, context, length, header, new CollectionSink, T>(values)); + return values; + } + + internal static Queue ReadQueueData(Serializer elementSerializer, ReadContext context) + { + int length = ReadLength(context); + Queue values = new(length); + context.StoreRef(values); + if (length == 0) + { + return values; + } + + byte header = ReadHeader(context, length); + ReadElements(elementSerializer, context, length, header, new QueueSink(values)); + return values; + } + + internal static Stack ReadStackData(Serializer elementSerializer, ReadContext context) + { + int length = ReadLength(context); + Stack values = new(length); + context.StoreRef(values); + if (length == 0) + { + return values; + } + + byte header = ReadHeader(context, length); + ReadElements(elementSerializer, context, length, header, new StackSink(values)); + return values; + } + + public static T[] ReadArrayData(Serializer elementSerializer, ReadContext context) + { + int length = ReadLength(context); + if (length == 0) + { + T[] empty = []; + context.StoreRef(empty); + return empty; + } + + byte header = ReadHeader(context, length); + T[] values = new T[length]; + context.StoreRef(values); + ReadElements(elementSerializer, context, length, header, new ArraySink(values)); return values; } } @@ -390,7 +584,17 @@ public static bool TryWritePayload(object value, WriteContext context, bool hasG public static object ReadMapPayload(ReadContext context) { - NullableKeyDictionary map = context.TypeResolver.GetSerializer>().ReadData(context); + Serializer> serializer = + context.TypeResolver.GetSerializer>(); + bool storeRef = context._reservedRefIds.Count != 0; + NullableKeyDictionary map = serializer.ReadData(context); + if (storeRef) + { + // Dynamic tracked maps publish the nullable-key owner itself so + // nested references resolve to the same returned object. + return map; + } + if (map.HasNullKey) { return map; @@ -521,8 +725,7 @@ public override void WriteData(WriteContext context, in T[] value, bool hasGener public override T[] ReadData(ReadContext context) { - List values = CollectionCodec.ReadCollectionData(context.TypeResolver.GetSerializer(), context); - return values.ToArray(); + return CollectionCodec.ReadArrayData(context.TypeResolver.GetSerializer(), context); } } @@ -554,7 +757,7 @@ public override void WriteData(WriteContext context, in HashSet value, bool h public override HashSet ReadData(ReadContext context) { - return [.. CollectionCodec.ReadCollectionData(context.TypeResolver.GetSerializer(), context)]; + return CollectionCodec.ReadHashSetData(context.TypeResolver.GetSerializer(), context); } } @@ -570,7 +773,7 @@ public override void WriteData(WriteContext context, in SortedSet value, bool public override SortedSet ReadData(ReadContext context) { - return [.. CollectionCodec.ReadCollectionData(context.TypeResolver.GetSerializer(), context)]; + return CollectionCodec.ReadSortedSetData(context.TypeResolver.GetSerializer(), context); } } @@ -586,7 +789,7 @@ public override void WriteData(WriteContext context, in ImmutableHashSet valu public override ImmutableHashSet ReadData(ReadContext context) { - return ImmutableHashSet.CreateRange(CollectionCodec.ReadCollectionData(context.TypeResolver.GetSerializer(), context)); + return CollectionCodec.ReadImmutableHashSetData(context.TypeResolver.GetSerializer(), context); } } @@ -602,7 +805,7 @@ public override void WriteData(WriteContext context, in LinkedList value, boo public override LinkedList ReadData(ReadContext context) { - return new LinkedList(CollectionCodec.ReadCollectionData(context.TypeResolver.GetSerializer(), context)); + return CollectionCodec.ReadLinkedListData(context.TypeResolver.GetSerializer(), context); } } @@ -618,14 +821,7 @@ public override void WriteData(WriteContext context, in Queue value, bool has public override Queue ReadData(ReadContext context) { - List values = CollectionCodec.ReadCollectionData(context.TypeResolver.GetSerializer(), context); - Queue queue = new(values.Count); - for (int i = 0; i < values.Count; i++) - { - queue.Enqueue(values[i]); - } - - return queue; + return CollectionCodec.ReadQueueData(context.TypeResolver.GetSerializer(), context); } } @@ -654,13 +850,6 @@ public override void WriteData(WriteContext context, in Stack value, bool has public override Stack ReadData(ReadContext context) { - List values = CollectionCodec.ReadCollectionData(context.TypeResolver.GetSerializer(), context); - Stack stack = new(values.Count); - for (int i = 0; i < values.Count; i++) - { - stack.Push(values[i]); - } - - return stack; + return CollectionCodec.ReadStackData(context.TypeResolver.GetSerializer(), context); } } diff --git a/csharp/src/Fory/Config.cs b/csharp/src/Fory/Config.cs index 438039d2c8..e84193af50 100644 --- a/csharp/src/Fory/Config.cs +++ b/csharp/src/Fory/Config.cs @@ -28,6 +28,7 @@ internal Config( bool compatible, bool checkStructVersion, int maxDepth, + long maxGraphMemoryBytes, int maxTypeFields, int maxTypeMetaBytes, int maxSchemaVersionsPerType, @@ -53,11 +54,16 @@ internal Config( { throw new ArgumentOutOfRangeException(nameof(maxAverageSchemaVersionsPerType), "MaxAverageSchemaVersionsPerType must be greater than 0."); } + if (maxGraphMemoryBytes <= 0) + { + throw new ArgumentOutOfRangeException(nameof(maxGraphMemoryBytes), "MaxGraphMemoryBytes must be greater than 0."); + } TrackRef = trackRef; Compatible = compatible; CheckStructVersion = checkStructVersion; MaxDepth = maxDepth; + MaxGraphMemoryBytes = maxGraphMemoryBytes; MaxTypeFields = maxTypeFields; MaxTypeMetaBytes = maxTypeMetaBytes; MaxSchemaVersionsPerType = maxSchemaVersionsPerType; @@ -84,6 +90,11 @@ internal Config( /// public int MaxDepth { get; } + /// + /// Gets the maximum estimated graph memory accepted during one root deserialization. + /// + public long MaxGraphMemoryBytes { get; } + /// /// Gets the maximum accepted field count in one received struct TypeMeta. /// @@ -114,6 +125,7 @@ public sealed class ForyBuilder private bool? _compatible; private bool _checkStructVersion; private int _maxDepth = 20; + private long _maxGraphMemoryBytes = 128L * 1024 * 1024; private int _maxTypeFields = 512; private int _maxTypeMetaBytes = 4096; private int _maxSchemaVersionsPerType = 10; @@ -169,6 +181,20 @@ public ForyBuilder MaxDepth(int value) return this; } + /// + /// Sets the maximum estimated graph memory accepted during one root deserialization. + /// + public ForyBuilder MaxGraphMemoryBytes(long value) + { + if (value <= 0) + { + throw new ArgumentOutOfRangeException(nameof(value), "MaxGraphMemoryBytes must be greater than 0."); + } + + _maxGraphMemoryBytes = value; + return this; + } + /// /// Sets the maximum accepted field count in one received struct TypeMeta. /// @@ -235,6 +261,7 @@ private Config BuildConfig() compatible: compatible, checkStructVersion: compatible ? false : _checkStructVersion, maxDepth: _maxDepth, + maxGraphMemoryBytes: _maxGraphMemoryBytes, maxTypeFields: _maxTypeFields, maxTypeMetaBytes: _maxTypeMetaBytes, maxSchemaVersionsPerType: _maxSchemaVersionsPerType, diff --git a/csharp/src/Fory/DictionarySerializers.cs b/csharp/src/Fory/DictionarySerializers.cs index 5aa49dfa75..8722be02bd 100644 --- a/csharp/src/Fory/DictionarySerializers.cs +++ b/csharp/src/Fory/DictionarySerializers.cs @@ -16,6 +16,7 @@ // under the License. using System.Collections.Concurrent; +using System.Runtime.CompilerServices; namespace Apache.Fory; @@ -33,6 +34,19 @@ public abstract class DictionaryLikeSerializer : Seri where TDictionary : class, IDictionary where TKey : notnull { + private const int MapBytes = 1; + private const int ReferenceBytes = 4; + private static readonly long MapElementBytes = (long)ElementBytes() + ElementBytes(); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static int ElementBytes() => typeof(T).IsValueType ? Unsafe.SizeOf() : ReferenceBytes; + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static void ReserveMapStorage(ReadContext context, int count) + { + context.ReserveGraphMemory(MapBytes + count * MapElementBytes); + } + public override TDictionary DefaultValue => null!; protected abstract TDictionary CreateMap(int capacity); @@ -214,11 +228,18 @@ public override TDictionary ReadData(ReadContext context) int totalLength = checked((int)context.Reader.ReadVarUInt32()); if (totalLength == 0) { - return CreateMap(0); + ReserveMapStorage(context, totalLength); + TDictionary empty = CreateMap(0); + context.StoreRef(empty); + + return empty; } + ReserveMapStorage(context, totalLength); context.Reader.CheckBound(totalLength); TDictionary map = CreateMap(totalLength); + context.StoreRef(map); + bool keyDynamicType = keyTypeInfo.IsDynamicType; bool valueDynamicType = valueTypeInfo.IsDynamicType; int readCount = 0; diff --git a/csharp/src/Fory/FieldSkipper.cs b/csharp/src/Fory/FieldSkipper.cs index aaca46ec23..3120e16bf7 100644 --- a/csharp/src/Fory/FieldSkipper.cs +++ b/csharp/src/Fory/FieldSkipper.cs @@ -119,16 +119,10 @@ private static bool HasInlineTypeInfo(uint typeId) { uint reservedRefId = context.RefReader.ReserveRefId(); context.SetReservedRefId(reservedRefId); - try - { - object? value = ReadInlineTypedPayload(context); - context.StoreRef(value); - return value; - } - finally - { - context.ClearReservedRefId(); - } + object? value = ReadInlineTypedPayload(context); + context.StoreRef(value); + context.ClearReservedRefId(); + return value; } case RefFlag.NotNullValue: return ReadInlineTypedPayload(context); @@ -184,16 +178,10 @@ private static bool HasInlineTypeInfo(uint typeId) { uint reservedRefId = context.RefReader.ReserveRefId(); context.SetReservedRefId(reservedRefId); - try - { - object? value = context.TypeResolver.ReadAnyValue(typeInfo, context); - context.StoreRef(value); - return value; - } - finally - { - context.ClearReservedRefId(); - } + object? value = context.TypeResolver.ReadAnyValue(typeInfo, context); + context.StoreRef(value); + context.ClearReservedRefId(); + return value; } case RefFlag.NotNullValue: return context.TypeResolver.ReadAnyValue(typeInfo, context); diff --git a/csharp/src/Fory/Fory.cs b/csharp/src/Fory/Fory.cs index 9bbafd1775..6c456e543a 100644 --- a/csharp/src/Fory/Fory.cs +++ b/csharp/src/Fory/Fory.cs @@ -28,12 +28,14 @@ namespace Apache.Fory; public sealed class Fory { private readonly TypeResolver _typeResolver; + private readonly bool _trackRef; private WriteContext _writeContext; private ReadContext _readContext; internal Fory(Config config) { Config = config; + _trackRef = config.TrackRef; _typeResolver = new TypeResolver(); _writeContext = new WriteContext( new ByteWriter(), @@ -275,22 +277,78 @@ private static void ThrowInvalidRootHeader(byte bitmap) => [MethodImpl(MethodImplOptions.AggressiveInlining)] private T DeserializeFromReader(ByteReader reader) { - ReadHead(reader); - Serializer serializer = _typeResolver.GetSerializer(); ReadContext readContext = _readContext; readContext.ResetFor(reader); - RefMode refMode = Config.TrackRef ? RefMode.Tracking : RefMode.NullOnly; - T value = serializer.Read(readContext, refMode, true); - readContext.RefReader.Reset(); - readContext._typeMetaType = null; - readContext._typeMeta = null; - readContext._typeMetaByType?.ClearKeys(); - readContext._readTypeInfoByType.ClearKeys(); - readContext._reservedRefIds.Clear(); - readContext._cachedTypeMetaType = null; - readContext._cachedTypeMeta = null; - readContext._currentDynamicReadDepth = 0; - return value; + readContext._remainingGraphMemoryBytes = Config.MaxGraphMemoryBytes; + try + { + ReadHead(reader); + Serializer serializer = _typeResolver.GetSerializer(); + return _trackRef + ? serializer.Read(readContext, RefMode.Tracking, true) + : ReadRootNoRef(serializer, readContext); + } + finally + { + if (_trackRef || readContext.RefReader.HasRefs) + { + readContext.RefReader.Reset(); + } + if (readContext._reservedRefIds.Count != 0) + { + readContext._reservedRefIds.Clear(); + } + readContext._typeMetaType = null; + readContext._typeMeta = null; + readContext._typeMetaByType?.ClearKeys(); + readContext._readTypeInfoByType.ClearKeys(); + readContext._cachedTypeMetaType = null; + readContext._cachedTypeMeta = null; + readContext._currentDynamicReadDepth = 0; + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static T ReadRootNoRef(Serializer serializer, ReadContext context) + { + RefFlag flag = (RefFlag)context.Reader.ReadInt8(); + if (flag == RefFlag.NotNullValue) + { + return serializer.Read(context, RefMode.None, true); + } + + if (flag == RefFlag.Null) + { + return serializer.DefaultValue; + } + + return ReadRootRefFallback(serializer, context, flag); + } + + [MethodImpl(MethodImplOptions.NoInlining)] + private static T ReadRootRefFallback(Serializer serializer, ReadContext context, RefFlag flag) + { + switch (flag) + { + case RefFlag.Ref: + { + uint refId = context.RefReader.ReadRefId(context.Reader); + return context.RefReader.GetRef(refId); + } + case RefFlag.RefValue: + { + uint reservedRefId = context.RefReader.ReserveRefId(); + context.SetReservedRefId(reservedRefId); + context.TypeResolver.ReadTypeInfo(serializer, context); + T value = serializer.ReadData(context); + context.StoreRef(value); + context.ClearReservedRefId(); + context.RefReader.Reset(); + return value; + } + default: + throw new RefException($"invalid ref flag {(sbyte)flag}"); + } } } diff --git a/csharp/src/Fory/GraphMemory.cs b/csharp/src/Fory/GraphMemory.cs new file mode 100644 index 0000000000..e1cd3820c9 --- /dev/null +++ b/csharp/src/Fory/GraphMemory.cs @@ -0,0 +1,62 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +using System.Runtime.CompilerServices; + +namespace Apache.Fory; + +internal static class GraphMemory +{ + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal static long ValueOwnerBytes() => ValueOwner.Bytes; + + private static class ValueOwner + { + internal static readonly long Bytes = Compute(); + + private static long Compute() + { + Type type = typeof(T); + if (!ShouldReserve(type)) + { + return 0; + } + + int bytes = Unsafe.SizeOf(); + return bytes == 0 ? 1 : bytes; + } + + private static bool ShouldReserve(Type type) + { + if (!type.IsValueType || + Nullable.GetUnderlyingType(type) is not null || + type.IsEnum || + type.IsPrimitive) + { + return false; + } + + return type != typeof(decimal) && + type != typeof(Half) && + type != typeof(BFloat16) && + type != typeof(DateOnly) && + type != typeof(DateTime) && + type != typeof(DateTimeOffset) && + type != typeof(TimeSpan); + } + } +} diff --git a/csharp/src/Fory/NullableKeyDictionary.cs b/csharp/src/Fory/NullableKeyDictionary.cs index d6c8caab47..a12bbf61d1 100644 --- a/csharp/src/Fory/NullableKeyDictionary.cs +++ b/csharp/src/Fory/NullableKeyDictionary.cs @@ -16,6 +16,7 @@ // under the License. using System.Collections; +using System.Runtime.CompilerServices; namespace Apache.Fory; @@ -390,6 +391,19 @@ IEnumerator IEnumerable.GetEnumerator() public sealed class NullableKeyDictionarySerializer : Serializer> { + private const int MapBytes = 1; + private const int ReferenceBytes = 4; + private static readonly long MapElementBytes = (long)ElementBytes() + ElementBytes(); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static int ElementBytes() => typeof(T).IsValueType ? Unsafe.SizeOf() : ReferenceBytes; + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static void ReserveMapStorage(ReadContext context, int count) + { + context.ReserveGraphMemory(MapBytes + count * MapElementBytes); + } + public override NullableKeyDictionary DefaultValue => null!; public override void WriteData(WriteContext context, in NullableKeyDictionary value, bool hasGenerics) @@ -537,11 +551,18 @@ public override NullableKeyDictionary ReadData(ReadContext context int totalLength = checked((int)context.Reader.ReadVarUInt32()); if (totalLength == 0) { - return new NullableKeyDictionary(); + ReserveMapStorage(context, totalLength); + NullableKeyDictionary empty = new(); + context.StoreRef(empty); + + return empty; } + ReserveMapStorage(context, totalLength); context.Reader.CheckBound(totalLength); NullableKeyDictionary map = new(totalLength); + context.StoreRef(map); + bool keyDynamicType = keyTypeInfo.IsDynamicType; bool valueDynamicType = valueTypeInfo.IsDynamicType; int readCount = 0; diff --git a/csharp/src/Fory/PrimitiveDictionarySerializers.cs b/csharp/src/Fory/PrimitiveDictionarySerializers.cs index a136bd57bd..d9520b3525 100644 --- a/csharp/src/Fory/PrimitiveDictionarySerializers.cs +++ b/csharp/src/Fory/PrimitiveDictionarySerializers.cs @@ -17,6 +17,7 @@ using System.Collections; using System.Collections.Concurrent; +using System.Runtime.CompilerServices; namespace Apache.Fory; @@ -663,6 +664,23 @@ public static void WriteMap() => ElementStorage.Bytes; + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static void ReserveMapStorage(ReadContext context, int count) + { + context.ReserveGraphMemory(MapBytes + count * ((long)ElementBytes() + ElementBytes())); + } + + private static class ElementStorage + { + internal static readonly int Bytes = typeof(T).IsValueType ? Unsafe.SizeOf() : ReferenceBytes; + } + public static TMap ReadMap(ReadContext context) where TKey : notnull where TKeyCodec : struct, IPrimitiveDictionaryCodec @@ -672,9 +690,11 @@ public static TMap ReadMap( int totalLength = checked((int)context.Reader.ReadVarUInt32()); if (totalLength == 0) { + ReserveMapStorage(context, totalLength); return TMapOps.Create(0); } + ReserveMapStorage(context, totalLength); context.Reader.CheckBound(totalLength); TMap map = TMapOps.Create(totalLength); TypeId keyTypeId = TKeyCodec.WireTypeId; diff --git a/csharp/src/Fory/ReadContext.cs b/csharp/src/Fory/ReadContext.cs index f83ac0e99e..13948ab1d7 100644 --- a/csharp/src/Fory/ReadContext.cs +++ b/csharp/src/Fory/ReadContext.cs @@ -15,11 +15,14 @@ // specific language governing permissions and limitations // under the License. +using System.Runtime.CompilerServices; + namespace Apache.Fory; public sealed class ReadContext { private const int MinRemoteTypeMetaLimit = 8192; + private const uint NoReservedRefId = uint.MaxValue; private readonly ReusableArray _typeMetaRefs = new(); private readonly UInt64Map _typeMetasByHeader = new(); @@ -29,6 +32,8 @@ public sealed class ReadContext private readonly List _readMetaStrings = []; internal readonly UInt64Map _readTypeInfoByType = new(); + // Consumed slots stay on the stack until their matching reader scope clears them. That lets + // nested child reads restore an outer owner that has not been materialized yet. internal readonly List _reservedRefIds = []; private readonly int _maxDynamicReadDepth; internal Type? _typeMetaType; @@ -40,6 +45,7 @@ public sealed class ReadContext private readonly Dictionary _remoteSchemaVersionsByType = []; private readonly Config _config; private int _totalAcceptedSchemaVersions; + internal long _remainingGraphMemoryBytes; public ReadContext( ByteReader reader, @@ -70,6 +76,38 @@ public ReadContext( internal RefReader RefReader { get; } + /// + /// Reserves estimated graph memory for the current root deserialization. + /// + /// + /// Serializer owners compute owner-specific formulas and pass raw bytes here. This + /// accounting does not replace byte-availability checks before backing allocation. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public void ReserveGraphMemory(long bytes) + { + long remaining = _remainingGraphMemoryBytes; + if ((ulong)bytes > (ulong)remaining) + { + ReserveGraphMemorySlow(bytes, remaining); + return; + } + + _remainingGraphMemoryBytes = remaining - bytes; + } + + [MethodImpl(MethodImplOptions.NoInlining)] + private void ReserveGraphMemorySlow(long bytes, long remaining) + { + if (bytes < 0) + { + throw new InvalidDataException("graph memory estimate overflows"); + } + + throw new InvalidDataException( + $"estimated graph memory request {bytes} bytes exceeds MaxGraphMemoryBytes remaining budget {remaining} bytes out of effective limit {_config.MaxGraphMemoryBytes} bytes"); + } + internal void ResetFor(ByteReader reader) { Reader = reader; @@ -404,14 +442,23 @@ internal void ClearReadTypeInfo(Type type) _readTypeInfoByType.Remove(TypeMapKey.Get(type)); } + [MethodImpl(MethodImplOptions.AggressiveInlining)] public void StoreRef(object? value) { - if (_reservedRefIds.Count == 0) + int index = _reservedRefIds.Count - 1; + if (index < 0) + { + return; + } + + uint refId = _reservedRefIds[index]; + if (refId == NoReservedRefId) { return; } - RefReader.StoreRefAt(_reservedRefIds[^1], value); + RefReader.StoreRefAt(refId, value); + _reservedRefIds[index] = NoReservedRefId; } internal void SetReservedRefId(uint refId) @@ -421,9 +468,10 @@ internal void SetReservedRefId(uint refId) internal void ClearReservedRefId() { - if (_reservedRefIds.Count > 0) + int count = _reservedRefIds.Count; + if (count > 0) { - _reservedRefIds.RemoveAt(_reservedRefIds.Count - 1); + _reservedRefIds.RemoveAt(count - 1); } } diff --git a/csharp/src/Fory/RefResolver.cs b/csharp/src/Fory/RefResolver.cs index d4e6dbf574..98ccb81fff 100644 --- a/csharp/src/Fory/RefResolver.cs +++ b/csharp/src/Fory/RefResolver.cs @@ -60,6 +60,11 @@ public sealed class RefReader { private readonly List _refs = []; + internal bool HasRefs + { + get => _refs.Count != 0; + } + public RefFlag ReadRefFlag(ByteReader reader) { return (RefFlag)reader.ReadInt8(); @@ -107,11 +112,22 @@ public T GetRef(uint refId) throw new RefException($"ref_id out of range: {refId}"); } - return _refs[index]; + object? value = _refs[index]; + if (value is null) + { + throw new RefException($"ref_id {refId} has not been published"); + } + + return value; } public void Reset() { + if (_refs.Count == 0) + { + return; + } + _refs.Clear(); } } diff --git a/csharp/src/Fory/Serializer.cs b/csharp/src/Fory/Serializer.cs index d2bc727c20..40d027a49f 100644 --- a/csharp/src/Fory/Serializer.cs +++ b/csharp/src/Fory/Serializer.cs @@ -114,21 +114,15 @@ public virtual T Read(ReadContext context, RefMode refMode, bool readTypeInfo) { uint reservedRefId = context.RefReader.ReserveRefId(); context.SetReservedRefId(reservedRefId); - try + if (readTypeInfo) { - if (readTypeInfo) - { - context.TypeResolver.ReadTypeInfo(this, context); - } - - T value = ReadData(context); - context.StoreRef(value); - return value; - } - finally - { - context.ClearReservedRefId(); + context.TypeResolver.ReadTypeInfo(this, context); } + + T value = ReadData(context); + context.StoreRef(value); + context.ClearReservedRefId(); + return value; } case RefFlag.NotNullValue: break; @@ -142,6 +136,12 @@ public virtual T Read(ReadContext context, RefMode refMode, bool readTypeInfo) context.TypeResolver.ReadTypeInfo(this, context); } + long graphBytes = GraphMemory.ValueOwnerBytes(); + if (graphBytes != 0) + { + context.ReserveGraphMemory(graphBytes); + } + return ReadData(context); } } diff --git a/csharp/src/Fory/TypeInfo.cs b/csharp/src/Fory/TypeInfo.cs index 8341d418c1..d0c43e082c 100644 --- a/csharp/src/Fory/TypeInfo.cs +++ b/csharp/src/Fory/TypeInfo.cs @@ -102,6 +102,7 @@ internal static TypeInfo Create(Type type, Serializer serializer) bool evolving = ResolveStructEvolving(type, userTypeKind); bool isNullableType = !type.IsValueType || Nullable.GetUnderlyingType(type) is not null; bool isRefType = type != typeof(string) && !type.IsValueType; + long boxedValueBytes = GraphMemory.ValueOwnerBytes(); return new TypeInfo( type, serializer, @@ -118,7 +119,7 @@ internal static TypeInfo Create(Type type, Serializer serializer) namespaceName: null, typeName: null, (context, value, hasGenerics) => WriteDataObject(serializer, context, value, hasGenerics), - context => serializer.ReadData(context), + context => ReadDataObject(serializer, context, boxedValueBytes), (context, value, refMode, writeTypeInfo, hasGenerics) => WriteObject(serializer, context, value, refMode, writeTypeInfo, hasGenerics), (context, refMode, readTypeInfo) => serializer.Read(context, refMode, readTypeInfo), @@ -178,6 +179,16 @@ private static void WriteDataObject(Serializer serializer, WriteContext co serializer.WriteData(context, CoerceRuntimeValue(serializer, value), hasGenerics); } + private static object? ReadDataObject(Serializer serializer, ReadContext context, long boxedValueBytes) + { + if (boxedValueBytes != 0) + { + context.ReserveGraphMemory(boxedValueBytes); + } + + return serializer.ReadData(context); + } + private static void WriteObject( Serializer serializer, WriteContext context, diff --git a/csharp/src/Fory/TypeResolver.cs b/csharp/src/Fory/TypeResolver.cs index a5cb888c4b..9bb5c7ff37 100644 --- a/csharp/src/Fory/TypeResolver.cs +++ b/csharp/src/Fory/TypeResolver.cs @@ -982,10 +982,7 @@ TypeId.CompatibleStruct or return actualWireType == expected; } - private object? ReadRegisteredValue( - TypeInfo typeInfo, - ReadContext context, - TypeMeta? typeMeta) + private object? ReadRegisteredValue(TypeInfo typeInfo, ReadContext context, TypeMeta? typeMeta) { if (typeMeta is not null) { diff --git a/csharp/src/Fory/UnknownCaseSerializer.cs b/csharp/src/Fory/UnknownCaseSerializer.cs index f3866247c4..bd4c3c9597 100644 --- a/csharp/src/Fory/UnknownCaseSerializer.cs +++ b/csharp/src/Fory/UnknownCaseSerializer.cs @@ -51,16 +51,10 @@ public static UnknownCase ReadPayload(ReadContext context, int caseId) { uint reservedRefId = context.RefReader.ReserveRefId(); context.SetReservedRefId(reservedRefId); - try - { - (uint typeId, object? value) = ReadNonNullPayload(context); - context.StoreRef(value); - return UnknownCase.FromRuntime(caseId, typeId, value); - } - finally - { - context.ClearReservedRefId(); - } + (uint typeId, object? value) = ReadNonNullPayload(context); + context.StoreRef(value); + context.ClearReservedRefId(); + return UnknownCase.FromRuntime(caseId, typeId, value); } case RefFlag.NotNullValue: { diff --git a/csharp/tests/Fory.Tests/ForyRuntimeTests.cs b/csharp/tests/Fory.Tests/ForyRuntimeTests.cs index 6d93003948..650eccd3c4 100644 --- a/csharp/tests/Fory.Tests/ForyRuntimeTests.cs +++ b/csharp/tests/Fory.Tests/ForyRuntimeTests.cs @@ -62,6 +62,12 @@ public sealed class Node public Node? Next { get; set; } } +[ForyStruct] +public sealed class AnyNode +{ + public object? Next { get; set; } +} + [ForyStruct] public sealed class FieldOrder { @@ -606,6 +612,148 @@ public void ForyReusedContextsHandleSequentialCalls() } } + [Fact] + public void WireTrackingRefsWorkWithNoRefReader() + { + ForyRuntime writer = ForyRuntime.Builder().TrackRef(true).Build(); + writer.Register(953); + ForyRuntime reader = ForyRuntime.Builder().TrackRef(false).Build(); + reader.Register(953); + + Node source = new() { Value = 7 }; + source.Next = source; + + Node decoded = reader.Deserialize(writer.Serialize(source)); + + Assert.Equal(7, decoded.Value); + Assert.NotNull(decoded.Next); + Assert.Same(decoded, decoded.Next); + } + + [Fact] + public void WireTrackingRefsResetAfterNoRefRoot() + { + ForyRuntime writer = ForyRuntime.Builder().TrackRef(true).Build(); + writer.Register(954); + ForyRuntime reader = ForyRuntime.Builder().TrackRef(false).Build(); + reader.Register(954); + + Node source = new() { Value = 9 }; + source.Next = source; + _ = reader.Deserialize(writer.Serialize(source)); + + ByteWriter staleRefPayload = new(); + staleRefPayload.WriteUInt8(ForyHeaderFlag.IsXlang); + staleRefPayload.WriteInt8((sbyte)RefFlag.Ref); + staleRefPayload.WriteVarUInt32(0); + + Assert.Throws(() => reader.Deserialize(staleRefPayload.ToArray())); + } + + [Fact] + public void DynamicAnyPublishesTrackedRef() + { + ForyRuntime writer = ForyRuntime.Builder().TrackRef(true).Build(); + writer.Register(955); + ForyRuntime reader = ForyRuntime.Builder().TrackRef(false).Build(); + reader.Register(955); + + Node source = new() { Value = 11 }; + source.Next = source; + + object decodedObject = reader.Deserialize(writer.Serialize(source)); + Node decoded = Assert.IsType(decodedObject); + + Assert.Equal(11, decoded.Value); + Assert.NotNull(decoded.Next); + Assert.Same(decoded, decoded.Next); + } + + [Fact] + public void DynamicContainersPublishTrackedRefs() + { + ForyRuntime writer = ForyRuntime.Builder().TrackRef(true).Build(); + ForyRuntime reader = ForyRuntime.Builder().TrackRef(false).Build(); + + List list = []; + list.Add(list); + object decodedListObject = reader.Deserialize(writer.Serialize(list)); + List decodedList = Assert.IsType>(decodedListObject); + Assert.Same(decodedList, decodedList[0]); + + HashSet set = []; + set.Add(set); + object decodedSetObject = reader.Deserialize(writer.Serialize(set)); + HashSet decodedSet = Assert.IsType>(decodedSetObject); + Assert.Contains(decodedSet, decodedSet); + + Dictionary map = []; + map["self"] = map; + object decodedMapObject = reader.Deserialize(writer.Serialize(map)); + NullableKeyDictionary decodedMap = + Assert.IsType>(decodedMapObject); + Assert.True(decodedMap.TryGetValue("self", out object? self)); + Assert.Same(decodedMap, self); + } + + [Fact] + public void ArraysAndMapsPublishRefs() + { + ForyRuntime writer = ForyRuntime.Builder().TrackRef(true).Build(); + ForyRuntime reader = ForyRuntime.Builder().TrackRef(false).Build(); + + object?[] array = new object?[1]; + array[0] = array; + object?[] decodedArray = reader.Deserialize(writer.Serialize(array)); + Assert.Same(decodedArray, decodedArray[0]); + + Dictionary map = []; + map["self"] = map; + Dictionary decodedMap = + reader.Deserialize>(writer.Serialize(map)); + Assert.Same(decodedMap, decodedMap["self"]); + } + + [Fact] + public void MutableCollectionsPublishRefs() + { + ForyRuntime writer = ForyRuntime.Builder().TrackRef(true).Build(); + ForyRuntime reader = ForyRuntime.Builder().TrackRef(false).Build(); + + LinkedList linkedList = []; + linkedList.AddLast(linkedList); + LinkedList decodedList = + reader.Deserialize>(writer.Serialize(linkedList)); + Assert.Same(decodedList, decodedList.First!.Value); + + Queue queue = []; + queue.Enqueue(queue); + Queue decodedQueue = reader.Deserialize>(writer.Serialize(queue)); + Assert.Same(decodedQueue, decodedQueue.Peek()); + + Stack stack = []; + stack.Push(stack); + Stack decodedStack = reader.Deserialize>(writer.Serialize(stack)); + Assert.Same(decodedStack, decodedStack.Peek()); + } + + [Fact] + public void UnionCycleRefFailsLoudly() + { + ForyRuntime writer = ForyRuntime.Builder().TrackRef(true).Build(); + writer.Register(956); + writer.Register(957); + ForyRuntime reader = ForyRuntime.Builder().TrackRef(false).Build(); + reader.Register(956); + reader.Register(957); + + AnyNode node = new(); + Union union = new(0, node); + node.Next = union; + + Assert.Throws(() => reader.Deserialize(writer.Serialize(union))); + } + [Fact] public void ThreadSafeForySupportsParallelPrimitiveRoundTrip() { @@ -2427,7 +2575,8 @@ public void GeneratedSerializerSupportsObjectKeyMap() }; DynamicAnyHolder decoded = fory.Deserialize(fory.Serialize(source)); - Dictionary dynamicMap = Assert.IsType>(decoded.AnyValue); + NullableKeyDictionary dynamicMap = + Assert.IsType>(decoded.AnyValue); Assert.Equal(9, dynamicMap["inner"]); Assert.Equal("ten", dynamicMap[10]); Assert.Equal(source.AnySet.Count, decoded.AnySet.Count); diff --git a/csharp/tests/Fory.Tests/GraphMemoryBudgetTests.cs b/csharp/tests/Fory.Tests/GraphMemoryBudgetTests.cs new file mode 100644 index 0000000000..7f00f2515e --- /dev/null +++ b/csharp/tests/Fory.Tests/GraphMemoryBudgetTests.cs @@ -0,0 +1,343 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +using System.Buffers; +using System.Collections.Immutable; +using System.Runtime.CompilerServices; +using Apache.Fory; +using ForyRuntime = Apache.Fory.Fory; +using S = Apache.Fory.Schema.Types; + +namespace Apache.Fory.Tests; + +[ForyStruct] +public sealed class BudgetItem +{ + public int Id { get; set; } + public string Name { get; set; } = string.Empty; +} + +[ForyStruct] +public sealed class BudgetEmpty +{ +} + +[ForyStruct] +public sealed class BudgetSiblings +{ + public List Left { get; set; } = []; + public List Right { get; set; } = []; +} + +[ForyStruct] +public sealed class BudgetArrayHolder +{ + public BudgetItem[] Values { get; set; } = []; +} + +[ForyStruct] +public struct BudgetValue +{ + public int Id { get; set; } +} + +[ForyStruct] +public sealed class GeneratedSchemaListBudget +{ + [ForyField(Type = typeof(S.List))] + public List Values { get; set; } = []; +} + +[ForyStruct] +public sealed class GeneratedPackedListBudget +{ + [ForyField(Type = typeof(S.Array))] + public List Values { get; set; } = []; +} + +[ForyStruct] +public sealed class GeneratedSchemaMapBudget +{ + [ForyField(Type = typeof(S.Map))] + public Dictionary Values { get; set; } = []; +} + +[ForyStruct] +public sealed class CompatibleBudgetList +{ + [ForyField(Type = typeof(S.List))] + public List Values { get; set; } = []; +} + +[ForyStruct] +public sealed class CompatibleBudgetArray +{ + [ForyField(Type = typeof(S.Array))] + public int[] Values { get; set; } = []; +} + +public sealed class GraphMemoryBudgetTests +{ + private const int ReferenceBytes = 4; + private const int ObjectBytes = 1; + private const long BudgetEmptyBytes = ObjectBytes; + private const long BudgetItemBytes = ObjectBytes + 4 + ReferenceBytes; + private const long BudgetSiblingsBytes = ObjectBytes + ReferenceBytes + ReferenceBytes; + private const long BudgetArrayHolderBytes = ObjectBytes + ReferenceBytes; + private const long GeneratedGraphHolderBytes = ObjectBytes + ReferenceBytes; + private const long BudgetValueBytes = 4; + private const long DefaultGraphMemoryBytes = 128L * 1024 * 1024; + + private static int ElementBytes() => typeof(T).IsValueType ? Unsafe.SizeOf() : ReferenceBytes; + + private static ForyRuntime NewFory(long maxGraphMemoryBytes = DefaultGraphMemoryBytes) + { + return ForyRuntime.Builder() + .Compatible(false) + .TrackRef(false) + .MaxGraphMemoryBytes(maxGraphMemoryBytes) + .Build() + .Register(1001) + .Register(1002) + .Register(1003) + .Register(1004) + .Register(1005) + .Register(1006) + .Register(1007) + .Register(1008); + } + + private static byte[] Serialize(T value) + { + return NewFory().Serialize(value); + } + + private static long ListBudget(int count) + { + return ObjectBytes + (long)count * ElementBytes(); + } + + private static long ArrayBudget(int count) + { + return ObjectBytes + (long)count * ElementBytes(); + } + + private static long MapBudget(int count) + { + return ObjectBytes + (long)count * (ElementBytes() + ElementBytes()); + } + + [Fact] + public void DefaultFixedBudgetAndValidation() + { + Assert.Equal(DefaultGraphMemoryBytes, NewFory().Config.MaxGraphMemoryBytes); + Assert.Throws(() => NewFory(0)); + Assert.Throws(() => NewFory(-2)); + + List> value = Enumerable.Range(0, 3).Select(_ => new List()).ToList(); + Assert.Equal(value.Count, NewFory().Deserialize>>(Serialize(value)).Count); + } + + [Fact] + public void ReadOnlySequenceUsesSameBudget() + { + const int count = 6; + List> value = Enumerable.Range(0, count).Select(_ => new List()).ToList(); + byte[] bytes = Serialize(value); + ReadOnlySequence sequence = new(bytes); + + Assert.Equal(count, NewFory().Deserialize>>(ref sequence).Count); + } + + [Fact] + public void ExplicitConfigOverridesDefault() + { + List value = Enumerable.Range(0, 8).Select(i => new BudgetItem { Id = i }).ToList(); + byte[] bytes = Serialize(value); + long required = ListBudget(value.Count) + value.Count * BudgetItemBytes; + + Assert.Throws(() => NewFory(required - 1).Deserialize>(bytes)); + List result = NewFory(required).Deserialize>(bytes); + Assert.Equal(value.Count, result.Count); + } + + [Fact] + public void EmptyObjectOwnerIsCharged() + { + List value = [new BudgetEmpty()]; + byte[] bytes = Serialize(value); + long required = ListBudget(value.Count) + value.Count * BudgetEmptyBytes; + + Assert.Throws(() => NewFory(required - 1).Deserialize>(bytes)); + Assert.Single(NewFory(required).Deserialize>(bytes)); + } + + [Fact] + public void SiblingContainersShareOneBudget() + { + BudgetSiblings value = new() + { + Left = Enumerable.Range(0, 16).Select(i => new BudgetItem { Id = i }).ToList(), + Right = Enumerable.Range(0, 16).Select(i => new BudgetItem { Id = i }).ToList(), + }; + byte[] bytes = Serialize(value); + long oneList = ListBudget(16) + 16 * BudgetItemBytes; + long required = BudgetSiblingsBytes + oneList * 2; + + Assert.Throws(() => NewFory(required - 1).Deserialize(bytes)); + BudgetSiblings result = NewFory(required).Deserialize(bytes); + Assert.Equal(16, result.Left.Count); + Assert.Equal(16, result.Right.Count); + } + + [Fact] + public void MapBudgetIsCharged() + { + Dictionary value = new() { ["a"] = 1, ["b"] = 2, ["c"] = 3 }; + byte[] bytes = Serialize(value); + long required = MapBudget(value.Count); + + Assert.Throws(() => NewFory(required - 1).Deserialize>(bytes)); + Dictionary result = NewFory(required).Deserialize>(bytes); + Assert.Equal(value, result); + } + + [Fact] + public void ArrayAndInlineListBudget() + { + BudgetArrayHolder holder = new() + { + Values = Enumerable.Range(0, 4).Select(i => new BudgetItem { Id = i }).ToArray(), + }; + byte[] holderBytes = Serialize(holder); + long holderRequired = + BudgetArrayHolderBytes + ArrayBudget(4) + holder.Values.Length * BudgetItemBytes; + Assert.Throws(() => NewFory(holderRequired - 1).Deserialize(holderBytes)); + Assert.Equal(4, NewFory(holderRequired).Deserialize(holderBytes).Values.Length); + + List ints = [1, 2, 3, 4]; + byte[] intBytes = Serialize(ints); + long listRequired = ListBudget(ints.Count); + Assert.Throws(() => NewFory(listRequired - 1).Deserialize>(intBytes)); + Assert.Equal(ints, NewFory(listRequired).Deserialize>(intBytes)); + } + + [Fact] + public void ValueStructOwnerIsChargedByHolder() + { + BudgetValue value = new() { Id = 7 }; + byte[] valueBytes = Serialize(value); + Assert.Throws(() => NewFory(BudgetValueBytes - 1).Deserialize(valueBytes)); + Assert.Equal(value.Id, NewFory(BudgetValueBytes).Deserialize(valueBytes).Id); + + List values = Enumerable.Range(0, 4).Select(i => new BudgetValue { Id = i }).ToList(); + byte[] listBytes = Serialize(values); + long listRequired = ListBudget(values.Count); + Assert.Throws(() => NewFory(listRequired - 1).Deserialize>(listBytes)); + Assert.Equal(values.Select(v => v.Id), NewFory(listRequired).Deserialize>(listBytes).Select(v => v.Id)); + } + + [Fact] + public void GeneratedSchemaContainersAreCharged() + { + GeneratedSchemaListBudget list = new() { Values = [1, 2, 3, 4, 5, 6] }; + byte[] listBytes = Serialize(list); + long listRequired = GeneratedGraphHolderBytes + ListBudget(list.Values.Count); + Assert.Throws(() => NewFory(listRequired - 1).Deserialize(listBytes)); + Assert.Equal(list.Values, NewFory(listRequired).Deserialize(listBytes).Values); + + GeneratedPackedListBudget packed = new() { Values = [1, 2, 3, 4, 5, 6] }; + byte[] packedBytes = Serialize(packed); + long packedRequired = GeneratedGraphHolderBytes + ListBudget(packed.Values.Count); + Assert.Throws(() => NewFory(packedRequired - 1).Deserialize(packedBytes)); + Assert.Equal(packed.Values, NewFory(packedRequired).Deserialize(packedBytes).Values); + + GeneratedSchemaMapBudget map = new() + { + Values = new Dictionary { [1] = 1, [2] = 2, [3] = 3 }, + }; + byte[] mapBytes = Serialize(map); + long mapRequired = GeneratedGraphHolderBytes + MapBudget(map.Values.Count); + Assert.Throws(() => NewFory(mapRequired - 1).Deserialize(mapBytes)); + Assert.Equal(map.Values, NewFory(mapRequired).Deserialize(mapBytes).Values); + } + + [Fact] + public void ConversionCollectionsAreChargedOnce() + { + long required = ListBudget(3); + + Check(new HashSet { 1, 2, 3 }, v => v.SetEquals([1, 2, 3])); + Check(new SortedSet { 1, 2, 3 }, v => v.SetEquals([1, 2, 3])); + Check(ImmutableHashSet.Create(1, 2, 3), v => v.SetEquals([1, 2, 3])); + Check(new LinkedList([1, 2, 3]), v => v.SequenceEqual([1, 2, 3])); + + Queue queue = new(); + queue.Enqueue(1); + queue.Enqueue(2); + queue.Enqueue(3); + Check(queue, v => v.SequenceEqual([1, 2, 3])); + + Stack stack = new(); + stack.Push(1); + stack.Push(2); + stack.Push(3); + Check(stack, v => v.SequenceEqual([3, 2, 1])); + + void Check(T value, Func assertValue) + { + byte[] bytes = Serialize(value); + Assert.Throws(() => NewFory(required - 1).Deserialize(bytes)); + Assert.True(assertValue(NewFory(required).Deserialize(bytes))); + } + } + + [Fact] + public void DenseLeafOwnersAreSkipped() + { + Assert.Equal("budget", NewFory(1).Deserialize(Serialize("budget"))); + Assert.Equal(new byte[] { 1, 2, 3 }, NewFory(1).Deserialize(Serialize(new byte[] { 1, 2, 3 }))); + Assert.Equal(new[] { 1, 2, 3 }, NewFory(1).Deserialize(Serialize(new[] { 1, 2, 3 }))); + } + + [Fact] + public void CompatibleListToDenseArrayIsSkipped() + { + ForyRuntime writer = ForyRuntime.Builder().Compatible(true).TrackRef(false).Build(); + writer.Register(1010); + byte[] bytes = writer.Serialize(new CompatibleBudgetList { Values = [1, 2, 3] }); + + ForyRuntime reader = ForyRuntime.Builder() + .Compatible(true) + .TrackRef(false) + .MaxGraphMemoryBytes(GeneratedGraphHolderBytes) + .Build(); + reader.Register(1010); + + Assert.Equal(new[] { 1, 2, 3 }, reader.Deserialize(bytes).Values); + } + + [Fact] + public void ByteChecksRejectLargeLength() + { + byte[] bytes = Serialize(new List()); + bytes[^1] = 64; + Array.Resize(ref bytes, bytes.Length + 1); + + Assert.Throws(() => NewFory().Deserialize>(bytes)); + } +} diff --git a/dart/packages/fory-test/test/generated_round_trip_test.dart b/dart/packages/fory-test/test/generated_round_trip_test.dart index 2d8b603511..6dd2026309 100644 --- a/dart/packages/fory-test/test/generated_round_trip_test.dart +++ b/dart/packages/fory-test/test/generated_round_trip_test.dart @@ -28,8 +28,8 @@ import 'package:test/test.dart'; void main() { group('configuration', () { test('defaults to compatible mode unless explicitly set', () { - expect(const Config().compatible, isTrue); - expect(const Config(compatible: false).compatible, isFalse); + expect(Config().compatible, isTrue); + expect(Config(compatible: false).compatible, isFalse); }); }); diff --git a/dart/packages/fory/lib/src/codegen/fory_generator.dart b/dart/packages/fory/lib/src/codegen/fory_generator.dart index cae611048a..5595fac063 100644 --- a/dart/packages/fory/lib/src/codegen/fory_generator.dart +++ b/dart/packages/fory/lib/src/codegen/fory_generator.dart @@ -628,9 +628,11 @@ final class ForyGenerator extends Generator { ..writeln() ..writeln(' @override') ..writeln(' ${structSpec.name} read(ReadContext context) {'); + final graphObjectBytes = _graphObjectBytes(structSpec); switch (structSpec.constructionModel.mode) { case _ConstructorMode.mutable: + output.writeln(' context.reserveGraphMemory($graphObjectBytes);'); output.writeln(' final value = ${structSpec.name}();'); if (_structNeedsEarlyReadReference(structSpec)) { output @@ -755,6 +757,7 @@ final class ForyGenerator extends Generator { } } final constructorInvocation = _constructorInvocation(structSpec); + output.writeln(' context.reserveGraphMemory($graphObjectBytes);'); output.writeln(' final value = $constructorInvocation;'); for (final fieldName in structSpec.constructionModel.postConstructionFieldNames) { @@ -809,6 +812,9 @@ final class ForyGenerator extends Generator { ); switch (structSpec.constructionModel.mode) { case _ConstructorMode.mutable: + output.writeln( + ' context.reserveGraphMemory(${_graphObjectBytes(structSpec)});', + ); output.writeln(' final value = ${structSpec.name}();'); if (_structNeedsEarlyReadReference(structSpec)) { output @@ -999,6 +1005,9 @@ final class ForyGenerator extends Generator { ..writeln(' }'); } final constructorInvocation = _constructorInvocation(structSpec); + output.writeln( + ' context.reserveGraphMemory(${_graphObjectBytes(structSpec)});', + ); output.writeln(' final value = $constructorInvocation;'); for (final fieldName in structSpec.constructionModel.postConstructionFieldNames) { @@ -1218,6 +1227,9 @@ final class ForyGenerator extends Generator { return '${structSpec.name}($arguments)'; } + int _graphObjectBytes(_GeneratedStructSpec structSpec) => + 1 + structSpec.fields.length * 4; + bool _isSkipped(FieldElement field) { final annotation = _fieldAnnotationOf(field); if (annotation == null) { diff --git a/dart/packages/fory/lib/src/config.dart b/dart/packages/fory/lib/src/config.dart index d5d248cd36..88c3f36a0d 100644 --- a/dart/packages/fory/lib/src/config.dart +++ b/dart/packages/fory/lib/src/config.dart @@ -28,6 +28,7 @@ final class Config { static const int defaultMaxTypeMetaBytes = 4096; static const int defaultMaxSchemaVersionsPerType = 10; static const int defaultMaxAverageSchemaVersionsPerType = 3; + static const int defaultMaxGraphMemoryBytes = 128 * 1024 * 1024; /// Enables compatible struct encoding and decoding. /// @@ -56,29 +57,54 @@ final class Config { /// types. final int maxAverageSchemaVersionsPerType; + /// Maximum estimated graph memory per root deserialization. + /// + /// Value must be a positive byte limit. + final int maxGraphMemoryBytes; + /// Creates an immutable configuration object. /// /// Invalid numeric limits fail fast. When [compatible] is `true`, /// [checkStructVersion] is normalized to `false`. - const Config({ + Config({ this.compatible = true, bool checkStructVersion = true, - this.maxDepth = defaultMaxDepth, - this.maxTypeFields = defaultMaxTypeFields, - this.maxTypeMetaBytes = defaultMaxTypeMetaBytes, - this.maxSchemaVersionsPerType = defaultMaxSchemaVersionsPerType, - this.maxAverageSchemaVersionsPerType = + int maxDepth = defaultMaxDepth, + int maxTypeFields = defaultMaxTypeFields, + int maxTypeMetaBytes = defaultMaxTypeMetaBytes, + int maxSchemaVersionsPerType = defaultMaxSchemaVersionsPerType, + int maxAverageSchemaVersionsPerType = defaultMaxAverageSchemaVersionsPerType, + int maxGraphMemoryBytes = defaultMaxGraphMemoryBytes, }) : checkStructVersion = compatible ? false : checkStructVersion, - assert(maxDepth > 0, 'maxDepth must be positive'), - assert(maxTypeFields > 0, 'maxTypeFields must be positive'), - assert(maxTypeMetaBytes > 0, 'maxTypeMetaBytes must be positive'), - assert( - maxSchemaVersionsPerType > 0, - 'maxSchemaVersionsPerType must be positive', + maxDepth = _positive(maxDepth, 'maxDepth'), + maxTypeFields = _positive(maxTypeFields, 'maxTypeFields'), + maxTypeMetaBytes = _positive(maxTypeMetaBytes, 'maxTypeMetaBytes'), + maxSchemaVersionsPerType = _positive( + maxSchemaVersionsPerType, + 'maxSchemaVersionsPerType', + ), + maxAverageSchemaVersionsPerType = _positive( + maxAverageSchemaVersionsPerType, + 'maxAverageSchemaVersionsPerType', ), - assert( - maxAverageSchemaVersionsPerType > 0, - 'maxAverageSchemaVersionsPerType must be positive', + maxGraphMemoryBytes = _positiveSafeInteger( + maxGraphMemoryBytes, + 'maxGraphMemoryBytes', ); + + static int _positive(int value, String name) { + if (value <= 0) { + throw ArgumentError.value(value, name, 'must be positive'); + } + return value; + } + + static int _positiveSafeInteger(int value, String name) { + const maxSafeInteger = 9007199254740991; + if (value <= 0 || value > maxSafeInteger) { + throw ArgumentError.value(value, name, 'must be a positive safe integer'); + } + return value; + } } diff --git a/dart/packages/fory/lib/src/context/read_context.dart b/dart/packages/fory/lib/src/context/read_context.dart index faa8191aba..f15022dcfd 100644 --- a/dart/packages/fory/lib/src/context/read_context.dart +++ b/dart/packages/fory/lib/src/context/read_context.dart @@ -54,6 +54,7 @@ final class ReadContext { late Buffer _buffer; final List _sharedTypes = []; int _depth = 0; + int _remainingGraphMemoryBytes = 0; @internal ReadContext( @@ -64,8 +65,10 @@ final class ReadContext { ); @internal + @pragma('vm:prefer-inline') void prepare(Buffer buffer) { _buffer = buffer; + _remainingGraphMemoryBytes = config.maxGraphMemoryBytes; } @internal @@ -74,6 +77,7 @@ final class ReadContext { _refReader.reset(); _metaStringReader.reset(); _depth = 0; + _remainingGraphMemoryBytes = 0; } /// The active input buffer for the current operation. @@ -85,6 +89,34 @@ final class ReadContext { @internal RefReader get refReader => _refReader; + @internal + @pragma('vm:prefer-inline') + void reserveGraphMemory(int bytes) { + if (bytes < 0) { + _throwGraphMemoryOverflow(bytes); + } + if (bytes > _remainingGraphMemoryBytes) { + _throwGraphMemoryExceeded(bytes); + } + _remainingGraphMemoryBytes -= bytes; + } + + @pragma('vm:never-inline') + Never _throwGraphMemoryOverflow(int bytes) { + throw StateError( + 'maxGraphMemoryBytes overflow: requested $bytes estimated graph bytes.', + ); + } + + @pragma('vm:never-inline') + Never _throwGraphMemoryExceeded(int bytes) { + throw StateError( + 'maxGraphMemoryBytes exceeded: requested $bytes estimated graph bytes, ' + '$_remainingGraphMemoryBytes remaining, effective limit ' + '${config.maxGraphMemoryBytes}.', + ); + } + @internal @pragma('vm:prefer-inline') TypeInfo readTypeMetaValue([TypeInfo? expectedNamedType]) => diff --git a/dart/packages/fory/lib/src/fory.dart b/dart/packages/fory/lib/src/fory.dart index adc6091a8d..bebc7f7896 100644 --- a/dart/packages/fory/lib/src/fory.dart +++ b/dart/packages/fory/lib/src/fory.dart @@ -62,6 +62,7 @@ final class Fory { int maxSchemaVersionsPerType = Config.defaultMaxSchemaVersionsPerType, int maxAverageSchemaVersionsPerType = Config.defaultMaxAverageSchemaVersionsPerType, + int maxGraphMemoryBytes = Config.defaultMaxGraphMemoryBytes, }) { final config = Config( compatible: compatible, @@ -71,6 +72,7 @@ final class Fory { maxTypeMetaBytes: maxTypeMetaBytes, maxSchemaVersionsPerType: maxSchemaVersionsPerType, maxAverageSchemaVersionsPerType: maxAverageSchemaVersionsPerType, + maxGraphMemoryBytes: maxGraphMemoryBytes, ); _readBuffer = Buffer(); _writeBuffer = Buffer(); diff --git a/dart/packages/fory/lib/src/serializer/collection_serializers.dart b/dart/packages/fory/lib/src/serializer/collection_serializers.dart index 4e2a8050c0..8245a3492e 100644 --- a/dart/packages/fory/lib/src/serializer/collection_serializers.dart +++ b/dart/packages/fory/lib/src/serializer/collection_serializers.dart @@ -38,6 +38,9 @@ import 'package:fory/src/types/float16.dart'; import 'package:fory/src/types/int64.dart'; import 'package:fory/src/types/uint64.dart'; +const int _referenceBytes = 4; +const int _ownerBytes = 1; + @pragma('vm:prefer-inline') void _writeDirectTypeInfoValue( WriteContext context, @@ -378,13 +381,25 @@ final class SetSerializer extends Serializer { FieldType? elementFieldType, { bool hasPreservedRef = false, }) { - return Set.of( - ListSerializer.readPayload( - context, - elementFieldType, - hasPreservedRef: hasPreservedRef, - ), - ); + final state = _prepareListRead(context, elementFieldType); + context.buffer.checkReadableBytes(state.size); + final result = {}; + if (hasPreservedRef) { + context.reference(result); + } + if (state.size == 0) { + return result; + } + if (state.tracksDepth) { + context.increaseDepth(); + } + for (var index = 0; index < state.size; index += 1) { + result.add(_readPreparedListItem(context, state)); + } + if (state.tracksDepth) { + context.decreaseDepth(); + } + return result; } } @@ -429,7 +444,7 @@ Object? readCompatibleMatchedCollectionArrayField( ); } final raw = readCompatibleField(context, remoteField); - return _arrayToListValue(raw); + return _arrayToListValue(context, raw); } return readFieldValue(context, localField); } @@ -625,11 +640,13 @@ void _setArrayValue(Object target, int arrayTypeId, int index, Object? value) { } } -Object _arrayToListValue(Object? raw) { +Object _arrayToListValue(ReadContext context, Object? raw) { if (raw is BoolList) { + context.reserveGraphMemory(_ownerBytes + raw.length * _referenceBytes); return raw.toList(); } if (raw is Iterable) { + context.reserveGraphMemory(_ownerBytes + raw.length * _referenceBytes); return raw.toList(); } throw StateError('Expected compatible array payload.'); @@ -910,6 +927,7 @@ _PreparedListRead _prepareListRead( FieldType? elementFieldType, ) { final size = context.buffer.readVarUint32(); + context.reserveGraphMemory(_ownerBytes + size * _referenceBytes); if (size == 0) { return _PreparedListRead( size: 0, diff --git a/dart/packages/fory/lib/src/serializer/map_serializers.dart b/dart/packages/fory/lib/src/serializer/map_serializers.dart index 051454c3d6..f0aae14c72 100644 --- a/dart/packages/fory/lib/src/serializer/map_serializers.dart +++ b/dart/packages/fory/lib/src/serializer/map_serializers.dart @@ -26,6 +26,8 @@ import 'package:fory/src/resolver/type_resolver.dart'; import 'package:fory/src/serializer/collection_serializers.dart'; import 'package:fory/src/serializer/serializer.dart'; +const int _referenceBytes = 4; + abstract final class MapFlags { static const int trackingKeyRef = 0x01; static const int keyHasNull = 0x02; @@ -257,6 +259,8 @@ Map readTypedMapPayload( bool hasPreservedRef = false, }) { var remaining = context.buffer.readVarUint32(); + context.reserveGraphMemory(1 + remaining * 2 * _referenceBytes); + context.buffer.checkReadableBytes(remaining); final declaredKeyTypeInfo = keyFieldType == null || keyFieldType.isDynamic ? null diff --git a/dart/packages/fory/test/collection_serializer_test.dart b/dart/packages/fory/test/collection_serializer_test.dart index a42e5d701c..9704462a99 100644 --- a/dart/packages/fory/test/collection_serializer_test.dart +++ b/dart/packages/fory/test/collection_serializer_test.dart @@ -257,6 +257,19 @@ void main() { expect(identical(roundTrip, roundTrip[0]), isTrue); }); + test('trackRef true preserves self-referential root sets', () { + final fory = Fory(); + final value = {}; + value.add(value); + + final roundTrip = + fory.deserialize(fory.serialize(value, trackRef: true)) + as Set; + + expect(roundTrip, hasLength(1)); + expect(identical(roundTrip, roundTrip.single), isTrue); + }); + test('trackRef true preserves self-referential root maps', () { final fory = Fory(); final value = {}; diff --git a/dart/packages/fory/test/graph_memory_budget_test.dart b/dart/packages/fory/test/graph_memory_budget_test.dart new file mode 100644 index 0000000000..1d3672c942 --- /dev/null +++ b/dart/packages/fory/test/graph_memory_budget_test.dart @@ -0,0 +1,324 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +import 'dart:typed_data'; + +import 'package:fory/fory.dart'; +import 'package:fory/src/context/meta_string_reader.dart'; +import 'package:fory/src/context/ref_reader.dart'; +import 'package:fory/src/resolver/type_resolver.dart'; +import 'package:fory/src/serializer/collection_serializers.dart'; +import 'package:fory/src/serializer/map_serializers.dart'; +import 'package:test/test.dart'; + +part 'graph_memory_budget_test.fory.dart'; + +const Matcher _throwsGraphBudget = ThrowsGraphBudget(); +const int _defaultGraphMemoryBytes = 128 * 1024 * 1024; +const int _objectBytes = 1; +const int _referenceBytes = 4; + +int _objectGraphBytes(int fields) => _objectBytes + fields * _referenceBytes; +int _listGraphBytes(int count) => _objectBytes + count * _referenceBytes; +int _mapGraphBytes(int count) => _objectBytes + count * 2 * _referenceBytes; + +@ForyStruct() +class BudgetGeneratedEnvelope { + BudgetGeneratedEnvelope(); + + @ListField(element: Int32Type(encoding: Encoding.fixed)) + List ids = []; + + @SetField(element: StringType()) + Set tags = {}; + + @MapField(key: StringType(), value: Int32Type(encoding: Encoding.fixed)) + Map counts = {}; +} + +@ForyStruct() +class BudgetCompatibleListEnvelope { + BudgetCompatibleListEnvelope(); + + @ListField(element: Int32Type(encoding: Encoding.fixed)) + List values = []; +} + +@ForyStruct() +class BudgetCompatibleArrayEnvelope { + BudgetCompatibleArrayEnvelope(); + + @ArrayField(element: Int32Type()) + Int32List values = Int32List(0); +} + +final class ThrowsGraphBudget extends Matcher { + const ThrowsGraphBudget(); + + @override + Description describe(Description description) { + return description.add('throws a maxGraphMemoryBytes StateError'); + } + + @override + bool matches(Object? item, Map matchState) { + if (item is! Function) { + return false; + } + try { + item(); + } on StateError catch (error) { + return error.message.contains('maxGraphMemoryBytes'); + } + return false; + } +} + +void _registerGenerated(Fory fory) { + GraphMemoryBudgetTestForyModule.register( + fory, + BudgetGeneratedEnvelope, + name: 'test.BudgetGeneratedEnvelope', + ); +} + +void _registerCompatibleList(Fory fory) { + GraphMemoryBudgetTestForyModule.register( + fory, + BudgetCompatibleListEnvelope, + name: 'test.BudgetCompatibleEnvelope', + ); +} + +void _registerCompatibleArray(Fory fory) { + GraphMemoryBudgetTestForyModule.register( + fory, + BudgetCompatibleArrayEnvelope, + name: 'test.BudgetCompatibleEnvelope', + ); +} + +ReadContext _readContext( + Buffer buffer, { + int maxGraphMemoryBytes = _defaultGraphMemoryBytes, +}) { + final config = Config(maxGraphMemoryBytes: maxGraphMemoryBytes); + final resolver = TypeResolver(config); + return ReadContext(config, resolver, RefReader(), MetaStringReader(resolver)) + ..prepare(buffer); +} + +Uint8List _serialize(Object? value) => Fory().serialize(value); + +Object? _readWithBudget(Object? value, int budget) { + return Fory( + maxGraphMemoryBytes: budget, + ).deserialize(_serialize(value)); +} + +void main() { + group('graph memory budget', () { + test('fixed default applies to roots', () { + final buffer = Buffer.wrap(Uint8List(17)); + final context = _readContext(buffer); + + expect( + () => context.reserveGraphMemory(_defaultGraphMemoryBytes), + returnsNormally, + ); + expect(() => context.reserveGraphMemory(1), _throwsGraphBudget); + }); + + test('explicit config overrides default and invalid config fails', () { + final buffer = Buffer.wrap(Uint8List(4096)); + final context = _readContext(buffer, maxGraphMemoryBytes: 31); + + expect(() => context.reserveGraphMemory(31), returnsNormally); + expect(() => context.reserveGraphMemory(1), _throwsGraphBudget); + + expect(() => Fory(maxGraphMemoryBytes: 0), throwsA(isA())); + expect( + () => Fory(maxGraphMemoryBytes: -2), + throwsA(isA()), + ); + }); + + test('uses parent storage for nested empty containers', () { + final value = [[]]; + + expect( + () => + _readWithBudget(value, _listGraphBytes(1) + _listGraphBytes(0) - 1), + _throwsGraphBudget, + ); + expect( + _readWithBudget(value, _listGraphBytes(1) + _listGraphBytes(0)), + equals(value), + ); + }); + + test('reserves sibling containers cumulatively', () { + final value = [[], [], []]; + + expect( + () => _readWithBudget( + value, + _listGraphBytes(3) + 3 * _listGraphBytes(0) - 1, + ), + _throwsGraphBudget, + ); + expect( + _readWithBudget(value, _listGraphBytes(3) + 3 * _listGraphBytes(0)), + equals(value), + ); + }); + + test('reserves map entries', () { + final value = {'a': 1}; + + expect( + () => _readWithBudget(value, _mapGraphBytes(1) - 1), + _throwsGraphBudget, + ); + expect(_readWithBudget(value, _mapGraphBytes(1)), equals(value)); + }); + + test('reserves generic set owner once', () { + final value = {'x', 7}; + final required = _listGraphBytes(value.length); + + expect(() => _readWithBudget(value, required - 1), _throwsGraphBudget); + expect(_readWithBudget(value, required), equals(value)); + }); + + test('reserves generated list set and map reads', () { + final writer = Fory(); + _registerGenerated(writer); + final bytes = writer.serialize( + BudgetGeneratedEnvelope() + ..ids = [1] + ..tags = {'x'} + ..counts = {'one': 1}, + ); + + final required = + _objectGraphBytes(3) + + _listGraphBytes(1) + + _listGraphBytes(1) + + _mapGraphBytes(1); + final failingReader = Fory(maxGraphMemoryBytes: required - 1); + _registerGenerated(failingReader); + expect( + () => failingReader.deserialize(bytes), + _throwsGraphBudget, + ); + + final passingReader = Fory(maxGraphMemoryBytes: required); + _registerGenerated(passingReader); + final roundTrip = passingReader.deserialize( + bytes, + ); + expect(roundTrip.ids, equals([1])); + expect(roundTrip.tags, equals({'x'})); + expect(roundTrip.counts, equals({'one': 1})); + }); + + test('skips compatible list to typed array leaf', () { + final listWriter = Fory(); + _registerCompatibleList(listWriter); + final listBytes = listWriter.serialize( + BudgetCompatibleListEnvelope()..values = [1, 2, 3], + ); + + final arrayRequired = _objectGraphBytes(1); + final arrayFail = Fory(maxGraphMemoryBytes: arrayRequired - 1); + _registerCompatibleArray(arrayFail); + expect( + () => arrayFail.deserialize(listBytes), + _throwsGraphBudget, + ); + + final arrayPass = Fory(maxGraphMemoryBytes: arrayRequired); + _registerCompatibleArray(arrayPass); + expect( + arrayPass + .deserialize(listBytes) + .values + .toList(), + equals([1, 2, 3]), + ); + + final arrayWriter = Fory(); + _registerCompatibleArray(arrayWriter); + final arrayBytes = arrayWriter.serialize( + BudgetCompatibleArrayEnvelope() + ..values = Int32List.fromList([1, 2, 3]), + ); + + final listRequired = _objectGraphBytes(1) + _listGraphBytes(3); + final listFail = Fory(maxGraphMemoryBytes: listRequired - 1); + _registerCompatibleList(listFail); + expect( + () => listFail.deserialize(arrayBytes), + _throwsGraphBudget, + ); + + final listPass = Fory(maxGraphMemoryBytes: listRequired); + _registerCompatibleList(listPass); + expect( + listPass.deserialize(arrayBytes).values, + equals([1, 2, 3]), + ); + }); + + test('skips strings binary and dense typed arrays', () { + final fory = Fory(maxGraphMemoryBytes: 1); + final text = List.filled(128, 'x').join(); + + expect(fory.deserialize(Fory().serialize(text)), hasLength(128)); + expect( + fory.deserialize(Fory().serialize(Uint8List(128))).length, + equals(128), + ); + expect( + fory.deserialize(Fory().serialize(Int32List(32))).length, + equals(32), + ); + }); + + test('keeps byte availability checks before allocation', () { + final listBuffer = + Buffer() + ..writeVarUint32(64) + ..writeUint8(0); + final listContext = _readContext(listBuffer); + expect( + () => ListSerializer.readPayload(listContext, null), + throwsStateError, + ); + + final mapBuffer = Buffer()..writeVarUint32(64); + final mapContext = _readContext(mapBuffer); + expect( + () => MapSerializer.readPayload(mapContext, null, null), + throwsStateError, + ); + }); + }); +} diff --git a/dart/packages/fory/test/scalar_and_typed_array_serializer_test.dart b/dart/packages/fory/test/scalar_and_typed_array_serializer_test.dart index 7ef820c666..a3d64234e2 100644 --- a/dart/packages/fory/test/scalar_and_typed_array_serializer_test.dart +++ b/dart/packages/fory/test/scalar_and_typed_array_serializer_test.dart @@ -62,7 +62,7 @@ FieldInfo _compatibleScalarField({ } ReadContext _compatibleReadContext(Buffer buffer) { - const config = Config(); + final config = Config(); final resolver = TypeResolver(config); return ReadContext(config, resolver, RefReader(), MetaStringReader(resolver)) ..prepare(buffer); diff --git a/dart/packages/fory/test/xlang_protocol_test.dart b/dart/packages/fory/test/xlang_protocol_test.dart index 21e2ada7d3..74b2bcd140 100644 --- a/dart/packages/fory/test/xlang_protocol_test.dart +++ b/dart/packages/fory/test/xlang_protocol_test.dart @@ -151,7 +151,7 @@ void _rememberLateHolder() { } Uint8List _lateHolderTypeDefBytes({required bool registerExtFirst}) { - final resolver = TypeResolver(const Config()); + final resolver = TypeResolver(Config()); _rememberLateHolder(); if (registerExtFirst) { resolver.registerSerializer( @@ -183,7 +183,7 @@ Uint8List _typeMetaBytes( String name, List fields, ) { - final resolver = TypeResolver(const Config()); + final resolver = TypeResolver(Config()); _rememberSchema(type, fields); final parts = name.split('.'); resolver.registerGenerated( @@ -203,7 +203,7 @@ Uint8List _typeMetaBytes( } Uint8List _enumTypeMetaBytes(Type type, String name) { - final resolver = TypeResolver(const Config()); + final resolver = TypeResolver(Config()); _rememberEnum(type); final parts = name.split('.'); resolver.registerGenerated( @@ -372,7 +372,7 @@ void main() { test('remote schema limit rejects extra versions', () { const name = 'example.Unknown'; - final reader = TypeResolver(const Config(maxSchemaVersionsPerType: 1)); + final reader = TypeResolver(Config(maxSchemaVersionsPerType: 1)); _rememberSchema(_SchemaLocal, []); reader.registerGenerated( _SchemaLocal, @@ -393,7 +393,7 @@ void main() { test('named enum TypeDef uses metadata byte limit', () { const name = 'example.RemoteEnum'; - final reader = TypeResolver(const Config(maxTypeMetaBytes: 1)); + final reader = TypeResolver(Config(maxTypeMetaBytes: 1)); _rememberEnum(_SchemaLocal); final bytes = _enumTypeMetaBytes(_SchemaRemoteA, name); @@ -402,7 +402,7 @@ void main() { test('registered named enum TypeDef uses metadata byte limit', () { const name = 'example.RemoteEnum'; - final reader = TypeResolver(const Config(maxTypeMetaBytes: 1)); + final reader = TypeResolver(Config(maxTypeMetaBytes: 1)); _rememberEnum(_SchemaLocal); reader.registerGenerated( _SchemaLocal, @@ -416,7 +416,7 @@ void main() { test('exact local named enum TypeDef is accepted', () { const name = 'example.SharedEnum'; - final reader = TypeResolver(const Config(maxSchemaVersionsPerType: 1)); + final reader = TypeResolver(Config(maxSchemaVersionsPerType: 1)); _rememberEnum(_SchemaLocal); reader.registerGenerated( _SchemaLocal, @@ -429,7 +429,7 @@ void main() { }); test('type meta field limit rejects large struct', () { - final reader = TypeResolver(const Config(maxTypeFields: 1)); + final reader = TypeResolver(Config(maxTypeFields: 1)); final bytes = _typeMetaBytes( _SchemaRemoteA, 'example.TooManyFields', @@ -443,7 +443,7 @@ void main() { }); test('type meta body limit rejects large metadata', () { - final reader = TypeResolver(const Config(maxTypeMetaBytes: 1)); + final reader = TypeResolver(Config(maxTypeMetaBytes: 1)); final bytes = _typeMetaBytes( _SchemaRemoteA, 'example.LargeTypeMeta', @@ -454,7 +454,7 @@ void main() { }); test('remote schema limit keeps unknown types separate', () { - final reader = TypeResolver(const Config(maxSchemaVersionsPerType: 1)); + final reader = TypeResolver(Config(maxSchemaVersionsPerType: 1)); _rememberSchema(_SchemaLocal, []); reader.registerGenerated( _SchemaLocal, @@ -484,7 +484,7 @@ void main() { test('failed remote schema does not consume schema limit', () { const name = 'example.Accepted'; - final reader = TypeResolver(const Config(maxSchemaVersionsPerType: 1)); + final reader = TypeResolver(Config(maxSchemaVersionsPerType: 1)); _rememberSchema(_SchemaLocal, [ _generatedField('value'), ]); diff --git a/docs/guide/cpp/configuration.md b/docs/guide/cpp/configuration.md index d617450041..906796431e 100644 --- a/docs/guide/cpp/configuration.md +++ b/docs/guide/cpp/configuration.md @@ -96,6 +96,32 @@ When enabled, avoids duplicating shared objects and handles cycles. **Default:** `true` +### max_graph_memory_bytes(int64_t) + +Set the maximum estimated shallow graph memory accepted during one root +deserialization. + +```cpp +auto fory = Fory::builder() + .max_graph_memory_bytes(64 * 1024 * 1024) + .build(); +``` + +The default limit is a fixed `128 MiB` for byte-array, `Buffer`, and stream +roots. Positive values override the default. Explicit non-positive values are +rejected when the runtime is created. + +This budget is a portable lower-bound estimate for shallow materialized graph +owners such as dynamic collection backing storage, map key/value storage, +object/reference array slots, and struct or object field storage. It is not an +exact process heap limit and does not include STL implementation details such as +debug nodes, table buckets, or allocator headers. Dedicated string, binary, and +primitive dense-array payloads +continue to rely on their byte-availability checks instead. `std::vector` +is counted as packed standard-container storage. + +**Default:** `128 MiB` + ### max_dyn_depth(uint32_t) Set maximum allowed nesting depth for dynamically-typed objects. @@ -200,17 +226,18 @@ auto fory = Fory::builder().build_thread_safe(); // Returns ThreadSafeFory ## Configuration Summary -| Option | Description | Default | -| ------------------------------------------------ | ------------------------------------------------- | ------- | -| `xlang(bool)` | Use xlang mode | `true` | -| `compatible(bool)` | Enable schema evolution | `true` | -| `track_ref(bool)` | Enable reference tracking | `true` | -| `max_dyn_depth(uint32_t)` | Maximum nesting depth for dynamic types | `5` | -| `max_type_fields(uint32_t)` | Max fields in one received struct metadata body | `512` | -| `max_type_meta_bytes(uint32_t)` | Max encoded bytes in one received metadata body | `4096` | -| `max_schema_versions_per_type(uint32_t)` | Max remote metadata versions for one logical type | `10` | -| `max_average_schema_versions_per_type(uint32_t)` | Average remote metadata versions across types | `3` | -| `check_struct_version(bool)` | Enable struct version checking | `false` | +| Option | Description | Default | +| ------------------------------------------------ | ------------------------------------------------- | --------- | +| `xlang(bool)` | Use xlang mode | `true` | +| `compatible(bool)` | Enable schema evolution | `true` | +| `track_ref(bool)` | Enable reference tracking | `true` | +| `max_graph_memory_bytes(int64_t)` | Max estimated graph memory per root read | `128 MiB` | +| `max_dyn_depth(uint32_t)` | Maximum nesting depth for dynamic types | `5` | +| `max_type_fields(uint32_t)` | Max fields in one received struct metadata body | `512` | +| `max_type_meta_bytes(uint32_t)` | Max encoded bytes in one received metadata body | `4096` | +| `max_schema_versions_per_type(uint32_t)` | Max remote metadata versions for one logical type | `10` | +| `max_average_schema_versions_per_type(uint32_t)` | Average remote metadata versions across types | `3` | +| `check_struct_version(bool)` | Enable struct version checking | `false` | ## Security @@ -218,6 +245,8 @@ Security-related configuration: - Register all structs and polymorphic implementations before deserializing untrusted payloads. - Use `check_struct_version(true)` with `compatible(false)` for intentional same-schema payloads. +- Keep `max_graph_memory_bytes(...)` at the fixed `128 MiB` default for most inputs, or set a + positive value for a trusted workload that needs a different envelope. - Keep `max_dyn_depth(...)` as low as your model permits to reject unexpectedly deep polymorphic graphs. - Keep the remote schema metadata limits at their defaults unless the data is not malicious and a diff --git a/docs/guide/csharp/basic-serialization.md b/docs/guide/csharp/basic-serialization.md index d112990ccd..5c88675bc8 100644 --- a/docs/guide/csharp/basic-serialization.md +++ b/docs/guide/csharp/basic-serialization.md @@ -104,6 +104,11 @@ byte[] payload = fory.Serialize(value); object? decoded = fory.Deserialize(payload); ``` +Dynamic maps normally decode as `Dictionary` when they have no +null key. If the payload uses reference tracking for the dynamic map itself, C# +returns `NullableKeyDictionary` so nested references and null +keys point to the decoded map owner. + ## Buffer Writer API Serialize directly into `IBufferWriter` targets. diff --git a/docs/guide/csharp/configuration.md b/docs/guide/csharp/configuration.md index e7c0c24d42..841964d955 100644 --- a/docs/guide/csharp/configuration.md +++ b/docs/guide/csharp/configuration.md @@ -35,16 +35,17 @@ ThreadSafeFory threadSafe = Fory.Builder().BuildThreadSafe(); `Fory.Builder().Build()` uses: -| Option | Default | Description | -| --------------------------------- | ------- | ------------------------------------------------- | -| `TrackRef` | `false` | Reference tracking disabled | -| `Compatible` | `true` | Compatible schema-evolution metadata enabled | -| `CheckStructVersion` | `false` | Struct schema hash checks disabled | -| `MaxDepth` | `20` | Max dynamic nesting depth | -| `MaxTypeFields` | `512` | Max fields in one received struct metadata body | -| `MaxTypeMetaBytes` | `4096` | Max encoded bytes in one received metadata body | -| `MaxSchemaVersionsPerType` | `10` | Max remote metadata versions for one logical type | -| `MaxAverageSchemaVersionsPerType` | `3` | Average remote metadata versions across types | +| Option | Default | Description | +| --------------------------------- | ----------- | ------------------------------------------------- | +| `TrackRef` | `false` | Reference tracking disabled | +| `Compatible` | `true` | Compatible schema-evolution metadata enabled | +| `CheckStructVersion` | `false` | Struct schema hash checks disabled | +| `MaxDepth` | `20` | Max dynamic nesting depth | +| `MaxGraphMemoryBytes` | `134217728` | Graph memory budget | +| `MaxTypeFields` | `512` | Max fields in one received struct metadata body | +| `MaxTypeMetaBytes` | `4096` | Max encoded bytes in one received metadata body | +| `MaxSchemaVersionsPerType` | `10` | Max remote metadata versions for one logical type | +| `MaxAverageSchemaVersionsPerType` | `3` | Average remote metadata versions across types | ## Builder Options @@ -96,6 +97,19 @@ Fory fory = Fory.Builder() `value` must be greater than `0`. +### `MaxGraphMemoryBytes(long value)` + +Sets the maximum estimated shallow graph memory accepted during one root deserialization. + +```csharp +Fory fory = Fory.Builder() + .MaxGraphMemoryBytes(64L * 1024 * 1024) + .Build(); +``` + +The default limit is a fixed `128 MiB` for all root input forms. A positive value overrides the +default. Explicit non-positive values are rejected when the runtime is created. + ### `MaxTypeFields(int value)` Sets the maximum fields accepted in one received remote struct metadata body. @@ -173,6 +187,8 @@ Security-related configuration: - Register only the expected types before deserializing untrusted payloads. - Use `CheckStructVersion(true)` with `Compatible(false)` for intentional same-schema payloads. - Set `MaxDepth(...)` to reject unexpectedly deep dynamic object graphs. +- Set `MaxGraphMemoryBytes(...)` to cap estimated shallow graph memory during one root + deserialization. - Keep the remote schema metadata limits at their defaults unless the data is not malicious and a trusted peer sends larger metadata or many schema versions. - Prefer generated or registered concrete models over broad dynamic fields for untrusted input. diff --git a/docs/guide/csharp/references.md b/docs/guide/csharp/references.md index 772b2535ca..6b6b2363c0 100644 --- a/docs/guide/csharp/references.md +++ b/docs/guide/csharp/references.md @@ -65,6 +65,10 @@ System.Diagnostics.Debug.Assert(object.ReferenceEquals(decoded, decoded.Next)); `TrackRef(false)` can be faster for tree-like, acyclic data where reference identity does not matter. +C# union wrappers are immutable and are created after their case payload is read. +Reference cycles from a union case payload back to the containing union are not +supported; Fory rejects unresolved refs instead of returning a partial union. + ## Related Topics - [Configuration](configuration.md) diff --git a/docs/guide/dart/configuration.md b/docs/guide/dart/configuration.md index 6a4c640f6a..4e3c005cf9 100644 --- a/docs/guide/dart/configuration.md +++ b/docs/guide/dart/configuration.md @@ -38,6 +38,7 @@ final fory = Fory( maxTypeMetaBytes: 4096, maxSchemaVersionsPerType: 10, maxAverageSchemaVersionsPerType: 3, + maxGraphMemoryBytes: 64 * 1024 * 1024, ); ``` @@ -107,17 +108,36 @@ final fory = Fory( - `maxAverageSchemaVersionsPerType` limits the average across accepted remote types. The effective global floor is `8192` schemas. +### `maxGraphMemoryBytes` + +Limits estimated shallow graph memory for one root deserialization. The budget covers materialized +Dart lists, sets, maps, object/reference arrays, structs, objects, and compatible list/array +materialization. It does not count strings, binary values, or dense typed-array payloads, which are +protected by byte-availability checks. + +The default is a fixed `128 MiB` and is not derived from input size. + +Set a positive value when a trusted workload legitimately contains compact, graph-heavy object or +collection payloads: + +```dart +final fory = Fory(maxGraphMemoryBytes: 256 * 1024 * 1024); +``` + +Explicit non-positive values are rejected when the runtime is created. + ## Defaults -| Option | Default | -| --------------------------------- | ------- | -| `compatible` | `true` | -| `checkStructVersion` | `false` | -| `maxDepth` | 256 | -| `maxTypeFields` | 512 | -| `maxTypeMetaBytes` | 4096 | -| `maxSchemaVersionsPerType` | 10 | -| `maxAverageSchemaVersionsPerType` | 3 | +| Option | Default | +| --------------------------------- | --------- | +| `compatible` | `true` | +| `checkStructVersion` | `false` | +| `maxDepth` | 256 | +| `maxTypeFields` | 512 | +| `maxTypeMetaBytes` | 4096 | +| `maxSchemaVersionsPerType` | 10 | +| `maxAverageSchemaVersionsPerType` | 3 | +| `maxGraphMemoryBytes` | 134217728 | ## Xlang Notes @@ -134,6 +154,8 @@ Security-related configuration: - Register only the expected generated models before deserializing untrusted payloads. - Use `checkStructVersion: true` with `compatible: false` for intentional same-schema payloads. - Set `maxDepth` to reject unexpectedly deep payload shapes. +- Keep `maxGraphMemoryBytes` at the default for most inputs, or set an explicit positive byte limit + for known trusted graph-heavy payloads. - Keep the remote schema metadata limits at their defaults unless the data is not malicious and a trusted peer sends larger metadata or many schema versions. - Prefer generated schemas and explicit field metadata over broad dynamic fields for untrusted input. diff --git a/docs/guide/go/configuration.md b/docs/guide/go/configuration.md index 20d9012aee..e29202acc7 100644 --- a/docs/guide/go/configuration.md +++ b/docs/guide/go/configuration.md @@ -33,16 +33,17 @@ f := fory.New(fory.WithXlang(true)) Default settings: -| Option | Default | Description | -| ------------------------------- | ------- | ------------------------------------------------- | -| TrackRef | false | Reference tracking disabled | -| MaxDepth | 20 | Maximum nesting depth | -| IsXlang | true | Xlang mode enabled | -| Compatible | true | Compatible schema-evolution metadata enabled | -| MaxTypeFields | 512 | Max fields in one received struct metadata body | -| MaxTypeMetaBytes | 4096 | Max encoded bytes in one received metadata body | -| MaxSchemaVersionsPerType | 10 | Max remote metadata versions for one logical type | -| MaxAverageSchemaVersionsPerType | 3 | Average remote metadata versions across types | +| Option | Default | Description | +| ------------------------------- | --------- | ------------------------------------------------- | +| TrackRef | false | Reference tracking disabled | +| MaxDepth | 20 | Maximum nesting depth | +| IsXlang | true | Xlang mode enabled | +| Compatible | true | Compatible schema-evolution metadata enabled | +| MaxGraphMemoryBytes | 134217728 | Graph memory limit per root read | +| MaxTypeFields | 512 | Max fields in one received struct metadata body | +| MaxTypeMetaBytes | 4096 | Max encoded bytes in one received metadata body | +| MaxSchemaVersionsPerType | 10 | Max remote metadata versions for one logical type | +| MaxAverageSchemaVersionsPerType | 3 | Average remote metadata versions across types | ### With Options @@ -51,6 +52,7 @@ f := fory.New( fory.WithXlang(true), fory.WithTrackRef(true), fory.WithMaxDepth(10), + fory.WithMaxGraphMemoryBytes(128 * 1024 * 1024), fory.WithMaxTypeFields(512), fory.WithMaxTypeMetaBytes(4096), fory.WithMaxSchemaVersionsPerType(10), @@ -127,6 +129,22 @@ f := fory.New(fory.WithMaxDepth(30)) - Protects against deeply nested, recursive structures or malicious data - Serialization fails with error when exceeded +### WithMaxGraphMemoryBytes + +Limit estimated shallow graph memory accepted during one root deserialization: + +```go +f := fory.New(fory.WithMaxGraphMemoryBytes(256 * 1024 * 1024)) +``` + +The default limit is a fixed `128 MiB` for byte-slice and stream roots. A +positive value overrides the default. Explicit non-positive values are rejected +when the runtime is created. The budget covers lower-bound slice +backing storage, map key/value storage, sets, generated object reads, and +materialized struct field storage. Strings, binary blobs, and primitive dense +array owners keep their byte-availability checks and are not reserved against +this budget. + ### WithMaxTypeFields Set the maximum fields accepted in one received remote struct metadata body: diff --git a/docs/guide/java/configuration.md b/docs/guide/java/configuration.md index 7b3ce60bf6..f947ac6ed7 100644 --- a/docs/guide/java/configuration.md +++ b/docs/guide/java/configuration.md @@ -38,6 +38,7 @@ This page documents all configuration options available through `ForyBuilder`. | `registerGuavaTypes` | Whether to pre-register Guava types such as `RegularImmutableMap`/`RegularImmutableList`. These types are not public API, but seem pretty stable. | `true` | | `requireClassRegistration` | Disabling may allow unknown classes to be deserialized, potentially causing security risks. | `true` | | `maxDepth` | Set max depth for deserialization, when depth exceeds, an exception will be thrown. This can be used to refuse deserialization DDOS attack. | `50` | +| `maxGraphMemoryBytes` | Maximum estimated shallow graph memory accepted during one root deserialization. The default is a fixed `128 MiB`; positive values set an explicit byte limit. Explicit non-positive values are rejected when the runtime is created. | `134217728` | | `maxTypeFields` | Maximum fields accepted in one received remote struct metadata body. | `512` | | `maxTypeMetaBytes` | Maximum encoded body bytes accepted for one received TypeDef or TypeMeta body, excluding the 8-byte header and any extended-size varint. | `4096` | | `maxSchemaVersionsPerType` | Maximum accepted remote metadata versions for one logical type. | `10` | @@ -90,6 +91,7 @@ Keep class registration enabled for production and any untrusted payload source: Fory fory = Fory.builder() .requireClassRegistration(true) .withMaxDepth(50) + .withMaxGraphMemoryBytes(128L * 1024 * 1024) .build(); ``` @@ -97,6 +99,10 @@ Security-related options: - `requireClassRegistration(true)` restricts deserialization to registered classes. - `withMaxDepth(...)` rejects unexpectedly deep object graphs. +- `withMaxGraphMemoryBytes(...)` bounds estimated shallow graph memory during one root + deserialization. The default is a fixed `128 MiB`; set a positive byte limit when trusted + workloads need a larger or smaller limit. Explicit non-positive values are rejected when the + runtime is created. - `withMaxTypeFields(...)` and `withMaxTypeMetaBytes(...)` bound the field count and encoded body size of one received remote metadata body. - `withMaxSchemaVersionsPerType(...)` and diff --git a/docs/guide/javascript/configuration.md b/docs/guide/javascript/configuration.md index 058bccf4b3..6b96329c92 100644 --- a/docs/guide/javascript/configuration.md +++ b/docs/guide/javascript/configuration.md @@ -43,6 +43,7 @@ const fory = new Fory({ ref: true, compatible: true, maxDepth: 100, + maxGraphMemoryBytes: 128 * 1024 * 1024, maxTypeFields: 512, maxTypeMetaBytes: 4096, maxSchemaVersionsPerType: 10, @@ -51,18 +52,19 @@ const fory = new Fory({ }); ``` -| Option | Default | Description | -| --------------------------------- | ------- | ------------------------------------------------------------------------------------- | -| `ref` | `false` | Enable reference tracking for shared or circular object graphs | -| `compatible` | `true` | Allow field additions/removals without breaking existing messages | -| `maxDepth` | `50` | Maximum nesting depth. Must be `>= 2`. Increase for deeply nested structures | -| `maxTypeFields` | `512` | Maximum fields accepted in one received remote struct metadata body | -| `maxTypeMetaBytes` | `4096` | Maximum encoded body bytes accepted for one received TypeMeta body | -| `maxSchemaVersionsPerType` | `10` | Maximum accepted remote metadata versions for one logical type | -| `maxAverageSchemaVersionsPerType` | `3` | Average accepted remote metadata versions across accepted remote types | -| `useSliceString` | `false` | Optional string-reading optimization for Node.js. Leave at default unless benchmarked | -| `hps` | unset | Optional fast string helper from `@apache-fory/hps` (Node.js 20+) | -| `hooks.afterCodeGenerated` | unset | Callback to inspect the generated serializer code, useful for debugging | +| Option | Default | Description | +| --------------------------------- | --------- | ------------------------------------------------------------------------------------- | +| `ref` | `false` | Enable reference tracking for shared or circular object graphs | +| `compatible` | `true` | Allow field additions/removals without breaking existing messages | +| `maxDepth` | `50` | Maximum nesting depth. Must be `>= 2`. Increase for deeply nested structures | +| `maxGraphMemoryBytes` | `128 MiB` | Maximum estimated shallow graph memory accepted during one root deserialization | +| `maxTypeFields` | `512` | Maximum fields accepted in one received remote struct metadata body | +| `maxTypeMetaBytes` | `4096` | Maximum encoded body bytes accepted for one received TypeMeta body | +| `maxSchemaVersionsPerType` | `10` | Maximum accepted remote metadata versions for one logical type | +| `maxAverageSchemaVersionsPerType` | `3` | Average accepted remote metadata versions across accepted remote types | +| `useSliceString` | `false` | Optional string-reading optimization for Node.js. Leave at default unless benchmarked | +| `hps` | unset | Optional fast string helper from `@apache-fory/hps` (Node.js 20+) | +| `hooks.afterCodeGenerated` | unset | Callback to inspect the generated serializer code, useful for debugging | ## Reference Tracking @@ -92,6 +94,27 @@ to that struct. For cross-language payloads, set `compatible: false` only after verifying that every language uses the same schema, or when native types are generated from Fory schema IDL. See [Schema Evolution](schema-evolution.md). +## Graph Memory Budget + +`maxGraphMemoryBytes` limits estimated shallow graph memory accepted during one +root deserialization. The budget covers materialized arrays, sets, object +arrays, maps, structs, and objects; it is not an exact JavaScript heap limit. +The default is a fixed `128 MiB` and is not derived from input size. + +Use a positive byte value to set an explicit lower or higher limit: + +```ts +const fory = new Fory({ + maxGraphMemoryBytes: 32 * 1024 * 1024, +}); +``` + +Explicit non-positive values are rejected when the runtime is created. + +String, binary, and dedicated dense primitive array payloads keep their normal +byte-size checks and do not consume this graph budget. Raise the limit only for +trusted workloads that legitimately contain very compact object graphs. + ## Optional HPS String Path `@apache-fory/hps` provides an optional Node.js string fast path: @@ -110,6 +133,8 @@ Security-related configuration: - Register only the expected schemas before deserializing untrusted payloads. - Set `maxDepth` for the maximum nesting depth your service accepts. +- Set `maxGraphMemoryBytes` for the maximum graph memory your service + accepts from one root payload. - Keep `maxTypeFields` and `maxTypeMetaBytes` at their defaults unless the data is not malicious and a trusted peer sends larger remote metadata. - Keep `maxSchemaVersionsPerType` and diff --git a/docs/guide/python/configuration.md b/docs/guide/python/configuration.md index fdd6459fea..c798b4cea2 100644 --- a/docs/guide/python/configuration.md +++ b/docs/guide/python/configuration.md @@ -40,6 +40,7 @@ class Fory: max_type_meta_bytes: int = 4096, max_schema_versions_per_type: int = 10, max_average_schema_versions_per_type: int = 3, + max_graph_memory_bytes: int = 128 * 1024 * 1024, policy: DeserializationPolicy = None, field_nullable: bool = False, meta_compressor=None, @@ -59,21 +60,22 @@ class ThreadSafeFory: ## Parameters -| Parameter | Type | Default | Description | -| -------------------------------------- | ------------------------------- | ------- | -------------------------------------------------------------------------------------------------------------------------------------------------------- | -| `xlang` | `bool` | `True` | Use xlang mode. Set `False` for Python native mode. | -| `ref` | `bool` | `False` | Enable reference tracking for shared/circular references. Disable for better performance if your data has no shared references. | -| `strict` | `bool` | `True` | Require type registration for security. Keep this enabled for production unless a policy owns trust decisions. | -| `compatible` | `bool \| None` | `None` | Schema evolution mode. `None` enables compatible mode in both xlang and native mode. Set `False` only when every reader and writer uses the same schema. | -| `max_depth` | `int` | `50` | Maximum deserialization depth for security, preventing stack overflow attacks. | -| `max_type_fields` | `int` | `512` | Maximum fields accepted in one received remote struct metadata body. | -| `max_type_meta_bytes` | `int` | `4096` | Maximum encoded body bytes accepted for one received TypeDef body, excluding the 8-byte header and any extended-size varint. | -| `max_schema_versions_per_type` | `int` | `10` | Maximum accepted remote metadata versions for one logical type. | -| `max_average_schema_versions_per_type` | `int` | `3` | Average accepted remote metadata versions across accepted remote types. The effective global floor is `8192` schemas. | -| `policy` | `DeserializationPolicy \| None` | `None` | Deserialization policy used for security checks. Strongly recommended when `strict=False`. | -| `field_nullable` | `bool` | `False` | Treat dataclass fields as nullable by default. | -| `meta_compressor` | `Any` | `None` | Optional metadata compressor used for compatible-mode metadata encoding. | -| `fory_factory` | `Callable \| None` | `None` | `ThreadSafeFory` factory hook. When set, `ThreadSafeFory` creates instances via this callback; otherwise it forwards `**kwargs` to `Fory` construction. | +| Parameter | Type | Default | Description | +| -------------------------------------- | ------------------------------- | ----------- | -------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `xlang` | `bool` | `True` | Use xlang mode. Set `False` for Python native mode. | +| `ref` | `bool` | `False` | Enable reference tracking for shared/circular references. Disable for better performance if your data has no shared references. | +| `strict` | `bool` | `True` | Require type registration for security. Keep this enabled for production unless a policy owns trust decisions. | +| `compatible` | `bool \| None` | `None` | Schema evolution mode. `None` enables compatible mode in both xlang and native mode. Set `False` only when every reader and writer uses the same schema. | +| `max_depth` | `int` | `50` | Maximum deserialization depth for security, preventing stack overflow attacks. | +| `max_type_fields` | `int` | `512` | Maximum fields accepted in one received remote struct metadata body. | +| `max_type_meta_bytes` | `int` | `4096` | Maximum encoded body bytes accepted for one received TypeDef body, excluding the 8-byte header and any extended-size varint. | +| `max_schema_versions_per_type` | `int` | `10` | Maximum accepted remote metadata versions for one logical type. | +| `max_average_schema_versions_per_type` | `int` | `3` | Average accepted remote metadata versions across accepted remote types. The effective global floor is `8192` schemas. | +| `max_graph_memory_bytes` | `int` | `134217728` | Maximum estimated shallow graph memory for one root deserialization. Explicit non-positive values are rejected. | +| `policy` | `DeserializationPolicy \| None` | `None` | Deserialization policy used for security checks. Strongly recommended when `strict=False`. | +| `field_nullable` | `bool` | `False` | Treat dataclass fields as nullable by default. | +| `meta_compressor` | `Any` | `None` | Optional metadata compressor used for compatible-mode metadata encoding. | +| `fory_factory` | `Callable \| None` | `None` | `ThreadSafeFory` factory hook. When set, `ThreadSafeFory` creates instances via this callback; otherwise it forwards `**kwargs` to `Fory` construction. | ## Key Methods @@ -197,6 +199,7 @@ fory = pyfory.Fory( max_type_meta_bytes=4096, max_schema_versions_per_type=10, max_average_schema_versions_per_type=3, + max_graph_memory_bytes=128 * 1024 * 1024, ) fory.register(UserModel, name="example.User") @@ -222,6 +225,11 @@ Received remote metadata is also limited: - `max_type_meta_bytes` limits the encoded body bytes accepted for one received TypeDef body. - `max_schema_versions_per_type` limits accepted remote metadata versions for one logical type. - `max_average_schema_versions_per_type` limits the average across accepted remote types. +- `max_graph_memory_bytes` limits estimated shallow graph memory created during one root + deserialization, including materialized lists, tuples, sets, dicts, object arrays, structs, and + Python objects. The default is a fixed `128 MiB` for all root input forms. Set a positive byte + value for trusted payloads that legitimately contain larger or smaller object graphs. Explicit + non-positive values are rejected when the runtime is created. These limits do not change `strict`, `policy`, dynamic loading, unknown-class handling, or schema-evolution semantics. @@ -278,6 +286,8 @@ unchanged. - Register all expected application types before deserialization. - Use `DeserializationPolicy` when `strict=False` is necessary. - Keep `max_depth` low enough to reject unexpectedly deep payloads. +- Keep `max_graph_memory_bytes` at the fixed `128 MiB` default for most inputs, or set a positive + explicit limit for trusted workloads with different legitimate object-graph sizes. - Do not treat xlang/native mode choice as a security control. ## Related Topics diff --git a/docs/guide/rust/configuration.md b/docs/guide/rust/configuration.md index 58bd070567..37455840e5 100644 --- a/docs/guide/rust/configuration.md +++ b/docs/guide/rust/configuration.md @@ -110,6 +110,22 @@ let fory = Fory::builder() - `max_average_schema_versions_per_type` defaults to `3` and limits the average across accepted remote types. The effective global floor is `8192` schemas. +### Graph Memory Budget + +`max_graph_memory_bytes(...)` limits estimated shallow graph memory accepted during one root read. +The budget covers `Vec`/collection element storage, map key/value storage, and materialized struct +or object field storage; it is not an exact process heap limit. The default is a fixed `128 MiB` for +all root input forms. Set a positive byte value when trusted payloads need a larger or smaller +limit: + +```rust +let fory = Fory::builder() + .max_graph_memory_bytes(256 * 1024 * 1024) + .build(); +``` + +Explicit non-positive values are rejected when the runtime is created. + ### Explicit Xlang Examples Set `.xlang(true)` explicitly for xlang serialization examples: @@ -135,6 +151,11 @@ let fory = Fory::builder().xlang(false).compatible(false).build(); // Custom depth limit let fory = Fory::builder().max_dyn_depth(10).build(); +// Custom graph memory budget +let fory = Fory::builder() + .max_graph_memory_bytes(256 * 1024 * 1024) + .build(); + // Combined configuration let fory = Fory::builder() .xlang(false) @@ -144,15 +165,16 @@ let fory = Fory::builder() ## Configuration Summary -| Option | Description | Default | -| --------------------------------------------- | ------------------------------------------------- | ------- | -| `compatible(bool)` | Enable schema evolution | `true` | -| `xlang(bool)` | Use xlang mode | `true` | -| `max_dyn_depth(u32)` | Maximum nesting depth for dynamic types | `5` | -| `max_type_fields(usize)` | Max fields in one received struct metadata body | `512` | -| `max_type_meta_bytes(usize)` | Max encoded bytes in one received metadata body | `4096` | -| `max_schema_versions_per_type(usize)` | Max remote metadata versions for one logical type | `10` | -| `max_average_schema_versions_per_type(usize)` | Average remote metadata versions across types | `3` | +| Option | Description | Default | +| --------------------------------------------- | ------------------------------------------------- | --------- | +| `compatible(bool)` | Enable schema evolution | `true` | +| `xlang(bool)` | Use xlang mode | `true` | +| `max_dyn_depth(u32)` | Maximum nesting depth for dynamic types | `5` | +| `max_graph_memory_bytes(i64)` | Estimated graph memory per root read | `128 MiB` | +| `max_type_fields(usize)` | Max fields in one received struct metadata body | `512` | +| `max_type_meta_bytes(usize)` | Max encoded bytes in one received metadata body | `4096` | +| `max_schema_versions_per_type(usize)` | Max remote metadata versions for one logical type | `10` | +| `max_average_schema_versions_per_type(usize)` | Average remote metadata versions across types | `3` | ## Compatible Mode @@ -169,6 +191,8 @@ Security-related configuration: - Register application structs and trait-object implementations before deserializing untrusted payloads. - Use `max_dyn_depth(...)` to reject unexpectedly deep dynamic object graphs. +- Keep `max_graph_memory_bytes(...)` at the fixed `128 MiB` default for most inputs, or set a + positive byte limit for trusted workloads with different legitimate object-graph sizes. - Keep the remote schema metadata limits at their defaults unless the data is not malicious and a trusted peer sends larger metadata or many schema versions. - Prefer concrete typed fields over `dyn Any` or broad trait-object fields for untrusted input. diff --git a/docs/guide/swift/configuration.md b/docs/guide/swift/configuration.md index e3a0478817..dcfb169f60 100644 --- a/docs/guide/swift/configuration.md +++ b/docs/guide/swift/configuration.md @@ -31,6 +31,7 @@ public struct Config { public let compatible: Bool public let checkClassVersion: Bool public let maxDepth: Int + public let maxGraphMemoryBytes: Int64 public let maxTypeFields: Int public let maxTypeMetaBytes: Int public let maxSchemaVersionsPerType: Int @@ -90,8 +91,13 @@ let fory = Fory(compatible: false, checkClassVersion: true) ### Size and Depth Limits -`maxDepth` bounds decoded payload nesting depth. Compatible-mode remote metadata -is also limited: +`maxDepth` bounds decoded payload nesting depth. + +`maxGraphMemoryBytes` bounds estimated shallow graph memory accepted during one root +deserialization. The default limit is a fixed `128 MiB` for all root input forms. A positive value +overrides the default. Explicit non-positive values are rejected when the runtime is created. + +Compatible-mode remote metadata is also limited: - `maxTypeFields` defaults to `512` and limits fields in one received struct metadata body. - `maxTypeMetaBytes` defaults to `4096` and limits encoded body bytes in one received TypeMeta body, @@ -104,6 +110,7 @@ is also limited: ```swift let fory = Fory( maxDepth: 5, + maxGraphMemoryBytes: 128 * 1024 * 1024, maxTypeFields: 512, maxTypeMetaBytes: 4096, maxSchemaVersionsPerType: 10, @@ -140,5 +147,6 @@ Security-related configuration: - Register only the expected generated models before deserializing untrusted payloads. - Use `checkClassVersion` with `compatible: false` for intentional same-schema payloads. - Set `maxDepth` for the largest nesting depth your service accepts. +- Set `maxGraphMemoryBytes` to cap estimated shallow graph memory during one root deserialization. - Keep the remote schema metadata limits at their defaults unless the data is not malicious and a trusted peer sends larger metadata or many schema versions. diff --git a/docs/security/deserialization.md b/docs/security/deserialization.md index 33dddd2886..0f615c9712 100644 --- a/docs/security/deserialization.md +++ b/docs/security/deserialization.md @@ -149,11 +149,13 @@ For buffer-backed input: comparison. - Multi-byte element arrays should compute the required byte size with overflow checks before allocation. -- Container readers that allocate, reserve, or size-hint from a declared - logical element count should first call the byte owner's readability check for - that count. This is not a full container-body validation; it is the allocation - proof that the sender has supplied at least proportional input bytes before - the reader preallocates from the count. +- Container readers that allocate backing storage or size-hint from a declared + logical element count should call the byte owner's readability check for that + count before that backing allocation or capacity reservation. This is not a + full container-body validation; it is the allocation proof that the sender has + supplied at least proportional input bytes before the reader preallocates from + the count. Estimated memory-budget accounting may reserve budget before this + byte check because it does not allocate backing storage. For stream-backed input: @@ -203,6 +205,48 @@ validation can cause a no-progress loop, unbounded resource growth, retained state, or success across a Fory policy boundary. Protocol-allowed chunk segmentation is normal input and is not a security issue by itself. +## Root Graph Memory Budget + +Runtimes should enforce a root-deserialization budget for estimated shallow memory created by one +materialized graph. This is cumulative accounting for graph owners created by one root read; it is +not exact heap measurement and it is not a raw element-slot limit. + +The public configuration is `maxGraphMemoryBytes`. The default is a fixed `128 MiB` for all root +input forms; positive user configuration overrides the default. Explicit non-positive configuration +is invalid and should be rejected when the runtime is created. The budget is not derived from root +input size, and stream budgeting should not depend on dynamic bytes-read accounting. + +Graph budget accounting should: + +- happen in root-operation read state, with cleanup owned by the root deserialization `finally`; +- keep read context/read state limited to raw byte reservation; counted arithmetic and collection, + map, array, struct, and object storage formulas belong in the concrete serializer or generated + serializer owner; +- reject arithmetic overflow before comparing budget or allocating; +- estimate lower-bound shallow owner storage: independently materialized collections, maps, sets, + and reference arrays reserve nonzero shallow self cost plus backing/reference/inline storage, and + struct/record/POJO/tuple, compatible, generated, and dynamic object owners reserve a nonzero + shallow self cost plus shallow field storage; +- use a 4-byte reference slot when the actual reference slot size is not cheap or reliable to query, + and use primitive/value field widths for inline storage; +- preserve existing byte-availability checks before backing allocation or capacity reservation; +- skip enum/union as separate owners and skip dedicated string, binary, primitive scalar, primitive + array, and primitive dense-array leaf owners. + +Each runtime must inspect the concrete owner path before choosing formulas. Reserve self storage +exactly once at the owner that stores or allocates the value. Root facades may reset the budget, but +must not pre-reserve root type or root self bytes. Reference-backed paths reserve parent owner self +cost plus reference storage, while each referenced heap owner reserves its own shallow self cost when +materialized. Inline/value paths reserve inline element, field, or root value storage in the +holder/allocation owner; nested value serializers must not charge their own self storage again. +Parents must not recursively include child object, collection, map, string, binary, or primitive +dense-array contents; the child owner reserves its own shallow memory when it is materialized. + +Runtimes should not guess object headers, array headers, allocator headers, debug-mode fields, hash +buckets, tree links, hash-chain links, node headers, map-entry objects, spare blocks, or runtime +table layouts unless the owner path has a cheap, stable, explicit lower-bound storage signal and +documents the formula. C++ STL node, allocator, and debug-mode overheads should not be guessed. + ## Skip Semantics Skipping unknown or incompatible data is classified by concrete impact, not by diff --git a/docs/specification/xlang_implementation_guide.md b/docs/specification/xlang_implementation_guide.md index d8fb205012..5fb5e299aa 100644 --- a/docs/specification/xlang_implementation_guide.md +++ b/docs/specification/xlang_implementation_guide.md @@ -388,9 +388,9 @@ chunk, nullability, reference, and type-dispatch semantics. It is still the right allocation proof for count-based preallocation: after validating a non-empty count and reading any serializer-owned header or type metadata that precedes allocation, call `checkReadableBytes(logicalCount)` before allocating, -reserving, or size-hinting from that count. The byte owner handles buffer versus -stream readiness; the container serializer then allocates with the declared -count and reads elements through its normal owner path. +reserving backing capacity, or size-hinting from that count. The byte owner +handles buffer versus stream readiness; the container serializer then allocates +with the declared count and reads elements through its normal owner path. This check is not a full container-body validation. It only prevents a small or truncated input from causing a large count-based preallocation. Chunk sizes, @@ -398,6 +398,52 @@ duplicate keys, element value semantics, and protocol strictness remain owned by the container/map serializer and should be validated only when they protect a real owner invariant. +Materializing readers should also reserve a root-operation estimated graph +memory budget before allocation or size hinting. The budget state belongs to +`ReadContext` or the equivalent root read state, not to ambient thread-local +state. Root facades set or reset the per-operation budget only; they must not +pre-reserve root type or root self bytes. `maxGraphMemoryBytes` defaults to a +fixed `128 MiB`; positive configuration overrides the default; explicit +non-positive configuration is invalid and must be rejected when the runtime is +created. Do not derive this budget from root input size, and do not add dynamic +stream bytes-read accounting for this budget. +Because the budget is fixed per root, read state should not mirror the +configured maximum into a second active-limit field. Use the existing +configuration, or one configured maximum field when the config is not otherwise +available, plus the mutable remaining budget. + +Read context or equivalent read state owns only raw byte reservation. It must +not expose counted arithmetic helpers or collection, map, array, struct, or +object semantic reservation APIs. Concrete serializers and generated serializer +owners compute the storage constants and formulas for the owner path they +allocate, including counted-byte overflow checks. +Read state must not grow non-memory-budget APIs for this feature, including +ref-publication controls, temporary-owner controls, serializer-owner controls, +conversion helpers, or APIs that encode the kind of value being materialized. +Concrete serializers and generated serializers own those decisions. + +The budget estimates lower-bound shallow memory for materialized graph owners, +not exact heap bytes. Reserve self storage exactly once at the owner that stores +or allocates the value. Reference-backed containers, maps, sets, and +object/reference arrays reserve nonzero owner self cost plus reference slots; +each referenced heap owner then reserves its own shallow self cost when +materialized. Inline/value containers reserve element storage; inline/value maps +reserve key plus value storage; root materialization, product, and box owners +reserve value self storage; and nested value serializers reserve only additional +dynamic storage they allocate. Struct/record/POJO/tuple, compatible, generated, and dynamic +object owners reserve a nonzero shallow self cost plus shallow field storage. +Parents must not recursively include child object, collection, map, string, +binary, or primitive dense-array contents. Skip enum/union as separate owners and +skip dedicated string, binary, primitive scalar, primitive array, and primitive +dense-array leaf owners, but do not skip general inline-value containers such as +vectors or lists of value objects. If reference slot size is not cheap or +reliable to query, use a 4-byte reference slot. Native runtimes may use +conservative lower-bound estimates instead of guessing non-portable object, +container, allocator, table, node, entry, or debug-layout details. Reject +arithmetic overflow before budget comparison or allocation, and keep the +existing `checkReadableBytes` proof before backing +allocation or capacity reservation. + For TypeDef or TypeMeta bodies, first prove that the encoded metadata body bytes are readable through the byte owner. Field-list allocation should happen after that body readability check and should not use a separate small initial-capacity diff --git a/go/fory/README.md b/go/fory/README.md index 8b5ec3392c..5a18ddba06 100644 --- a/go/fory/README.md +++ b/go/fory/README.md @@ -93,11 +93,16 @@ f := fory.New(fory.WithXlang(false), fory.WithCompatible(false)) // Set maximum nesting depth f := fory.New(fory.WithMaxDepth(20)) +// Set maximum estimated graph memory for one root read +f := fory.New(fory.WithMaxGraphMemoryBytes(256 * 1024 * 1024)) +// The value must be positive. + // Combine multiple options f := fory.New( fory.WithXlang(true), fory.WithTrackRef(true), fory.WithMaxDepth(20), + fory.WithMaxGraphMemoryBytes(256 * 1024 * 1024), ) ``` diff --git a/go/fory/array.go b/go/fory/array.go index f99f6ff39f..dc99f50c1f 100644 --- a/go/fory/array.go +++ b/go/fory/array.go @@ -298,7 +298,7 @@ func newArrayDynSerializer(elemType reflect.Type) (arrayDynSerializer, error) { if err != nil { return arrayDynSerializer{}, err } - return arrayDynSerializer{sliceSerializer: sliceSer}, nil + return arrayDynSerializer{sliceSerializer: *sliceSer}, nil } func (s arrayDynSerializer) WriteData(ctx *WriteContext, value reflect.Value) { @@ -318,6 +318,14 @@ func (s arrayDynSerializer) Write(ctx *WriteContext, refMode RefMode, writeType func (s arrayDynSerializer) ReadData(ctx *ReadContext, value reflect.Value) { // Create a temp slice to read into, then copy back to array sliceType := reflect.SliceOf(value.Type().Elem()) + elemBytes := int64(value.Type().Elem().Size()) + if int64(value.Len()) > maxGraphCount(elemBytes) { + ctx.SetError(DeserializationErrorf("graph memory estimate overflows: length=%d elementBytes=%d", value.Len(), elemBytes)) + return + } + if !ctx.ReserveGraphMemory(int64(value.Len()) * elemBytes) { + return + } tempSlice := reflect.MakeSlice(sliceType, value.Len(), value.Len()) s.sliceSerializer.readData(ctx, tempSlice, value.Len()) if ctx.HasError() { diff --git a/go/fory/codegen/decoder.go b/go/fory/codegen/decoder.go index 0e57021343..931e36b614 100644 --- a/go/fory/codegen/decoder.go +++ b/go/fory/codegen/decoder.go @@ -25,6 +25,28 @@ import ( "github.com/apache/fory/go/fory" ) +func writeGraphReservation(buf *bytes.Buffer, indent, countExpr, elemBytesExpr string) { + fmt.Fprintf(buf, "%s{\n", indent) + fmt.Fprintf(buf, "%s\tgraphCount := %s\n", indent, countExpr) + fmt.Fprintf(buf, "%s\tgraphElemBytes := int64(%s)\n", indent, elemBytesExpr) + fmt.Fprintf(buf, "%s\tif graphCount < 0 {\n", indent) + fmt.Fprintf(buf, "%s\t\tctx.SetError(fory.DeserializationErrorf(\"negative graph element count: %%d\", graphCount))\n", indent) + fmt.Fprintf(buf, "%s\t\treturn ctx.TakeError()\n", indent) + fmt.Fprintf(buf, "%s\t}\n", indent) + fmt.Fprintf(buf, "%s\tif graphElemBytes < 0 {\n", indent) + fmt.Fprintf(buf, "%s\t\tctx.SetError(fory.DeserializationErrorf(\"negative graph element size: %%d\", graphElemBytes))\n", indent) + fmt.Fprintf(buf, "%s\t\treturn ctx.TakeError()\n", indent) + fmt.Fprintf(buf, "%s\t}\n", indent) + fmt.Fprintf(buf, "%s\tif graphCount != 0 && graphElemBytes != 0 && int64(graphCount) > fory.MaxInt64/graphElemBytes {\n", indent) + fmt.Fprintf(buf, "%s\t\tctx.SetError(fory.DeserializationErrorf(\"graph memory estimate overflows: length=%%d elementBytes=%%d\", graphCount, graphElemBytes))\n", indent) + fmt.Fprintf(buf, "%s\t\treturn ctx.TakeError()\n", indent) + fmt.Fprintf(buf, "%s\t}\n", indent) + fmt.Fprintf(buf, "%s\tif !ctx.ReserveGraphMemory(int64(graphCount) * graphElemBytes) {\n", indent) + fmt.Fprintf(buf, "%s\t\treturn ctx.TakeError()\n", indent) + fmt.Fprintf(buf, "%s\t}\n", indent) + fmt.Fprintf(buf, "%s}\n", indent) +} + // generateReadTyped generates the strongly-typed ReadData method func generateReadTyped(buf *bytes.Buffer, s *StructInfo) error { fmt.Fprintf(buf, "// ReadTyped provides strongly-typed deserialization with no reflection overhead\n") @@ -69,7 +91,13 @@ func generateReadInterface(buf *bytes.Buffer, s *StructInfo) error { fmt.Fprintf(buf, "\tvar v *%s\n", s.Name) fmt.Fprintf(buf, "\tif value.Kind() == reflect.Ptr {\n") fmt.Fprintf(buf, "\t\tif value.IsNil() {\n") - fmt.Fprintf(buf, "\t\t\t// For pointer types, allocate using value.Type().Elem()\n") + fmt.Fprintf(buf, "\t\t\tgraphBytes := int64(unsafe.Sizeof(%s{}))\n", s.Name) + fmt.Fprintf(buf, "\t\t\tif graphBytes == 0 {\n") + fmt.Fprintf(buf, "\t\t\t\tgraphBytes = 1\n") + fmt.Fprintf(buf, "\t\t\t}\n") + fmt.Fprintf(buf, "\t\t\tif !ctx.ReserveGraphMemory(graphBytes) {\n") + fmt.Fprintf(buf, "\t\t\t\treturn\n") + fmt.Fprintf(buf, "\t\t\t}\n") fmt.Fprintf(buf, "\t\t\tvalue.Set(reflect.New(value.Type().Elem()))\n") fmt.Fprintf(buf, "\t\t}\n") fmt.Fprintf(buf, "\t\tv = value.Interface().(*%s)\n", s.Name) @@ -172,6 +200,7 @@ func generateFieldReadTyped(buf *bytes.Buffer, field *FieldInfo) error { fmt.Fprintf(buf, "\t\t\tif ctx.HasError() {\n") fmt.Fprintf(buf, "\t\t\t\treturn ctx.TakeError()\n") fmt.Fprintf(buf, "\t\t\t}\n") + writeGraphReservation(buf, "\t\t\t", "sliceLen", unsafeSizeExpr(elemType)) fmt.Fprintf(buf, "\t\t\tif sliceLen == 0 {\n") fmt.Fprintf(buf, "\t\t\t\t%s = make([]any, 0)\n", fieldAccess) fmt.Fprintf(buf, "\t\t\t} else {\n") @@ -200,6 +229,7 @@ func generateFieldReadTyped(buf *bytes.Buffer, field *FieldInfo) error { fmt.Fprintf(buf, "\t\t\t\tif ctx.HasError() {\n") fmt.Fprintf(buf, "\t\t\t\t\treturn ctx.TakeError()\n") fmt.Fprintf(buf, "\t\t\t\t}\n") + writeGraphReservation(buf, "\t\t\t\t", "sliceLen", unsafeSizeExpr(elemType)) fmt.Fprintf(buf, "\t\t\t\tif sliceLen == 0 {\n") fmt.Fprintf(buf, "\t\t\t\t\t%s = make([]any, 0)\n", fieldAccess) fmt.Fprintf(buf, "\t\t\t\t} else {\n") @@ -501,6 +531,10 @@ func generateSliceReadInline(buf *bytes.Buffer, sliceType *types.Slice, fieldAcc fmt.Fprintf(buf, "\t\tif isXlang {\n") fmt.Fprintf(buf, "\t\t\t// xlang mode: slices are not nullable, read directly without null flag\n") fmt.Fprintf(buf, "\t\t\tsliceLen := ctx.ReadCollectionLength()\n") + fmt.Fprintf(buf, "\t\t\tif ctx.HasError() {\n") + fmt.Fprintf(buf, "\t\t\t\treturn ctx.TakeError()\n") + fmt.Fprintf(buf, "\t\t\t}\n") + writeGraphReservation(buf, "\t\t\t", "sliceLen", unsafeSizeExpr(elemType)) fmt.Fprintf(buf, "\t\t\tif sliceLen == 0 {\n") fmt.Fprintf(buf, "\t\t\t\t%s = make(%s, 0)\n", fieldAccess, sliceType.String()) fmt.Fprintf(buf, "\t\t\t} else {\n") @@ -519,6 +553,7 @@ func generateSliceReadInline(buf *bytes.Buffer, sliceType *types.Slice, fieldAcc fmt.Fprintf(buf, "\t\t\t\tif ctx.HasError() {\n") fmt.Fprintf(buf, "\t\t\t\t\treturn ctx.TakeError()\n") fmt.Fprintf(buf, "\t\t\t\t}\n") + writeGraphReservation(buf, "\t\t\t\t", "sliceLen", unsafeSizeExpr(elemType)) fmt.Fprintf(buf, "\t\t\t\tif sliceLen == 0 {\n") fmt.Fprintf(buf, "\t\t\t\t\t%s = make(%s, 0)\n", fieldAccess, sliceType.String()) fmt.Fprintf(buf, "\t\t\t\t} else {\n") @@ -545,6 +580,7 @@ func generateSliceReadInlineNoNull(buf *bytes.Buffer, sliceType *types.Slice, fi fmt.Fprintf(buf, "%sif ctx.HasError() {\n", indent) fmt.Fprintf(buf, "%s\treturn ctx.TakeError()\n", indent) fmt.Fprintf(buf, "%s}\n", indent) + writeGraphReservation(buf, indent, "sliceLen", unsafeSizeExpr(elemType)) fmt.Fprintf(buf, "%sif sliceLen == 0 {\n", indent) fmt.Fprintf(buf, "%s\t%s = make([]any, 0)\n", indent, fieldAccess) fmt.Fprintf(buf, "%s} else {\n", indent) @@ -568,6 +604,7 @@ func generateSliceReadInlineNoNull(buf *bytes.Buffer, sliceType *types.Slice, fi fmt.Fprintf(buf, "%sif ctx.HasError() {\n", indent) fmt.Fprintf(buf, "%s\treturn ctx.TakeError()\n", indent) fmt.Fprintf(buf, "%s}\n", indent) + writeGraphReservation(buf, indent, "sliceLen", unsafeSizeExpr(elemType)) fmt.Fprintf(buf, "%sif sliceLen == 0 {\n", indent) fmt.Fprintf(buf, "%s\t%s = make(%s, 0)\n", indent, fieldAccess, sliceType.String()) fmt.Fprintf(buf, "%s} else {\n", indent) @@ -831,6 +868,10 @@ func generateMapReadInline(buf *bytes.Buffer, mapType *types.Map, fieldAccess st fmt.Fprintf(buf, "\t\tif isXlang {\n") fmt.Fprintf(buf, "\t\t\t// xlang mode: maps are not nullable, read directly without null flag\n") fmt.Fprintf(buf, "\t\t\tmapLen := ctx.ReadCollectionLength()\n") + fmt.Fprintf(buf, "\t\t\tif ctx.HasError() {\n") + fmt.Fprintf(buf, "\t\t\t\treturn ctx.TakeError()\n") + fmt.Fprintf(buf, "\t\t\t}\n") + writeGraphReservation(buf, "\t\t\t", "mapLen", unsafeSizeExpr(keyType)+" + "+unsafeSizeExpr(valueType)) fmt.Fprintf(buf, "\t\t\tif mapLen == 0 {\n") fmt.Fprintf(buf, "\t\t\t\t%s = make(%s)\n", fieldAccess, mapType.String()) fmt.Fprintf(buf, "\t\t\t} else {\n") @@ -849,6 +890,7 @@ func generateMapReadInline(buf *bytes.Buffer, mapType *types.Map, fieldAccess st fmt.Fprintf(buf, "\t\t\t\tif ctx.HasError() {\n") fmt.Fprintf(buf, "\t\t\t\t\treturn ctx.TakeError()\n") fmt.Fprintf(buf, "\t\t\t\t}\n") + writeGraphReservation(buf, "\t\t\t\t", "mapLen", unsafeSizeExpr(keyType)+" + "+unsafeSizeExpr(valueType)) fmt.Fprintf(buf, "\t\t\t\tif mapLen == 0 {\n") fmt.Fprintf(buf, "\t\t\t\t\t%s = make(%s)\n", fieldAccess, mapType.String()) fmt.Fprintf(buf, "\t\t\t\t} else {\n") @@ -884,6 +926,7 @@ func generateMapReadInlineNoNull(buf *bytes.Buffer, mapType *types.Map, fieldAcc fmt.Fprintf(buf, "%sif ctx.HasError() {\n", indent) fmt.Fprintf(buf, "%s\treturn ctx.TakeError()\n", indent) fmt.Fprintf(buf, "%s}\n", indent) + writeGraphReservation(buf, indent, "mapLen", unsafeSizeExpr(keyType)+" + "+unsafeSizeExpr(valueType)) fmt.Fprintf(buf, "%sif mapLen == 0 {\n", indent) fmt.Fprintf(buf, "%s\t%s = make(%s)\n", indent, fieldAccess, mapType.String()) fmt.Fprintf(buf, "%s} else {\n", indent) @@ -978,6 +1021,10 @@ func getGoTypeString(t types.Type) string { return t.String() } +func unsafeSizeExpr(t types.Type) string { + return fmt.Sprintf("int64(unsafe.Sizeof(*new(%s)))", getGoTypeString(t)) +} + // generateMapKeyRead generates code to read a map key // Uses error-aware methods for deferred error checking func generateMapKeyRead(buf *bytes.Buffer, keyType types.Type, varName string) error { diff --git a/go/fory/codegen/generator.go b/go/fory/codegen/generator.go index dbe7842da4..56f615163d 100644 --- a/go/fory/codegen/generator.go +++ b/go/fory/codegen/generator.go @@ -284,8 +284,9 @@ func generateCodeForFile(pkg *packages.Package, structs []*StructInfo, sourceFil // Determine which imports are needed needsTime := false - needsReflect := false + needsReflect := len(structs) > 0 needsOptional := false + needsUnsafe := len(structs) > 0 for _, s := range structs { for _, field := range s.Fields { @@ -296,8 +297,6 @@ func generateCodeForFile(pkg *packages.Package, structs []*StructInfo, sourceFil if field.IsOptional { needsOptional = true } - // We need reflect for the interface compatibility methods - needsReflect = true } } @@ -310,6 +309,9 @@ func generateCodeForFile(pkg *packages.Package, structs []*StructInfo, sourceFil if needsTime { fmt.Fprintf(&buf, "\t\"time\"\n") } + if needsUnsafe { + fmt.Fprintf(&buf, "\t\"unsafe\"\n") + } fmt.Fprintf(&buf, "\t\"github.com/apache/fory/go/fory\"\n") if needsOptional { fmt.Fprintf(&buf, "\t\"github.com/apache/fory/go/fory/optional\"\n") @@ -549,8 +551,9 @@ func generateCode(pkg *packages.Package, structs []*StructInfo) error { // Determine which imports are needed needsTime := false - needsReflect := false + needsReflect := len(structs) > 0 needsOptional := false + needsUnsafe := len(structs) > 0 for _, s := range structs { for _, field := range s.Fields { @@ -561,8 +564,6 @@ func generateCode(pkg *packages.Package, structs []*StructInfo) error { if field.IsOptional { needsOptional = true } - // We need reflect for the interface compatibility methods - needsReflect = true } } @@ -575,6 +576,9 @@ func generateCode(pkg *packages.Package, structs []*StructInfo) error { if needsTime { fmt.Fprintf(&buf, "\t\"time\"\n") } + if needsUnsafe { + fmt.Fprintf(&buf, "\t\"unsafe\"\n") + } fmt.Fprintf(&buf, "\t\"github.com/apache/fory/go/fory\"\n") if needsOptional { fmt.Fprintf(&buf, "\t\"github.com/apache/fory/go/fory/optional\"\n") diff --git a/go/fory/field_serializer.go b/go/fory/field_serializer.go index d91b6ec5e1..0a3125ff21 100644 --- a/go/fory/field_serializer.go +++ b/go/fory/field_serializer.go @@ -42,7 +42,7 @@ func serializerNeedsGenericDispatch(serializer Serializer) bool { switch serializer.(type) { case *sliceSerializer, primitiveListSerializer, - sliceDynSerializer, + *sliceDynSerializer, setSerializer, mapSerializer, stringSliceSerializer, @@ -68,10 +68,13 @@ func newDeclaredSliceSerializer(type_ reflect.Type, elemSerializer Serializer, r if elem.Kind() == reflect.Ptr && elem.Elem().Kind() == reflect.Interface { return nil, fmt.Errorf("slice serializer does not support pointer to interface element type: %v", type_) } + elemBytes := int64(elem.Size()) return &sliceSerializer{ type_: type_, elemSerializer: elemSerializer, referencable: referencable, + elemBytes: elemBytes, + maxLength: maxGraphCount(elemBytes), }, nil } diff --git a/go/fory/fory.go b/go/fory/fory.go index 412fc46449..55b2a18b21 100644 --- a/go/fory/fory.go +++ b/go/fory/fory.go @@ -69,6 +69,7 @@ type Config struct { MaxDepth int IsXlang bool Compatible bool // Schema evolution compatibility mode + MaxGraphMemoryBytes int64 MaxTypeFields int MaxTypeMetaBytes int MaxSchemaVersionsPerType int @@ -82,6 +83,7 @@ func defaultConfig() Config { MaxDepth: 20, IsXlang: true, MaxTypeFields: 512, + MaxGraphMemoryBytes: 128 * 1024 * 1024, MaxTypeMetaBytes: 4096, MaxSchemaVersionsPerType: 10, MaxAverageSchemaVersionsPerType: 3, @@ -110,6 +112,16 @@ func WithMaxDepth(depth int) Option { } } +// WithMaxGraphMemoryBytes sets the maximum estimated graph memory accepted during one root deserialization. +func WithMaxGraphMemoryBytes(size int64) Option { + if size <= 0 { + panic("MaxGraphMemoryBytes must be positive") + } + return func(f *Fory) { + f.config.MaxGraphMemoryBytes = size + } +} + // WithXlang sets cross-language serialization mode func WithXlang(enabled bool) Option { return func(f *Fory) { @@ -556,15 +568,15 @@ func (f *Fory) Serialize(value any) ([]byte, error) { func (f *Fory) Deserialize(data []byte, v any) error { defer f.resetReadState() f.readCtx.SetData(data) + target := reflect.ValueOf(v).Elem() + f.readCtx.remainingGraphMemoryBytes = f.config.MaxGraphMemoryBytes readHeader(f.readCtx) if f.readCtx.HasError() { return f.readCtx.TakeError() } - // Deserialize the value - TypeMeta is read inline using streaming protocol - target := reflect.ValueOf(v).Elem() - f.readCtx.ReadValue(target, RefModeTracking, true) + f.readRootValue(target) if f.readCtx.HasError() { return f.readCtx.TakeError() } @@ -648,6 +660,8 @@ func (f *Fory) DeserializeFrom(buf *ByteBuffer, v any) error { // Temporarily swap buffer origBuffer := f.readCtx.buffer f.readCtx.buffer = buf + target := reflect.ValueOf(v).Elem() + f.readCtx.remainingGraphMemoryBytes = f.config.MaxGraphMemoryBytes readHeader(f.readCtx) if f.readCtx.HasError() { @@ -656,8 +670,7 @@ func (f *Fory) DeserializeFrom(buf *ByteBuffer, v any) error { } // Deserialize the value - TypeMeta is read inline using streaming protocol - target := reflect.ValueOf(v).Elem() - f.readCtx.ReadValue(target, RefModeTracking, true) + f.readRootValue(target) if f.readCtx.HasError() { f.readCtx.buffer = origBuffer return f.readCtx.TakeError() @@ -748,12 +761,6 @@ func (f *Fory) DeserializeWithCallbackBuffers(buffer *ByteBuffer, v any, buffers f.readCtx.outOfBandBuffers = buffers } - // ReadData and validate header - readHeader(f.readCtx) - if f.readCtx.HasError() { - return f.readCtx.TakeError() - } - // v must be a pointer so we can deserialize into it if v == nil { return fmt.Errorf("v cannot be nil") @@ -766,8 +773,17 @@ func (f *Fory) DeserializeWithCallbackBuffers(buffer *ByteBuffer, v any, buffers return fmt.Errorf("v must be a non-nil pointer") } + target := rv.Elem() + f.readCtx.remainingGraphMemoryBytes = f.config.MaxGraphMemoryBytes + + // ReadData and validate header + readHeader(f.readCtx) + if f.readCtx.HasError() { + return f.readCtx.TakeError() + } + // Deserialize the value - TypeMeta is read inline using streaming protocol - f.readCtx.ReadValue(rv.Elem(), RefModeTracking, true) + f.readRootValue(target) if f.readCtx.HasError() { return f.readCtx.TakeError() } @@ -1016,6 +1032,17 @@ func Deserialize[T any](f *Fory, data []byte, target *T) error { // Reuse context, reset and set new data f.readCtx.Reset() f.readCtx.SetData(data) + f.readCtx.remainingGraphMemoryBytes = f.config.MaxGraphMemoryBytes + + var targetVal reflect.Value + var targetType reflect.Type + switch any(target).(type) { + case *bool, *int8, *int16, *int32, *int64, *int, *float32, *float64, *string, + *[]byte, *[]int8, *[]int16, *[]int32, *[]int64, *[]int, *[]float32, *[]float64, *[]bool: + default: + targetVal = reflect.ValueOf(target).Elem() + targetType = targetVal.Type() + } // ReadData and validate header readHeader(f.readCtx) @@ -1154,17 +1181,38 @@ func Deserialize[T any](f *Fory, data []byte, target *T) error { return f.readCtx.CheckError() default: // Slow path: use serializer-based deserialization - targetVal := reflect.ValueOf(target).Elem() - targetType := targetVal.Type() - - // Get serializer for the target type + if !targetVal.IsValid() { + targetVal = reflect.ValueOf(target).Elem() + targetType = targetVal.Type() + } + if targetType.Kind() == reflect.Struct { + if typeInfo := f.readCtx.getTypeInfoByType(targetType); typeInfo != nil { + if structSer, ok := typeInfo.Serializer.(*structSerializer); ok { + structSer.readRoot(f.readCtx, targetVal) + return f.readCtx.CheckError() + } + } + } serializer, err := f.typeResolver.getSerializerByType(targetType, false) if err != nil { return fmt.Errorf("failed to get serializer for type %v: %w", targetType, err) } - - // Use Read to deserialize directly into target serializer.Read(f.readCtx, RefModeTracking, true, false, targetVal) return f.readCtx.CheckError() } } + +func (f *Fory) readRootValue(target reflect.Value) { + targetType := target.Type() + if targetType.Kind() == reflect.Struct { + if typeInfo := f.readCtx.getTypeInfoByType(targetType); typeInfo != nil { + if structSer, ok := typeInfo.Serializer.(*structSerializer); ok { + structSer.readRoot(f.readCtx, target) + } else { + typeInfo.Serializer.Read(f.readCtx, RefModeTracking, true, false, target) + } + return + } + } + f.readCtx.ReadValue(target, RefModeTracking, true) +} diff --git a/go/fory/graph_memory_budget_test.go b/go/fory/graph_memory_budget_test.go new file mode 100644 index 0000000000..81250c9693 --- /dev/null +++ b/go/fory/graph_memory_budget_test.go @@ -0,0 +1,282 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package fory + +import ( + "bytes" + "reflect" + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +type budgetItem struct { + A int32 +} + +type budgetSiblings struct { + A []string + B []string +} + +func graphOwnerSizeOf[T any]() int64 { + bytes := graphSizeOf[T]() + if bytes == 0 { + return 1 + } + return bytes +} + +func TestGraphMemoryBudgetConfig(t *testing.T) { + require.Equal(t, int64(128*1024*1024), New().config.MaxGraphMemoryBytes) + require.Equal(t, int64(123), New(WithMaxGraphMemoryBytes(123)).config.MaxGraphMemoryBytes) + require.Panics(t, func() { WithMaxGraphMemoryBytes(0) }) + require.Panics(t, func() { WithMaxGraphMemoryBytes(-2) }) +} + +func TestGraphMemoryBudgetFixedDefault(t *testing.T) { + writer := New(WithCompatible(false)) + value := []any{[]any{}, []any{}, []any{}} + data, err := writer.Serialize(value) + require.NoError(t, err) + + var out []any + err = New(WithCompatible(false)).Deserialize(data, &out) + require.NoError(t, err) + require.Len(t, out, len(value)) +} + +func TestGraphBudgetRootKinds(t *testing.T) { + writer := New(WithCompatible(false)) + values := make([]any, 12000) + for i := range values { + values[i] = []any{} + } + data, err := writer.Serialize(values) + require.NoError(t, err) + + var fromBytes []any + err = New(WithCompatible(false)).Deserialize(data, &fromBytes) + require.NoError(t, err) + require.Len(t, fromBytes, len(values)) + + var fromStream []any + err = New(WithCompatible(false)).DeserializeFromReader(bytes.NewReader(data), &fromStream) + require.NoError(t, err) + require.Len(t, fromStream, len(values)) +} + +func TestGraphMemoryBudgetBufferRoots(t *testing.T) { + writer := New(WithCompatible(false)) + value := []string{"a", "b"} + data, err := writer.Serialize(value) + require.NoError(t, err) + + reader := New(WithCompatible(false)) + var fromCallback []string + err = reader.DeserializeWithCallbackBuffers(NewByteBuffer(data), &fromCallback, nil) + require.NoError(t, err) + require.Equal(t, value, fromCallback) + + var fromBuffer []string + err = reader.DeserializeFrom(NewByteBuffer(data), &fromBuffer) + require.NoError(t, err) + require.Equal(t, value, fromBuffer) +} + +func TestGraphBudgetOverride(t *testing.T) { + writer := New(WithCompatible(false)) + values := make([]any, 12000) + for i := range values { + values[i] = []any{} + } + data, err := writer.Serialize(values) + require.NoError(t, err) + + var out []any + err = New(WithCompatible(false), WithMaxGraphMemoryBytes(4*1024*1024)).Deserialize(data, &out) + require.NoError(t, err) + require.Len(t, out, len(values)) +} + +func TestGraphBudgetCumulative(t *testing.T) { + data, err := New(WithCompatible(false)).Serialize([]any{}) + require.NoError(t, err) + var empty []any + err = New(WithCompatible(false), WithMaxGraphMemoryBytes(1)).Deserialize(data, &empty) + require.NoError(t, err) + require.Empty(t, empty) + + writer := New(WithCompatible(false)) + require.NoError(t, writer.RegisterStructByName(budgetSiblings{}, "test.BudgetSiblings")) + data, err = writer.Serialize(&budgetSiblings{A: []string{"a"}, B: []string{"b"}}) + require.NoError(t, err) + reader := New(WithCompatible(false), WithMaxGraphMemoryBytes(stringElementBytes)) + require.NoError(t, reader.RegisterStructByName(budgetSiblings{}, "test.BudgetSiblings")) + var out budgetSiblings + err = reader.Deserialize(data, &out) + require.Error(t, err) + require.Contains(t, err.Error(), "maxGraphMemoryBytes") + required := structGraphBytes(reflect.TypeOf(budgetSiblings{})) + 2*stringElementBytes + reader = New(WithCompatible(false), WithMaxGraphMemoryBytes(required)) + require.NoError(t, reader.RegisterStructByName(budgetSiblings{}, "test.BudgetSiblings")) + require.NoError(t, reader.Deserialize(data, &out)) + require.Equal(t, []string{"a"}, out.A) + require.Equal(t, []string{"b"}, out.B) +} + +func TestGraphMemoryBudgetMapAndOverflow(t *testing.T) { + data, err := New().Serialize(map[string]string{"k": "v"}) + require.NoError(t, err) + var out map[string]string + oneEntryBudget := graphSizeOf[string]() + graphSizeOf[string]() + err = New(WithMaxGraphMemoryBytes(oneEntryBudget-1)).Deserialize(data, &out) + require.Error(t, err) + require.Contains(t, err.Error(), "maxGraphMemoryBytes") + + ctx := NewReadContext(false) + require.False(t, ctx.ReserveGraphMemory(-1)) + require.Contains(t, ctx.CheckError().Error(), "non-negative") +} + +func TestGraphBudgetSlices(t *testing.T) { + data, err := New().Serialize([]string{"a"}) + require.NoError(t, err) + var stringsOut []string + err = New(WithMaxGraphMemoryBytes(graphSizeOf[string]()-1)).Deserialize(data, &stringsOut) + require.Error(t, err) + require.Contains(t, err.Error(), "maxGraphMemoryBytes") + + writer := New() + require.NoError(t, writer.RegisterStructByName(budgetItem{}, "test.BudgetItem")) + data, err = writer.Serialize([]budgetItem{{A: 1}}) + require.NoError(t, err) + reader := New(WithMaxGraphMemoryBytes(graphSizeOf[budgetItem]() - 1)) + require.NoError(t, reader.RegisterStructByName(budgetItem{}, "test.BudgetItem")) + var items []budgetItem + err = reader.Deserialize(data, &items) + require.Error(t, err) + require.Contains(t, err.Error(), "maxGraphMemoryBytes") + reader = New(WithMaxGraphMemoryBytes(graphSizeOf[budgetItem]())) + require.NoError(t, reader.RegisterStructByName(budgetItem{}, "test.BudgetItem")) + require.NoError(t, reader.Deserialize(data, &items)) + require.Equal(t, []budgetItem{{A: 1}}, items) +} + +func TestGraphMemoryBudgetStructOwners(t *testing.T) { + writer := New(WithCompatible(false)) + require.NoError(t, writer.RegisterStructByName(budgetItem{}, "test.BudgetItem")) + data, err := writer.Serialize(&budgetItem{A: 7}) + require.NoError(t, err) + + required := graphOwnerSizeOf[budgetItem]() + reader := New(WithCompatible(false), WithMaxGraphMemoryBytes(required-1)) + require.NoError(t, reader.RegisterStructByName(budgetItem{}, "test.BudgetItem")) + var out *budgetItem + err = reader.Deserialize(data, &out) + require.Error(t, err) + require.Contains(t, err.Error(), "maxGraphMemoryBytes") + + reader = New(WithCompatible(false), WithMaxGraphMemoryBytes(required)) + require.NoError(t, reader.RegisterStructByName(budgetItem{}, "test.BudgetItem")) + require.NoError(t, reader.Deserialize(data, &out)) + require.Equal(t, int32(7), out.A) + + reader = New(WithCompatible(false), WithMaxGraphMemoryBytes(required-1)) + require.NoError(t, reader.RegisterStructByName(budgetItem{}, "test.BudgetItem")) + var outValue budgetItem + err = reader.Deserialize(data, &outValue) + require.Error(t, err) + require.Contains(t, err.Error(), "maxGraphMemoryBytes") + + reader = New(WithCompatible(false), WithMaxGraphMemoryBytes(required)) + require.NoError(t, reader.RegisterStructByName(budgetItem{}, "test.BudgetItem")) + require.NoError(t, reader.Deserialize(data, &outValue)) + require.Equal(t, int32(7), outValue.A) + + reader = New(WithCompatible(false), WithMaxGraphMemoryBytes(required-1)) + require.NoError(t, reader.RegisterStructByName(budgetItem{}, "test.BudgetItem")) + err = reader.DeserializeFromReader(bytes.NewReader(data), &outValue) + require.Error(t, err) + require.Contains(t, err.Error(), "maxGraphMemoryBytes") + + reader = New(WithCompatible(false), WithMaxGraphMemoryBytes(required)) + require.NoError(t, reader.RegisterStructByName(budgetItem{}, "test.BudgetItem")) + require.NoError(t, reader.DeserializeFromReader(bytes.NewReader(data), &outValue)) + require.Equal(t, int32(7), outValue.A) + + reader = New(WithCompatible(false), WithMaxGraphMemoryBytes(required-1)) + require.NoError(t, reader.RegisterStructByName(budgetItem{}, "test.BudgetItem")) + err = reader.DeserializeFromStream(NewInputStream(bytes.NewReader(data)), &outValue) + require.Error(t, err) + require.Contains(t, err.Error(), "maxGraphMemoryBytes") + + reader = New(WithCompatible(false), WithMaxGraphMemoryBytes(required)) + require.NoError(t, reader.RegisterStructByName(budgetItem{}, "test.BudgetItem")) + require.NoError(t, reader.DeserializeFromStream(NewInputStream(bytes.NewReader(data)), &outValue)) + require.Equal(t, int32(7), outValue.A) +} + +func TestGraphBudgetSkipsDense(t *testing.T) { + f := New(WithMaxGraphMemoryBytes(1)) + + stringData, err := New().Serialize(strings.Repeat("x", 128)) + require.NoError(t, err) + var s string + require.NoError(t, f.Deserialize(stringData, &s)) + require.Len(t, s, 128) + + bytesData, err := New().Serialize([]byte{1, 2, 3, 4}) + require.NoError(t, err) + var b []byte + require.NoError(t, f.Deserialize(bytesData, &b)) + require.Equal(t, []byte{1, 2, 3, 4}, b) + + intsData, err := New().Serialize([]int32{1, 2, 3, 4}) + require.NoError(t, err) + var ints []int32 + require.NoError(t, f.Deserialize(intsData, &ints)) + require.Equal(t, []int32{1, 2, 3, 4}, ints) +} + +func TestGraphBudgetByteChecks(t *testing.T) { + buf := NewByteBuffer(nil) + buf.WriteByte_(XLangFlag) + buf.WriteInt8(NotNullValueFlag) + buf.WriteUint8(uint8(LIST)) + buf.WriteLength(1024) + buf.WriteInt8(int8(CollectionIsSameType)) + buf.WriteUint8(uint8(STRING)) + + var stringsOut []string + err := New(WithMaxGraphMemoryBytes(8*1024*1024)).Deserialize(buf.Bytes(), &stringsOut) + require.Error(t, err) + require.Contains(t, err.Error(), "buffer out of bound") + + buf = NewByteBuffer(nil) + buf.WriteByte_(XLangFlag) + buf.WriteInt8(NotNullValueFlag) + buf.WriteUint8(uint8(INT32_ARRAY)) + buf.WriteLength(4096) + + var ints []int32 + err = New(WithMaxGraphMemoryBytes(8*1024*1024)).Deserialize(buf.Bytes(), &ints) + require.Error(t, err) + require.Contains(t, err.Error(), "buffer out of bound") +} diff --git a/go/fory/map.go b/go/fory/map.go index 8b1d82cc95..9761efc2a8 100644 --- a/go/fory/map.go +++ b/go/fory/map.go @@ -303,6 +303,24 @@ func (s mapSerializer) ReadData(ctx *ReadContext, value reflect.Value) { iface := reflect.TypeOf((*any)(nil)).Elem() mapType = reflect.MapOf(iface, iface) } + keyBytes := int64(mapType.Key().Size()) + valueBytes := int64(mapType.Elem().Size()) + elemBytes := keyBytes + valueBytes + if elemBytes < keyBytes { + ctx.SetError(DeserializationErrorf("map entry size overflows: key=%d value=%d", keyBytes, valueBytes)) + return + } + if size < 0 { + ctx.SetError(DeserializationErrorf("negative graph element count: %d", size)) + return + } + if int64(size) > maxGraphCount(elemBytes) { + ctx.SetError(DeserializationErrorf("graph memory estimate overflows: length=%d elementBytes=%d", size, elemBytes)) + return + } + if !ctx.ReserveGraphMemory(int64(size) * elemBytes) { + return + } if size == 0 { if value.IsNil() { value.Set(reflect.MakeMap(mapType)) diff --git a/go/fory/map_primitive.go b/go/fory/map_primitive.go index d520e97bd5..ee98f310b6 100644 --- a/go/fory/map_primitive.go +++ b/go/fory/map_primitive.go @@ -25,6 +25,27 @@ import ( // Optimized map serializers for common types // ============================================================================ +var ( + stringStringMapElemBytes = stringElementBytes + stringElementBytes + stringStringMapMaxLength = maxGraphCount(stringStringMapElemBytes) + stringInt64MapElemBytes = stringElementBytes + graphSizeOf[int64]() + stringInt64MapMaxLength = maxGraphCount(stringInt64MapElemBytes) + stringInt32MapElemBytes = stringElementBytes + graphSizeOf[int32]() + stringInt32MapMaxLength = maxGraphCount(stringInt32MapElemBytes) + stringIntMapElemBytes = stringElementBytes + graphSizeOf[int]() + stringIntMapMaxLength = maxGraphCount(stringIntMapElemBytes) + stringFloat64MapElemBytes = stringElementBytes + graphSizeOf[float64]() + stringFloat64MapMaxLength = maxGraphCount(stringFloat64MapElemBytes) + stringBoolMapElemBytes = stringElementBytes + graphSizeOf[bool]() + stringBoolMapMaxLength = maxGraphCount(stringBoolMapElemBytes) + int32Int32MapElemBytes = graphSizeOf[int32]() + graphSizeOf[int32]() + int32Int32MapMaxLength = maxGraphCount(int32Int32MapElemBytes) + int64Int64MapElemBytes = graphSizeOf[int64]() + graphSizeOf[int64]() + int64Int64MapMaxLength = maxGraphCount(int64Int64MapElemBytes) + intIntMapElemBytes = graphSizeOf[int]() + graphSizeOf[int]() + intIntMapMaxLength = maxGraphCount(intIntMapElemBytes) +) + // writeMapStringString writes map[string]string using chunk protocol // When hasGenerics=true, element types are known so we set DECL_TYPE flags and skip type info func writeMapStringString(buf *ByteBuffer, m map[string]string, hasGenerics bool) { @@ -68,10 +89,24 @@ func writeMapStringString(buf *ByteBuffer, m map[string]string, hasGenerics bool } } -func readTypedMapSize(ctx *ReadContext) (int, bool) { +func readTypedMapSize(ctx *ReadContext, elemBytes int64, maxLength int64) (int, bool) { size := ctx.ReadCollectionLength() - if size == 0 || ctx.HasError() { - return size, false + if ctx.HasError() { + return 0, false + } + if size < 0 { + ctx.SetError(DeserializationErrorf("negative graph element count: %d", size)) + return 0, false + } + if int64(size) > maxLength { + ctx.SetError(DeserializationErrorf("graph memory estimate overflows: length=%d elementBytes=%d", size, elemBytes)) + return 0, false + } + if !ctx.ReserveGraphMemory(int64(size) * elemBytes) { + return 0, false + } + if size == 0 { + return size, true } if !ctx.Buffer().CheckReadable(size, ctx.Err()) { return 0, false @@ -83,12 +118,11 @@ func readTypedMapSize(ctx *ReadContext) (int, bool) { func readMapStringString(ctx *ReadContext) map[string]string { err := ctx.Err() buf := ctx.Buffer() - size, ok := readTypedMapSize(ctx) - result := make(map[string]string) + size, ok := readTypedMapSize(ctx, stringStringMapElemBytes, stringStringMapMaxLength) if !ok { - return result + return nil } - result = make(map[string]string, size) + result := make(map[string]string, size) for size > 0 { chunkHeader := buf.ReadUint8(err) @@ -171,12 +205,11 @@ func writeMapStringInt64(buf *ByteBuffer, m map[string]int64, hasGenerics bool) func readMapStringInt64(ctx *ReadContext) map[string]int64 { err := ctx.Err() buf := ctx.Buffer() - size, ok := readTypedMapSize(ctx) - result := make(map[string]int64) + size, ok := readTypedMapSize(ctx, stringInt64MapElemBytes, stringInt64MapMaxLength) if !ok { - return result + return nil } - result = make(map[string]int64, size) + result := make(map[string]int64, size) for size > 0 { chunkHeader := buf.ReadUint8(err) @@ -256,12 +289,11 @@ func writeMapStringInt32(buf *ByteBuffer, m map[string]int32, hasGenerics bool) func readMapStringInt32(ctx *ReadContext) map[string]int32 { err := ctx.Err() buf := ctx.Buffer() - size, ok := readTypedMapSize(ctx) - result := make(map[string]int32) + size, ok := readTypedMapSize(ctx, stringInt32MapElemBytes, stringInt32MapMaxLength) if !ok { - return result + return nil } - result = make(map[string]int32, size) + result := make(map[string]int32, size) for size > 0 { chunkHeader := buf.ReadUint8(err) @@ -341,12 +373,11 @@ func writeMapStringInt(buf *ByteBuffer, m map[string]int, hasGenerics bool) { func readMapStringInt(ctx *ReadContext) map[string]int { err := ctx.Err() buf := ctx.Buffer() - size, ok := readTypedMapSize(ctx) - result := make(map[string]int) + size, ok := readTypedMapSize(ctx, stringIntMapElemBytes, stringIntMapMaxLength) if !ok { - return result + return nil } - result = make(map[string]int, size) + result := make(map[string]int, size) for size > 0 { chunkHeader := buf.ReadUint8(err) @@ -426,12 +457,11 @@ func writeMapStringFloat64(buf *ByteBuffer, m map[string]float64, hasGenerics bo func readMapStringFloat64(ctx *ReadContext) map[string]float64 { err := ctx.Err() buf := ctx.Buffer() - size, ok := readTypedMapSize(ctx) - result := make(map[string]float64) + size, ok := readTypedMapSize(ctx, stringFloat64MapElemBytes, stringFloat64MapMaxLength) if !ok { - return result + return nil } - result = make(map[string]float64, size) + result := make(map[string]float64, size) for size > 0 { chunkHeader := buf.ReadUint8(err) @@ -511,12 +541,11 @@ func writeMapStringBool(buf *ByteBuffer, m map[string]bool, hasGenerics bool) { func readMapStringBool(ctx *ReadContext) map[string]bool { err := ctx.Err() buf := ctx.Buffer() - size, ok := readTypedMapSize(ctx) - result := make(map[string]bool) + size, ok := readTypedMapSize(ctx, stringBoolMapElemBytes, stringBoolMapMaxLength) if !ok { - return result + return nil } - result = make(map[string]bool, size) + result := make(map[string]bool, size) for size > 0 { chunkHeader := buf.ReadUint8(err) @@ -600,12 +629,11 @@ func writeMapInt32Int32(buf *ByteBuffer, m map[int32]int32, hasGenerics bool) { func readMapInt32Int32(ctx *ReadContext) map[int32]int32 { err := ctx.Err() buf := ctx.Buffer() - size, ok := readTypedMapSize(ctx) - result := make(map[int32]int32) + size, ok := readTypedMapSize(ctx, int32Int32MapElemBytes, int32Int32MapMaxLength) if !ok { - return result + return nil } - result = make(map[int32]int32, size) + result := make(map[int32]int32, size) for size > 0 { chunkHeader := buf.ReadUint8(err) @@ -685,12 +713,11 @@ func writeMapInt64Int64(buf *ByteBuffer, m map[int64]int64, hasGenerics bool) { func readMapInt64Int64(ctx *ReadContext) map[int64]int64 { err := ctx.Err() buf := ctx.Buffer() - size, ok := readTypedMapSize(ctx) - result := make(map[int64]int64) + size, ok := readTypedMapSize(ctx, int64Int64MapElemBytes, int64Int64MapMaxLength) if !ok { - return result + return nil } - result = make(map[int64]int64, size) + result := make(map[int64]int64, size) for size > 0 { chunkHeader := buf.ReadUint8(err) @@ -770,12 +797,11 @@ func writeMapIntInt(buf *ByteBuffer, m map[int]int, hasGenerics bool) { func readMapIntInt(ctx *ReadContext) map[int]int { err := ctx.Err() buf := ctx.Buffer() - size, ok := readTypedMapSize(ctx) - result := make(map[int]int) + size, ok := readTypedMapSize(ctx, intIntMapElemBytes, intIntMapMaxLength) if !ok { - return result + return nil } - result = make(map[int]int, size) + result := make(map[int]int, size) for size > 0 { chunkHeader := buf.ReadUint8(err) @@ -831,12 +857,12 @@ func (s stringStringMapSerializer) Write(ctx *WriteContext, refMode RefMode, wri } func (s stringStringMapSerializer) ReadData(ctx *ReadContext, value reflect.Value) { - if value.IsNil() { - value.Set(reflect.MakeMap(value.Type())) - } - ctx.RefResolver().Reference(value) result := readMapStringString(ctx) + if ctx.HasError() { + return + } value.Set(reflect.ValueOf(result)) + ctx.RefResolver().Reference(value) } func (s stringStringMapSerializer) Read(ctx *ReadContext, refMode RefMode, readType bool, hasGenerics bool, value reflect.Value) { @@ -866,12 +892,12 @@ func (s stringInt64MapSerializer) Write(ctx *WriteContext, refMode RefMode, writ } func (s stringInt64MapSerializer) ReadData(ctx *ReadContext, value reflect.Value) { - if value.IsNil() { - value.Set(reflect.MakeMap(value.Type())) - } - ctx.RefResolver().Reference(value) result := readMapStringInt64(ctx) + if ctx.HasError() { + return + } value.Set(reflect.ValueOf(result)) + ctx.RefResolver().Reference(value) } func (s stringInt64MapSerializer) Read(ctx *ReadContext, refMode RefMode, readType bool, hasGenerics bool, value reflect.Value) { @@ -901,12 +927,12 @@ func (s stringIntMapSerializer) Write(ctx *WriteContext, refMode RefMode, writeT } func (s stringIntMapSerializer) ReadData(ctx *ReadContext, value reflect.Value) { - if value.IsNil() { - value.Set(reflect.MakeMap(value.Type())) - } - ctx.RefResolver().Reference(value) result := readMapStringInt(ctx) + if ctx.HasError() { + return + } value.Set(reflect.ValueOf(result)) + ctx.RefResolver().Reference(value) } func (s stringIntMapSerializer) Read(ctx *ReadContext, refMode RefMode, readType bool, hasGenerics bool, value reflect.Value) { @@ -936,12 +962,12 @@ func (s stringFloat64MapSerializer) Write(ctx *WriteContext, refMode RefMode, wr } func (s stringFloat64MapSerializer) ReadData(ctx *ReadContext, value reflect.Value) { - if value.IsNil() { - value.Set(reflect.MakeMap(value.Type())) - } - ctx.RefResolver().Reference(value) result := readMapStringFloat64(ctx) + if ctx.HasError() { + return + } value.Set(reflect.ValueOf(result)) + ctx.RefResolver().Reference(value) } func (s stringFloat64MapSerializer) Read(ctx *ReadContext, refMode RefMode, readType bool, hasGenerics bool, value reflect.Value) { @@ -971,12 +997,12 @@ func (s stringBoolMapSerializer) Write(ctx *WriteContext, refMode RefMode, write } func (s stringBoolMapSerializer) ReadData(ctx *ReadContext, value reflect.Value) { - if value.IsNil() { - value.Set(reflect.MakeMap(value.Type())) - } - ctx.RefResolver().Reference(value) result := readMapStringBool(ctx) + if ctx.HasError() { + return + } value.Set(reflect.ValueOf(result)) + ctx.RefResolver().Reference(value) } func (s stringBoolMapSerializer) Read(ctx *ReadContext, refMode RefMode, readType bool, hasGenerics bool, value reflect.Value) { @@ -1006,12 +1032,12 @@ func (s int32Int32MapSerializer) Write(ctx *WriteContext, refMode RefMode, write } func (s int32Int32MapSerializer) ReadData(ctx *ReadContext, value reflect.Value) { - if value.IsNil() { - value.Set(reflect.MakeMap(value.Type())) - } - ctx.RefResolver().Reference(value) result := readMapInt32Int32(ctx) + if ctx.HasError() { + return + } value.Set(reflect.ValueOf(result)) + ctx.RefResolver().Reference(value) } func (s int32Int32MapSerializer) Read(ctx *ReadContext, refMode RefMode, readType bool, hasGenerics bool, value reflect.Value) { @@ -1041,12 +1067,12 @@ func (s int64Int64MapSerializer) Write(ctx *WriteContext, refMode RefMode, write } func (s int64Int64MapSerializer) ReadData(ctx *ReadContext, value reflect.Value) { - if value.IsNil() { - value.Set(reflect.MakeMap(value.Type())) - } - ctx.RefResolver().Reference(value) result := readMapInt64Int64(ctx) + if ctx.HasError() { + return + } value.Set(reflect.ValueOf(result)) + ctx.RefResolver().Reference(value) } func (s int64Int64MapSerializer) Read(ctx *ReadContext, refMode RefMode, readType bool, hasGenerics bool, value reflect.Value) { @@ -1076,12 +1102,12 @@ func (s intIntMapSerializer) Write(ctx *WriteContext, refMode RefMode, writeType } func (s intIntMapSerializer) ReadData(ctx *ReadContext, value reflect.Value) { - if value.IsNil() { - value.Set(reflect.MakeMap(value.Type())) - } - ctx.RefResolver().Reference(value) result := readMapIntInt(ctx) + if ctx.HasError() { + return + } value.Set(reflect.ValueOf(result)) + ctx.RefResolver().Reference(value) } func (s intIntMapSerializer) Read(ctx *ReadContext, refMode RefMode, readType bool, hasGenerics bool, value reflect.Value) { diff --git a/go/fory/pointer.go b/go/fory/pointer.go index b30a15f53d..199ea54e0f 100644 --- a/go/fory/pointer.go +++ b/go/fory/pointer.go @@ -140,6 +140,9 @@ func (s *ptrToValueSerializer) ReadData(ctx *ReadContext, value reflect.Value) { var newVal reflect.Value if value.IsNil() { // Allocate new value + if !ctx.ReserveGraphMemory(structGraphBytes(value.Type().Elem())) { + return + } newVal = reflect.New(value.Type().Elem()) value.Set(newVal) } else { @@ -195,6 +198,9 @@ func (s *ptrToValueSerializer) Read(ctx *ReadContext, refMode RefMode, readType if structSer, ok := typeInfo.Serializer.(*structSerializer); ok && len(structSer.fieldDefs) > 0 { // Allocate the pointer value if needed if value.IsNil() { + if !ctx.ReserveGraphMemory(structGraphBytes(value.Type().Elem())) { + return + } value.Set(reflect.New(value.Type().Elem())) } ctx.RefResolver().Reference(value) diff --git a/go/fory/reader.go b/go/fory/reader.go index 3985bb4e2b..c7a87aa4d1 100644 --- a/go/fory/reader.go +++ b/go/fory/reader.go @@ -29,21 +29,49 @@ import ( // ReadContext holds all state needed during deserialization. type ReadContext struct { - buffer *ByteBuffer - refReader *RefReader - trackRef bool // Cached flag to avoid indirection - xlang bool // Cross-language serialization mode - rootHeader byte - compatible bool // Schema evolution compatibility mode - typeResolver *TypeResolver // For complex type deserialization - refResolver *RefResolver // For reference tracking in native-mode paths - outOfBandBuffers []*ByteBuffer // Out-of-band buffers for deserialization - outOfBandIndex int // Current index into out-of-band buffers - depth int // Current nesting depth for cycle detection - maxDepth int // Maximum allowed nesting depth - err Error // Accumulated error state for deferred checking - lastTypePtr uintptr - lastTypeInfo *TypeInfo + buffer *ByteBuffer + refReader *RefReader + trackRef bool // Cached flag to avoid indirection + xlang bool // Cross-language serialization mode + rootHeader byte + compatible bool // Schema evolution compatibility mode + typeResolver *TypeResolver // For complex type deserialization + refResolver *RefResolver // For reference tracking in native-mode paths + outOfBandBuffers []*ByteBuffer // Out-of-band buffers for deserialization + outOfBandIndex int // Current index into out-of-band buffers + depth int // Current nesting depth for cycle detection + maxDepth int // Maximum allowed nesting depth + err Error // Accumulated error state for deferred checking + lastTypePtr uintptr + lastTypeInfo *TypeInfo + remainingGraphMemoryBytes int64 +} + +var referenceSlotBytes = int64(unsafe.Sizeof(uintptr(0))) +var stringElementBytes = graphSizeOf[string]() +var stringMaxLength = maxGraphCount(stringElementBytes) + +func graphSizeOf[T any]() int64 { + var v T + return int64(unsafe.Sizeof(v)) +} + +func maxGraphCount(elemBytes int64) int64 { + if elemBytes == 0 { + return MaxInt64 + } + return MaxInt64 / elemBytes +} + +func structGraphBytes(type_ reflect.Type) int64 { + if type_.Kind() != reflect.Struct { + return 0 + } + bytes := int64(type_.Size()) + if bytes == 0 { + return 1 + } + return bytes } // IsXlang returns whether cross-language serialization mode is enabled @@ -67,6 +95,8 @@ func (c *ReadContext) Reset() { c.outOfBandBuffers = nil c.outOfBandIndex = 0 c.err = Error{} // Clear error state + // Graph budget state is overwritten by each root read before deserialization. + // Avoid extra reset stores on the successful root hot path. if c.refResolver != nil { c.refResolver.resetRead() } @@ -75,6 +105,27 @@ func (c *ReadContext) Reset() { } } +// ReserveGraphMemory reserves raw estimated graph-owner bytes. +func (c *ReadContext) ReserveGraphMemory(bytes int64) bool { + if uint64(bytes) <= uint64(c.remainingGraphMemoryBytes) { + c.remainingGraphMemoryBytes -= bytes + return true + } + return c.rejectGraphMemoryReservation(bytes) +} + +//go:noinline +func (c *ReadContext) rejectGraphMemoryReservation(bytes int64) bool { + if bytes < 0 { + c.SetError(DeserializationErrorf("estimated graph memory must be non-negative, got %d bytes", bytes)) + return false + } + c.SetError(DeserializationErrorf( + "estimated graph memory request %d bytes exceeds maxGraphMemoryBytes remaining budget %d bytes", + bytes, c.remainingGraphMemoryBytes)) + return false +} + // SetData sets new input data (for buffer reuse) // Reuses existing buffer to avoid allocation func (c *ReadContext) SetData(data []byte) { @@ -536,7 +587,45 @@ func (c *ReadContext) ReadStringSlice(refMode RefMode, readType bool) []string { if readType { _ = c.buffer.ReadUint8(err) } - return ReadStringSlice(c.buffer, err) + buf := c.buffer + length := buf.ReadLength(err) + if c.HasError() { + return nil + } + if length < 0 { + c.SetError(DeserializationErrorf("negative graph element count: %d", length)) + return nil + } + if int64(length) > stringMaxLength { + c.SetError(DeserializationErrorf("graph memory estimate overflows: length=%d elementBytes=%d", length, stringElementBytes)) + return nil + } + if !c.ReserveGraphMemory(int64(length) * stringElementBytes) { + return nil + } + if length == 0 { + return make([]string, 0) + } + collectFlag := buf.ReadInt8(err) + if (collectFlag&CollectionIsSameType) != 0 && (collectFlag&CollectionIsDeclElementType) == 0 { + _ = buf.ReadUint8(err) + } + if c.HasError() || !buf.CheckReadable(length, err) { + return nil + } + result := make([]string, length) + trackRefs := (collectFlag & CollectionTrackingRef) != 0 + hasNull := (collectFlag & CollectionHasNull) != 0 + for i := 0; i < length; i++ { + if trackRefs || hasNull { + rf := buf.ReadInt8(err) + if rf == NullFlag { + continue + } + } + result[i] = readString(buf, err) + } + return result } // ReadStringStringMap reads map[string]string with optional ref/type info @@ -720,15 +809,15 @@ func (c *ReadContext) ReadValue(value reflect.Value, refMode RefMode, readType b c.SetError(DeserializationError("invalid reflect.Value")) return } - + valueType := value.Type() // Handle array targets (arrays are serialized as slices) - if value.Type().Kind() == reflect.Array { + if valueType.Kind() == reflect.Array { c.ReadArrayValue(value, refMode, readType) return } // For any types, we need to read the actual type from the buffer first - if value.Type().Kind() == reflect.Interface { + if valueType.Kind() == reflect.Interface { // Handle ref tracking based on refMode var refID int32 = int32(NotNullValueFlag) if refMode == RefModeTracking { @@ -795,9 +884,15 @@ func (c *ReadContext) ReadValue(value reflect.Value, refMode RefMode, readType b } else if isNamedStruct { // For named struct types, create a pointer to support circular references // Create *A instead of A + if !c.ReserveGraphMemory(structGraphBytes(actualType)) { + return + } newValue = reflect.New(actualType) valueToSet = newValue } else { + if !c.ReserveGraphMemory(structGraphBytes(actualType)) { + return + } newValue = reflect.New(actualType).Elem() valueToSet = newValue } @@ -831,14 +926,13 @@ func (c *ReadContext) ReadValue(value reflect.Value, refMode RefMode, readType b return } - if typeInfo := c.getTypeInfoByType(value.Type()); typeInfo != nil && typeInfo.Serializer != nil { + if typeInfo := c.getTypeInfoByType(valueType); typeInfo != nil && typeInfo.Serializer != nil { typeInfo.Serializer.Read(c, refMode, readType, false, value) return } // For struct types, use optimized ReadStruct path when using full ref tracking and type info. // Unions use a custom serializer and must bypass ReadStruct. - valueType := value.Type() if refMode == RefModeTracking && readType && !c.typeResolver.IsUnionType(valueType) { if valueType.Kind() == reflect.Struct { c.ReadStruct(value) @@ -930,12 +1024,18 @@ func (c *ReadContext) ReadStruct(value reflect.Value) { var readTarget reflect.Value if isPtr { if value.IsNil() { + if !c.ReserveGraphMemory(structGraphBytes(structType)) { + return + } value.Set(reflect.New(structType)) } readTarget = value.Elem() // Register reference before reading (for circular references) refResolver.SetReadObject(refID, value) } else { + if !c.ReserveGraphMemory(structGraphBytes(structType)) { + return + } readTarget = value // For non-pointer structs, register a pointer to enable circular ref resolution if value.CanAddr() { diff --git a/go/fory/set.go b/go/fory/set.go index 1a42739547..7ff903fcf2 100644 --- a/go/fory/set.go +++ b/go/fory/set.go @@ -318,6 +318,13 @@ func (s setSerializer) ReadData(ctx *ReadContext, value reflect.Value) { return } if length == 0 { + keyBytes := int64(type_.Key().Size()) + valueBytes := int64(type_.Elem().Size()) + elemBytes := keyBytes + valueBytes + if elemBytes < keyBytes { + ctx.SetError(DeserializationErrorf("map entry size overflows: key=%d value=%d", keyBytes, valueBytes)) + return + } // Initialize empty set if length is 0 value.Set(reflect.MakeMap(type_)) return @@ -356,6 +363,24 @@ func (s setSerializer) ReadData(ctx *ReadContext, value reflect.Value) { if !buf.CheckReadable(length, err) { return } + keyBytes := int64(type_.Key().Size()) + valueBytes := int64(type_.Elem().Size()) + elemBytes := keyBytes + valueBytes + if elemBytes < keyBytes { + ctx.SetError(DeserializationErrorf("map entry size overflows: key=%d value=%d", keyBytes, valueBytes)) + return + } + if length < 0 { + ctx.SetError(DeserializationErrorf("negative graph element count: %d", length)) + return + } + if int64(length) > maxGraphCount(elemBytes) { + ctx.SetError(DeserializationErrorf("graph memory estimate overflows: length=%d elementBytes=%d", length, elemBytes)) + return + } + if !ctx.ReserveGraphMemory(int64(length) * elemBytes) { + return + } // Initialize set if nil if value.IsNil() { diff --git a/go/fory/slice.go b/go/fory/slice.go index 6d941b3bf6..6a59df49a1 100644 --- a/go/fory/slice.go +++ b/go/fory/slice.go @@ -124,6 +124,8 @@ type sliceSerializer struct { type_ reflect.Type elemSerializer Serializer referencable bool + elemBytes int64 + maxLength int64 } // newSliceSerializer creates a sliceSerializer for slices with concrete element types. @@ -144,10 +146,13 @@ func newSliceSerializer(type_ reflect.Type, elemSerializer Serializer, xlang boo reflect.Uint8, reflect.Float32, reflect.Float64: return nil, fmt.Errorf("sliceSerializer does not support primitive element type %v: use dedicated primitive slice serializer", type_) } + elemBytes := int64(elem.Size()) return &sliceSerializer{ type_: type_, elemSerializer: elemSerializer, referencable: isRefType(elem, xlang), + elemBytes: elemBytes, + maxLength: maxGraphCount(elemBytes), }, nil } @@ -314,6 +319,19 @@ func (s *sliceSerializer) ReadData(ctx *ReadContext, value reflect.Value) { } return } + if !isArrayType { + if length < 0 { + ctx.SetError(DeserializationErrorf("negative graph element count: %d", length)) + return + } + if int64(length) > s.maxLength { + ctx.SetError(DeserializationErrorf("graph memory estimate overflows: length=%d elementBytes=%d", length, s.elemBytes)) + return + } + if !ctx.ReserveGraphMemory(int64(length) * s.elemBytes) { + return + } + } // ReadData collection flags collectFlag := buf.ReadInt8(ctxErr) diff --git a/go/fory/slice_dyn.go b/go/fory/slice_dyn.go index 907fcddd4f..076ada5071 100644 --- a/go/fory/slice_dyn.go +++ b/go/fory/slice_dyn.go @@ -31,35 +31,43 @@ type sliceDynSerializer struct { elemType reflect.Type isInterfaceElem bool isPointerElem bool + elemBytes int64 + maxLength int64 } // newSliceDynSerializer creates a new sliceDynSerializer. // This serializer is ONLY for slices with interface or pointer to interface element types. // For other slice types, use sliceSerializer instead. -func newSliceDynSerializer(elemType reflect.Type) (sliceDynSerializer, error) { +func newSliceDynSerializer(elemType reflect.Type) (*sliceDynSerializer, error) { // Nil element type is allowed for fully dynamic slices (e.g., []any) if elemType == nil { - return sliceDynSerializer{ + elemBytes := graphSizeOf[any]() + return &sliceDynSerializer{ isInterfaceElem: true, + elemBytes: elemBytes, + maxLength: maxGraphCount(elemBytes), }, nil } // Validate element type is interface or pointer to interface isInterface := elemType.Kind() == reflect.Interface isPointerToInterface := elemType.Kind() == reflect.Ptr && elemType.Elem().Kind() == reflect.Interface if !isInterface && !isPointerToInterface { - return sliceDynSerializer{}, fmt.Errorf( + return nil, fmt.Errorf( "sliceDynSerializer only supports interface or pointer to interface element types, got %v; use sliceSerializer for other types", elemType) } - return sliceDynSerializer{ + elemBytes := int64(elemType.Size()) + return &sliceDynSerializer{ elemType: elemType, isInterfaceElem: isInterface, isPointerElem: isPointerToInterface, + elemBytes: elemBytes, + maxLength: maxGraphCount(elemBytes), }, nil } // mustNewSliceDynSerializer is like newSliceDynSerializer but panics on error. // Used for initialization code where the element type is known to be valid. -func mustNewSliceDynSerializer(elemType reflect.Type) sliceDynSerializer { +func mustNewSliceDynSerializer(elemType reflect.Type) *sliceDynSerializer { s, err := newSliceDynSerializer(elemType) if err != nil { panic(err) @@ -67,7 +75,7 @@ func mustNewSliceDynSerializer(elemType reflect.Type) sliceDynSerializer { return s } -func (s sliceDynSerializer) Write(ctx *WriteContext, refMode RefMode, writeType bool, hasGenerics bool, value reflect.Value) { +func (s *sliceDynSerializer) Write(ctx *WriteContext, refMode RefMode, writeType bool, hasGenerics bool, value reflect.Value) { done := writeSliceRefAndType(ctx, refMode, writeType, value, LIST) if done || ctx.HasError() { return @@ -75,7 +83,7 @@ func (s sliceDynSerializer) Write(ctx *WriteContext, refMode RefMode, writeType s.WriteData(ctx, value) } -func (s sliceDynSerializer) WriteData(ctx *WriteContext, value reflect.Value) { +func (s *sliceDynSerializer) WriteData(ctx *WriteContext, value reflect.Value) { buf := ctx.Buffer() // Get slice length and handle empty slice case length := value.Len() @@ -103,7 +111,7 @@ func (s sliceDynSerializer) WriteData(ctx *WriteContext, value reflect.Value) { // - Type consistency flags // - Element type information (if homogeneous) // Returns pointer to TypeInfo to avoid copy overhead. -func (s sliceDynSerializer) writeHeader(ctx *WriteContext, buf *ByteBuffer, value reflect.Value) (byte, *TypeInfo) { +func (s *sliceDynSerializer) writeHeader(ctx *WriteContext, buf *ByteBuffer, value reflect.Value) (byte, *TypeInfo) { collectFlag := CollectionDefaultFlag var elemTypeInfo *TypeInfo hasNull := false @@ -161,7 +169,7 @@ func (s sliceDynSerializer) writeHeader(ctx *WriteContext, buf *ByteBuffer, valu } // writeSameType efficiently serializes a slice where all elements share the same type -func (s sliceDynSerializer) writeSameType( +func (s *sliceDynSerializer) writeSameType( ctx *WriteContext, buf *ByteBuffer, value reflect.Value, typeInfo *TypeInfo, flag byte) { if typeInfo == nil { return @@ -194,7 +202,7 @@ func (s sliceDynSerializer) writeSameType( } // writeDifferentTypes handles serialization of slices with mixed element types -func (s sliceDynSerializer) writeDifferentTypes(ctx *WriteContext, buf *ByteBuffer, value reflect.Value, flag byte) { +func (s *sliceDynSerializer) writeDifferentTypes(ctx *WriteContext, buf *ByteBuffer, value reflect.Value, flag byte) { trackRefs := (flag & CollectionTrackingRef) != 0 hasNull := (flag & CollectionHasNull) != 0 @@ -246,7 +254,7 @@ func (s sliceDynSerializer) writeDifferentTypes(ctx *WriteContext, buf *ByteBuff } } -func (s sliceDynSerializer) Read(ctx *ReadContext, refMode RefMode, readType bool, hasGenerics bool, value reflect.Value) { +func (s *sliceDynSerializer) Read(ctx *ReadContext, refMode RefMode, readType bool, hasGenerics bool, value reflect.Value) { done, typeId := readSliceRefAndType(ctx, refMode, readType, value) if done || ctx.HasError() { return @@ -258,11 +266,11 @@ func (s sliceDynSerializer) Read(ctx *ReadContext, refMode RefMode, readType boo s.ReadData(ctx, value) } -func (s sliceDynSerializer) ReadData(ctx *ReadContext, value reflect.Value) { +func (s *sliceDynSerializer) ReadData(ctx *ReadContext, value reflect.Value) { s.readData(ctx, value, -1) } -func (s sliceDynSerializer) readData(ctx *ReadContext, value reflect.Value, expectedLength int) { +func (s *sliceDynSerializer) readData(ctx *ReadContext, value reflect.Value, expectedLength int) { buf := ctx.Buffer() ctxErr := ctx.Err() length := ctx.ReadCollectionLength() @@ -274,10 +282,24 @@ func (s sliceDynSerializer) readData(ctx *ReadContext, value reflect.Value, expe ctx.SetError(DeserializationErrorf("array length %d does not match serialized length %d", expectedLength, length)) return } + allocatedByCaller := expectedLength >= 0 if length == 0 { value.Set(reflect.MakeSlice(sliceType, 0, 0)) return } + if !allocatedByCaller { + if length < 0 { + ctx.SetError(DeserializationErrorf("negative graph element count: %d", length)) + return + } + if int64(length) > s.maxLength { + ctx.SetError(DeserializationErrorf("graph memory estimate overflows: length=%d elementBytes=%d", length, s.elemBytes)) + return + } + if !ctx.ReserveGraphMemory(int64(length) * s.elemBytes) { + return + } + } collectFlag := buf.ReadInt8(ctxErr) if ctx.HasError() { @@ -305,7 +327,9 @@ func (s sliceDynSerializer) readData(ctx *ReadContext, value reflect.Value, expe if !buf.CheckReadable(length, ctxErr) { return } - value.Set(reflect.MakeSlice(sliceType, length, length)) + if !allocatedByCaller { + value.Set(reflect.MakeSlice(sliceType, length, length)) + } ctx.RefResolver().Reference(value) s.readSameType(ctx, buf, value, elemType, elemSerializer, collectFlag, length) return @@ -313,18 +337,20 @@ func (s sliceDynSerializer) readData(ctx *ReadContext, value reflect.Value, expe if !buf.CheckReadable(length, ctxErr) { return } - value.Set(reflect.MakeSlice(sliceType, length, length)) + if !allocatedByCaller { + value.Set(reflect.MakeSlice(sliceType, length, length)) + } ctx.RefResolver().Reference(value) s.readDifferentTypes(ctx, buf, value, collectFlag, length) } -func (s sliceDynSerializer) ReadWithTypeInfo(ctx *ReadContext, refMode RefMode, typeInfo *TypeInfo, value reflect.Value) { +func (s *sliceDynSerializer) ReadWithTypeInfo(ctx *ReadContext, refMode RefMode, typeInfo *TypeInfo, value reflect.Value) { // typeInfo is already read, don't read it again s.Read(ctx, refMode, false, false, value) } // readSameType handles deserialization of slices where all elements share the same type -func (s sliceDynSerializer) readSameType(ctx *ReadContext, buf *ByteBuffer, value reflect.Value, elemType reflect.Type, serializer Serializer, flag int8, length int) { +func (s *sliceDynSerializer) readSameType(ctx *ReadContext, buf *ByteBuffer, value reflect.Value, elemType reflect.Type, serializer Serializer, flag int8, length int) { trackRefs := (flag & CollectionTrackingRef) != 0 hasNull := (flag & CollectionHasNull) != 0 ctxErr := ctx.Err() @@ -402,7 +428,7 @@ func (s sliceDynSerializer) readSameType(ctx *ReadContext, buf *ByteBuffer, valu } // readDifferentTypes handles deserialization of slices with mixed element types -func (s sliceDynSerializer) readDifferentTypes( +func (s *sliceDynSerializer) readDifferentTypes( ctx *ReadContext, buf *ByteBuffer, value reflect.Value, flag int8, length int) { trackRefs := (flag & CollectionTrackingRef) != 0 hasNull := (flag & CollectionHasNull) != 0 @@ -464,7 +490,7 @@ func (s sliceDynSerializer) readDifferentTypes( // 1. Slice element type is pointer-to-interface and the deserialized type is not a pointer, OR // 2. Slice element type is interface and the deserialized type doesn't directly implement it // but the pointer type does (common case where interface has pointer receivers) -func (s sliceDynSerializer) wrapSerializerIfNeeded(elemType reflect.Type, serializer Serializer) (reflect.Type, Serializer) { +func (s *sliceDynSerializer) wrapSerializerIfNeeded(elemType reflect.Type, serializer Serializer) (reflect.Type, Serializer) { if elemType.Kind() == reflect.Ptr { return elemType, serializer } diff --git a/go/fory/slice_primitive.go b/go/fory/slice_primitive.go index 9b92691ac8..6fc371075a 100644 --- a/go/fory/slice_primitive.go +++ b/go/fory/slice_primitive.go @@ -652,6 +652,17 @@ func (s stringSliceSerializer) ReadData(ctx *ReadContext, value reflect.Value) { return } ptr := (*[]string)(value.Addr().UnsafePointer()) + if length < 0 { + ctx.SetError(DeserializationErrorf("negative graph element count: %d", length)) + return + } + if int64(length) > stringMaxLength { + ctx.SetError(DeserializationErrorf("graph memory estimate overflows: length=%d elementBytes=%d", length, stringElementBytes)) + return + } + if !ctx.ReserveGraphMemory(int64(length) * stringElementBytes) { + return + } if length == 0 { *ptr = make([]string, 0) return diff --git a/go/fory/slice_primitive_list.go b/go/fory/slice_primitive_list.go index 0335b2a08e..55bb5d5f23 100644 --- a/go/fory/slice_primitive_list.go +++ b/go/fory/slice_primitive_list.go @@ -25,6 +25,18 @@ import ( type primitiveListSerializer struct { type_ reflect.Type elemTypeID TypeId + elemBytes int64 + maxLength int64 +} + +func newPrimitiveList(type_ reflect.Type, elemTypeID TypeId, elemType reflect.Type) primitiveListSerializer { + elemBytes := int64(elemType.Size()) + return primitiveListSerializer{ + type_: type_, + elemTypeID: elemTypeID, + elemBytes: elemBytes, + maxLength: maxGraphCount(elemBytes), + } } type compatiblePrimitiveListToArraySerializer struct { @@ -39,43 +51,43 @@ func newPrimitiveListSerializer(type_ reflect.Type, elemTypeID TypeId) (Serializ elemType := type_.Elem() switch elemType.Kind() { case reflect.Bool: - return primitiveListSerializer{type_: type_, elemTypeID: elemTypeID}, elemTypeID == BOOL + return newPrimitiveList(type_, elemTypeID, elemType), elemTypeID == BOOL case reflect.Int8: - return primitiveListSerializer{type_: type_, elemTypeID: elemTypeID}, elemTypeID == INT8 + return newPrimitiveList(type_, elemTypeID, elemType), elemTypeID == INT8 case reflect.Uint8: - return primitiveListSerializer{type_: type_, elemTypeID: elemTypeID}, elemTypeID == UINT8 + return newPrimitiveList(type_, elemTypeID, elemType), elemTypeID == UINT8 case reflect.Int16: - return primitiveListSerializer{type_: type_, elemTypeID: elemTypeID}, elemTypeID == INT16 + return newPrimitiveList(type_, elemTypeID, elemType), elemTypeID == INT16 case reflect.Uint16: if elemType == float16Type { - return primitiveListSerializer{type_: type_, elemTypeID: elemTypeID}, elemTypeID == FLOAT16 + return newPrimitiveList(type_, elemTypeID, elemType), elemTypeID == FLOAT16 } if elemType == bfloat16Type { - return primitiveListSerializer{type_: type_, elemTypeID: elemTypeID}, elemTypeID == BFLOAT16 + return newPrimitiveList(type_, elemTypeID, elemType), elemTypeID == BFLOAT16 } - return primitiveListSerializer{type_: type_, elemTypeID: elemTypeID}, elemTypeID == UINT16 + return newPrimitiveList(type_, elemTypeID, elemType), elemTypeID == UINT16 case reflect.Int32: - return primitiveListSerializer{type_: type_, elemTypeID: elemTypeID}, elemTypeID == INT32 || elemTypeID == VARINT32 + return newPrimitiveList(type_, elemTypeID, elemType), elemTypeID == INT32 || elemTypeID == VARINT32 case reflect.Uint32: - return primitiveListSerializer{type_: type_, elemTypeID: elemTypeID}, elemTypeID == UINT32 || elemTypeID == VAR_UINT32 + return newPrimitiveList(type_, elemTypeID, elemType), elemTypeID == UINT32 || elemTypeID == VAR_UINT32 case reflect.Int64: - return primitiveListSerializer{type_: type_, elemTypeID: elemTypeID}, elemTypeID == INT64 || elemTypeID == VARINT64 || elemTypeID == TAGGED_INT64 + return newPrimitiveList(type_, elemTypeID, elemType), elemTypeID == INT64 || elemTypeID == VARINT64 || elemTypeID == TAGGED_INT64 case reflect.Uint64: - return primitiveListSerializer{type_: type_, elemTypeID: elemTypeID}, elemTypeID == UINT64 || elemTypeID == VAR_UINT64 || elemTypeID == TAGGED_UINT64 + return newPrimitiveList(type_, elemTypeID, elemType), elemTypeID == UINT64 || elemTypeID == VAR_UINT64 || elemTypeID == TAGGED_UINT64 case reflect.Int: if reflect.TypeOf(int(0)).Size() == 8 { - return primitiveListSerializer{type_: type_, elemTypeID: elemTypeID}, elemTypeID == INT64 || elemTypeID == VARINT64 || elemTypeID == TAGGED_INT64 + return newPrimitiveList(type_, elemTypeID, elemType), elemTypeID == INT64 || elemTypeID == VARINT64 || elemTypeID == TAGGED_INT64 } - return primitiveListSerializer{type_: type_, elemTypeID: elemTypeID}, elemTypeID == INT32 || elemTypeID == VARINT32 + return newPrimitiveList(type_, elemTypeID, elemType), elemTypeID == INT32 || elemTypeID == VARINT32 case reflect.Uint: if reflect.TypeOf(uint(0)).Size() == 8 { - return primitiveListSerializer{type_: type_, elemTypeID: elemTypeID}, elemTypeID == UINT64 || elemTypeID == VAR_UINT64 || elemTypeID == TAGGED_UINT64 + return newPrimitiveList(type_, elemTypeID, elemType), elemTypeID == UINT64 || elemTypeID == VAR_UINT64 || elemTypeID == TAGGED_UINT64 } - return primitiveListSerializer{type_: type_, elemTypeID: elemTypeID}, elemTypeID == UINT32 || elemTypeID == VAR_UINT32 + return newPrimitiveList(type_, elemTypeID, elemType), elemTypeID == UINT32 || elemTypeID == VAR_UINT32 case reflect.Float32: - return primitiveListSerializer{type_: type_, elemTypeID: elemTypeID}, elemTypeID == FLOAT32 + return newPrimitiveList(type_, elemTypeID, elemType), elemTypeID == FLOAT32 case reflect.Float64: - return primitiveListSerializer{type_: type_, elemTypeID: elemTypeID}, elemTypeID == FLOAT64 + return newPrimitiveList(type_, elemTypeID, elemType), elemTypeID == FLOAT64 default: return nil, false } @@ -167,6 +179,17 @@ func (s primitiveListSerializer) ReadData(ctx *ReadContext, value reflect.Value) if ctx.HasError() { return } + if length < 0 { + ctx.SetError(DeserializationErrorf("negative graph element count: %d", length)) + return + } + if int64(length) > s.maxLength { + ctx.SetError(DeserializationErrorf("graph memory estimate overflows: length=%d elementBytes=%d", length, s.elemBytes)) + return + } + if !ctx.ReserveGraphMemory(int64(length) * s.elemBytes) { + return + } if length == 0 { value.Set(reflect.MakeSlice(value.Type(), 0, 0)) return @@ -266,6 +289,17 @@ func (s compatiblePrimitiveListToArraySerializer) ReadData(ctx *ReadContext, val return } if value.Kind() == reflect.Slice { + if length < 0 { + ctx.SetError(DeserializationErrorf("negative graph element count: %d", length)) + return + } + if int64(length) > s.listReader.maxLength { + ctx.SetError(DeserializationErrorf("graph memory estimate overflows: length=%d elementBytes=%d", length, s.listReader.elemBytes)) + return + } + if !ctx.ReserveGraphMemory(int64(length) * s.listReader.elemBytes) { + return + } temp := reflect.New(value.Type()).Elem() s.listReader.readValues(buf, err, temp, length, false) if ctx.HasError() { diff --git a/go/fory/stream.go b/go/fory/stream.go index bb86689598..6c6c91e306 100644 --- a/go/fory/stream.go +++ b/go/fory/stream.go @@ -96,6 +96,8 @@ func (is *InputStream) Shrink() { func (f *Fory) DeserializeFromStream(is *InputStream, v any) error { origBuffer := f.readCtx.buffer f.readCtx.buffer = is.buffer + target := reflect.ValueOf(v).Elem() + f.readCtx.remainingGraphMemoryBytes = f.config.MaxGraphMemoryBytes defer func() { f.readCtx.buffer = origBuffer f.resetReadState() @@ -106,8 +108,7 @@ func (f *Fory) DeserializeFromStream(is *InputStream, v any) error { return f.readCtx.TakeError() } - target := reflect.ValueOf(v).Elem() - f.readCtx.ReadValue(target, RefModeTracking, true) + f.readRootValue(target) if f.readCtx.HasError() { return f.readCtx.TakeError() } @@ -123,14 +124,15 @@ func (f *Fory) DeserializeFromReader(r io.Reader, v any) error { defer f.resetReadState() // Always reset to enforce stateless semantics. f.readCtx.buffer.ResetWithReader(r, 0) + target := reflect.ValueOf(v).Elem() + f.readCtx.remainingGraphMemoryBytes = f.config.MaxGraphMemoryBytes readHeader(f.readCtx) if f.readCtx.HasError() { return f.readCtx.TakeError() } - target := reflect.ValueOf(v).Elem() - f.readCtx.ReadValue(target, RefModeTracking, true) + f.readRootValue(target) if f.readCtx.HasError() { return f.readCtx.TakeError() } diff --git a/go/fory/struct.go b/go/fory/struct.go index 890ca21082..39fc7ce43e 100644 --- a/go/fory/struct.go +++ b/go/fory/struct.go @@ -32,6 +32,7 @@ type structSerializer struct { structHash int32 typeID uint32 userTypeID uint32 + graphBytes int64 // Pre-sorted and categorized fields (embedded for cache locality) fieldGroup FieldGroup @@ -62,6 +63,7 @@ func newStructSerializerFromTypeDef(type_ reflect.Type, name string, fieldDefs [ type_: type_, name: name, userTypeID: invalidUserTypeID, + graphBytes: structGraphBytes(type_), fieldDefs: fieldDefs, } } @@ -77,6 +79,7 @@ func newStructSerializer(type_ reflect.Type, name string) *structSerializer { type_: type_, name: name, userTypeID: invalidUserTypeID, + graphBytes: structGraphBytes(type_), } } @@ -1367,6 +1370,53 @@ func (s *structSerializer) Read(ctx *ReadContext, refMode RefMode, readType bool s.ReadData(ctx, value) } +func (s *structSerializer) readRoot(ctx *ReadContext, value reflect.Value) { + buf := ctx.buffer + ctxErr := ctx.Err() + if ctx.refResolver.refTracking { + refID, refErr := ctx.refResolver.TryPreserveRefId(buf) + if refErr != nil { + ctx.SetError(FromError(refErr)) + return + } + if refID < int32(NotNullValueFlag) { + obj := ctx.refResolver.GetReadObject(refID) + if obj.IsValid() { + value.Set(obj) + } + return + } + } else { + // No-ref roots only need the marker byte; avoid the tracking helper on this hot path. + refFlag := buf.ReadInt8(ctxErr) + if refFlag == NullFlag { + return + } + if refFlag == RefFlag { + buf.ReadVarUint32(ctxErr) + return + } + } + if !ctx.ReserveGraphMemory(s.graphBytes) { + return + } + if s.type_ != nil { + serializer := ctx.typeResolver.ReadTypeInfoForType(buf, s.type_, ctxErr) + if ctxErr.HasError() { + return + } + if serializer == nil { + ctx.SetError(DeserializationError("unexpected type id for struct")) + return + } + if structSer, ok := serializer.(*structSerializer); ok && len(structSer.fieldDefs) > 0 { + structSer.ReadData(ctx, value) + return + } + } + s.ReadData(ctx, value) +} + func (s *structSerializer) ReadWithTypeInfo(ctx *ReadContext, refMode RefMode, typeInfo *TypeInfo, value reflect.Value) { // typeInfo is already read, don't read it again s.Read(ctx, refMode, false, false, value) diff --git a/go/fory/struct_test.go b/go/fory/struct_test.go index ac5d8ad6f0..262a404135 100644 --- a/go/fory/struct_test.go +++ b/go/fory/struct_test.go @@ -607,13 +607,13 @@ func TestSkipAnyValueReadsSharedTypeMeta(t *testing.T) { readHeader(f.readCtx) SkipAnyValue(f.readCtx, true) require.NoError(t, f.readCtx.CheckError()) + nextRootIndex := f.readCtx.Buffer().ReaderIndex() - f.resetReadState() - readHeader(f.readCtx) + rootBuffer := NewByteBuffer(buf.Bytes()) + rootBuffer.SetReaderIndex(nextRootIndex) var out any - f.readCtx.ReadValue(reflect.ValueOf(&out).Elem(), RefModeTracking, true) - require.NoError(t, f.readCtx.CheckError()) + require.NoError(t, f.DeserializeFrom(rootBuffer, &out)) result, ok := out.(*Second) require.True(t, ok) diff --git a/go/fory/tests/structs_fory_gen.go b/go/fory/tests/structs_fory_gen.go index 742135a8ba..2fb5ad9623 100644 --- a/go/fory/tests/structs_fory_gen.go +++ b/go/fory/tests/structs_fory_gen.go @@ -1,12 +1,13 @@ // Code generated by forygen. DO NOT EDIT. // source: structs.go -// generated at: 2026-06-12T06:41:26+08:00 +// generated at: 2026-07-01T02:06:03+08:00 package fory import ( "github.com/apache/fory/go/fory" "reflect" + "unsafe" ) func init() { @@ -189,6 +190,25 @@ func (g *DynamicSliceDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v if ctx.HasError() { return ctx.TakeError() } + { + graphCount := sliceLen + graphElemBytes := int64(int64(unsafe.Sizeof(*new(any)))) + if graphCount < 0 { + ctx.SetError(fory.DeserializationErrorf("negative graph element count: %d", graphCount)) + return ctx.TakeError() + } + if graphElemBytes < 0 { + ctx.SetError(fory.DeserializationErrorf("negative graph element size: %d", graphElemBytes)) + return ctx.TakeError() + } + if graphCount != 0 && graphElemBytes != 0 && int64(graphCount) > fory.MaxInt64/graphElemBytes { + ctx.SetError(fory.DeserializationErrorf("graph memory estimate overflows: length=%d elementBytes=%d", graphCount, graphElemBytes)) + return ctx.TakeError() + } + if !ctx.ReserveGraphMemory(int64(graphCount) * graphElemBytes) { + return ctx.TakeError() + } + } if sliceLen == 0 { v.DynamicSlice = make([]any, 0) } else { @@ -217,6 +237,25 @@ func (g *DynamicSliceDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v if ctx.HasError() { return ctx.TakeError() } + { + graphCount := sliceLen + graphElemBytes := int64(int64(unsafe.Sizeof(*new(any)))) + if graphCount < 0 { + ctx.SetError(fory.DeserializationErrorf("negative graph element count: %d", graphCount)) + return ctx.TakeError() + } + if graphElemBytes < 0 { + ctx.SetError(fory.DeserializationErrorf("negative graph element size: %d", graphElemBytes)) + return ctx.TakeError() + } + if graphCount != 0 && graphElemBytes != 0 && int64(graphCount) > fory.MaxInt64/graphElemBytes { + ctx.SetError(fory.DeserializationErrorf("graph memory estimate overflows: length=%d elementBytes=%d", graphCount, graphElemBytes)) + return ctx.TakeError() + } + if !ctx.ReserveGraphMemory(int64(graphCount) * graphElemBytes) { + return ctx.TakeError() + } + } if sliceLen == 0 { v.DynamicSlice = make([]any, 0) } else { @@ -253,7 +292,13 @@ func (g *DynamicSliceDemo_ForyGenSerializer) ReadData(ctx *fory.ReadContext, val var v *DynamicSliceDemo if value.Kind() == reflect.Ptr { if value.IsNil() { - // For pointer types, allocate using value.Type().Elem() + graphBytes := int64(unsafe.Sizeof(DynamicSliceDemo{})) + if graphBytes == 0 { + graphBytes = 1 + } + if !ctx.ReserveGraphMemory(graphBytes) { + return + } value.Set(reflect.New(value.Type().Elem())) } v = value.Interface().(*DynamicSliceDemo) @@ -662,6 +707,28 @@ func (g *MapDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v *MapDemo) if isXlang { // xlang mode: maps are not nullable, read directly without null flag mapLen := ctx.ReadCollectionLength() + if ctx.HasError() { + return ctx.TakeError() + } + { + graphCount := mapLen + graphElemBytes := int64(int64(unsafe.Sizeof(*new(int))) + int64(unsafe.Sizeof(*new(int)))) + if graphCount < 0 { + ctx.SetError(fory.DeserializationErrorf("negative graph element count: %d", graphCount)) + return ctx.TakeError() + } + if graphElemBytes < 0 { + ctx.SetError(fory.DeserializationErrorf("negative graph element size: %d", graphElemBytes)) + return ctx.TakeError() + } + if graphCount != 0 && graphElemBytes != 0 && int64(graphCount) > fory.MaxInt64/graphElemBytes { + ctx.SetError(fory.DeserializationErrorf("graph memory estimate overflows: length=%d elementBytes=%d", graphCount, graphElemBytes)) + return ctx.TakeError() + } + if !ctx.ReserveGraphMemory(int64(graphCount) * graphElemBytes) { + return ctx.TakeError() + } + } if mapLen == 0 { v.IntMap = make(map[int]int) } else { @@ -709,6 +776,25 @@ func (g *MapDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v *MapDemo) if ctx.HasError() { return ctx.TakeError() } + { + graphCount := mapLen + graphElemBytes := int64(int64(unsafe.Sizeof(*new(int))) + int64(unsafe.Sizeof(*new(int)))) + if graphCount < 0 { + ctx.SetError(fory.DeserializationErrorf("negative graph element count: %d", graphCount)) + return ctx.TakeError() + } + if graphElemBytes < 0 { + ctx.SetError(fory.DeserializationErrorf("negative graph element size: %d", graphElemBytes)) + return ctx.TakeError() + } + if graphCount != 0 && graphElemBytes != 0 && int64(graphCount) > fory.MaxInt64/graphElemBytes { + ctx.SetError(fory.DeserializationErrorf("graph memory estimate overflows: length=%d elementBytes=%d", graphCount, graphElemBytes)) + return ctx.TakeError() + } + if !ctx.ReserveGraphMemory(int64(graphCount) * graphElemBytes) { + return ctx.TakeError() + } + } if mapLen == 0 { v.IntMap = make(map[int]int) } else { @@ -755,6 +841,28 @@ func (g *MapDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v *MapDemo) if isXlang { // xlang mode: maps are not nullable, read directly without null flag mapLen := ctx.ReadCollectionLength() + if ctx.HasError() { + return ctx.TakeError() + } + { + graphCount := mapLen + graphElemBytes := int64(int64(unsafe.Sizeof(*new(string))) + int64(unsafe.Sizeof(*new(int)))) + if graphCount < 0 { + ctx.SetError(fory.DeserializationErrorf("negative graph element count: %d", graphCount)) + return ctx.TakeError() + } + if graphElemBytes < 0 { + ctx.SetError(fory.DeserializationErrorf("negative graph element size: %d", graphElemBytes)) + return ctx.TakeError() + } + if graphCount != 0 && graphElemBytes != 0 && int64(graphCount) > fory.MaxInt64/graphElemBytes { + ctx.SetError(fory.DeserializationErrorf("graph memory estimate overflows: length=%d elementBytes=%d", graphCount, graphElemBytes)) + return ctx.TakeError() + } + if !ctx.ReserveGraphMemory(int64(graphCount) * graphElemBytes) { + return ctx.TakeError() + } + } if mapLen == 0 { v.MixedMap = make(map[string]int) } else { @@ -802,6 +910,25 @@ func (g *MapDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v *MapDemo) if ctx.HasError() { return ctx.TakeError() } + { + graphCount := mapLen + graphElemBytes := int64(int64(unsafe.Sizeof(*new(string))) + int64(unsafe.Sizeof(*new(int)))) + if graphCount < 0 { + ctx.SetError(fory.DeserializationErrorf("negative graph element count: %d", graphCount)) + return ctx.TakeError() + } + if graphElemBytes < 0 { + ctx.SetError(fory.DeserializationErrorf("negative graph element size: %d", graphElemBytes)) + return ctx.TakeError() + } + if graphCount != 0 && graphElemBytes != 0 && int64(graphCount) > fory.MaxInt64/graphElemBytes { + ctx.SetError(fory.DeserializationErrorf("graph memory estimate overflows: length=%d elementBytes=%d", graphCount, graphElemBytes)) + return ctx.TakeError() + } + if !ctx.ReserveGraphMemory(int64(graphCount) * graphElemBytes) { + return ctx.TakeError() + } + } if mapLen == 0 { v.MixedMap = make(map[string]int) } else { @@ -848,6 +975,28 @@ func (g *MapDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v *MapDemo) if isXlang { // xlang mode: maps are not nullable, read directly without null flag mapLen := ctx.ReadCollectionLength() + if ctx.HasError() { + return ctx.TakeError() + } + { + graphCount := mapLen + graphElemBytes := int64(int64(unsafe.Sizeof(*new(string))) + int64(unsafe.Sizeof(*new(string)))) + if graphCount < 0 { + ctx.SetError(fory.DeserializationErrorf("negative graph element count: %d", graphCount)) + return ctx.TakeError() + } + if graphElemBytes < 0 { + ctx.SetError(fory.DeserializationErrorf("negative graph element size: %d", graphElemBytes)) + return ctx.TakeError() + } + if graphCount != 0 && graphElemBytes != 0 && int64(graphCount) > fory.MaxInt64/graphElemBytes { + ctx.SetError(fory.DeserializationErrorf("graph memory estimate overflows: length=%d elementBytes=%d", graphCount, graphElemBytes)) + return ctx.TakeError() + } + if !ctx.ReserveGraphMemory(int64(graphCount) * graphElemBytes) { + return ctx.TakeError() + } + } if mapLen == 0 { v.StringMap = make(map[string]string) } else { @@ -895,6 +1044,25 @@ func (g *MapDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v *MapDemo) if ctx.HasError() { return ctx.TakeError() } + { + graphCount := mapLen + graphElemBytes := int64(int64(unsafe.Sizeof(*new(string))) + int64(unsafe.Sizeof(*new(string)))) + if graphCount < 0 { + ctx.SetError(fory.DeserializationErrorf("negative graph element count: %d", graphCount)) + return ctx.TakeError() + } + if graphElemBytes < 0 { + ctx.SetError(fory.DeserializationErrorf("negative graph element size: %d", graphElemBytes)) + return ctx.TakeError() + } + if graphCount != 0 && graphElemBytes != 0 && int64(graphCount) > fory.MaxInt64/graphElemBytes { + ctx.SetError(fory.DeserializationErrorf("graph memory estimate overflows: length=%d elementBytes=%d", graphCount, graphElemBytes)) + return ctx.TakeError() + } + if !ctx.ReserveGraphMemory(int64(graphCount) * graphElemBytes) { + return ctx.TakeError() + } + } if mapLen == 0 { v.StringMap = make(map[string]string) } else { @@ -950,7 +1118,13 @@ func (g *MapDemo_ForyGenSerializer) ReadData(ctx *fory.ReadContext, value reflec var v *MapDemo if value.Kind() == reflect.Ptr { if value.IsNil() { - // For pointer types, allocate using value.Type().Elem() + graphBytes := int64(unsafe.Sizeof(MapDemo{})) + if graphBytes == 0 { + graphBytes = 1 + } + if !ctx.ReserveGraphMemory(graphBytes) { + return + } value.Set(reflect.New(value.Type().Elem())) } v = value.Interface().(*MapDemo) @@ -1250,6 +1424,28 @@ func (g *SliceDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v *SliceD if isXlang { // xlang mode: slices are not nullable, read directly without null flag sliceLen := ctx.ReadCollectionLength() + if ctx.HasError() { + return ctx.TakeError() + } + { + graphCount := sliceLen + graphElemBytes := int64(int64(unsafe.Sizeof(*new(bool)))) + if graphCount < 0 { + ctx.SetError(fory.DeserializationErrorf("negative graph element count: %d", graphCount)) + return ctx.TakeError() + } + if graphElemBytes < 0 { + ctx.SetError(fory.DeserializationErrorf("negative graph element size: %d", graphElemBytes)) + return ctx.TakeError() + } + if graphCount != 0 && graphElemBytes != 0 && int64(graphCount) > fory.MaxInt64/graphElemBytes { + ctx.SetError(fory.DeserializationErrorf("graph memory estimate overflows: length=%d elementBytes=%d", graphCount, graphElemBytes)) + return ctx.TakeError() + } + if !ctx.ReserveGraphMemory(int64(graphCount) * graphElemBytes) { + return ctx.TakeError() + } + } if sliceLen == 0 { v.BoolSlice = make([]bool, 0) } else { @@ -1289,6 +1485,25 @@ func (g *SliceDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v *SliceD if ctx.HasError() { return ctx.TakeError() } + { + graphCount := sliceLen + graphElemBytes := int64(int64(unsafe.Sizeof(*new(bool)))) + if graphCount < 0 { + ctx.SetError(fory.DeserializationErrorf("negative graph element count: %d", graphCount)) + return ctx.TakeError() + } + if graphElemBytes < 0 { + ctx.SetError(fory.DeserializationErrorf("negative graph element size: %d", graphElemBytes)) + return ctx.TakeError() + } + if graphCount != 0 && graphElemBytes != 0 && int64(graphCount) > fory.MaxInt64/graphElemBytes { + ctx.SetError(fory.DeserializationErrorf("graph memory estimate overflows: length=%d elementBytes=%d", graphCount, graphElemBytes)) + return ctx.TakeError() + } + if !ctx.ReserveGraphMemory(int64(graphCount) * graphElemBytes) { + return ctx.TakeError() + } + } if sliceLen == 0 { v.BoolSlice = make([]bool, 0) } else { @@ -1327,6 +1542,28 @@ func (g *SliceDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v *SliceD if isXlang { // xlang mode: slices are not nullable, read directly without null flag sliceLen := ctx.ReadCollectionLength() + if ctx.HasError() { + return ctx.TakeError() + } + { + graphCount := sliceLen + graphElemBytes := int64(int64(unsafe.Sizeof(*new(float64)))) + if graphCount < 0 { + ctx.SetError(fory.DeserializationErrorf("negative graph element count: %d", graphCount)) + return ctx.TakeError() + } + if graphElemBytes < 0 { + ctx.SetError(fory.DeserializationErrorf("negative graph element size: %d", graphElemBytes)) + return ctx.TakeError() + } + if graphCount != 0 && graphElemBytes != 0 && int64(graphCount) > fory.MaxInt64/graphElemBytes { + ctx.SetError(fory.DeserializationErrorf("graph memory estimate overflows: length=%d elementBytes=%d", graphCount, graphElemBytes)) + return ctx.TakeError() + } + if !ctx.ReserveGraphMemory(int64(graphCount) * graphElemBytes) { + return ctx.TakeError() + } + } if sliceLen == 0 { v.FloatSlice = make([]float64, 0) } else { @@ -1366,6 +1603,25 @@ func (g *SliceDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v *SliceD if ctx.HasError() { return ctx.TakeError() } + { + graphCount := sliceLen + graphElemBytes := int64(int64(unsafe.Sizeof(*new(float64)))) + if graphCount < 0 { + ctx.SetError(fory.DeserializationErrorf("negative graph element count: %d", graphCount)) + return ctx.TakeError() + } + if graphElemBytes < 0 { + ctx.SetError(fory.DeserializationErrorf("negative graph element size: %d", graphElemBytes)) + return ctx.TakeError() + } + if graphCount != 0 && graphElemBytes != 0 && int64(graphCount) > fory.MaxInt64/graphElemBytes { + ctx.SetError(fory.DeserializationErrorf("graph memory estimate overflows: length=%d elementBytes=%d", graphCount, graphElemBytes)) + return ctx.TakeError() + } + if !ctx.ReserveGraphMemory(int64(graphCount) * graphElemBytes) { + return ctx.TakeError() + } + } if sliceLen == 0 { v.FloatSlice = make([]float64, 0) } else { @@ -1404,6 +1660,28 @@ func (g *SliceDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v *SliceD if isXlang { // xlang mode: slices are not nullable, read directly without null flag sliceLen := ctx.ReadCollectionLength() + if ctx.HasError() { + return ctx.TakeError() + } + { + graphCount := sliceLen + graphElemBytes := int64(int64(unsafe.Sizeof(*new(int32)))) + if graphCount < 0 { + ctx.SetError(fory.DeserializationErrorf("negative graph element count: %d", graphCount)) + return ctx.TakeError() + } + if graphElemBytes < 0 { + ctx.SetError(fory.DeserializationErrorf("negative graph element size: %d", graphElemBytes)) + return ctx.TakeError() + } + if graphCount != 0 && graphElemBytes != 0 && int64(graphCount) > fory.MaxInt64/graphElemBytes { + ctx.SetError(fory.DeserializationErrorf("graph memory estimate overflows: length=%d elementBytes=%d", graphCount, graphElemBytes)) + return ctx.TakeError() + } + if !ctx.ReserveGraphMemory(int64(graphCount) * graphElemBytes) { + return ctx.TakeError() + } + } if sliceLen == 0 { v.IntSlice = make([]int32, 0) } else { @@ -1443,6 +1721,25 @@ func (g *SliceDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v *SliceD if ctx.HasError() { return ctx.TakeError() } + { + graphCount := sliceLen + graphElemBytes := int64(int64(unsafe.Sizeof(*new(int32)))) + if graphCount < 0 { + ctx.SetError(fory.DeserializationErrorf("negative graph element count: %d", graphCount)) + return ctx.TakeError() + } + if graphElemBytes < 0 { + ctx.SetError(fory.DeserializationErrorf("negative graph element size: %d", graphElemBytes)) + return ctx.TakeError() + } + if graphCount != 0 && graphElemBytes != 0 && int64(graphCount) > fory.MaxInt64/graphElemBytes { + ctx.SetError(fory.DeserializationErrorf("graph memory estimate overflows: length=%d elementBytes=%d", graphCount, graphElemBytes)) + return ctx.TakeError() + } + if !ctx.ReserveGraphMemory(int64(graphCount) * graphElemBytes) { + return ctx.TakeError() + } + } if sliceLen == 0 { v.IntSlice = make([]int32, 0) } else { @@ -1481,6 +1778,28 @@ func (g *SliceDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v *SliceD if isXlang { // xlang mode: slices are not nullable, read directly without null flag sliceLen := ctx.ReadCollectionLength() + if ctx.HasError() { + return ctx.TakeError() + } + { + graphCount := sliceLen + graphElemBytes := int64(int64(unsafe.Sizeof(*new(string)))) + if graphCount < 0 { + ctx.SetError(fory.DeserializationErrorf("negative graph element count: %d", graphCount)) + return ctx.TakeError() + } + if graphElemBytes < 0 { + ctx.SetError(fory.DeserializationErrorf("negative graph element size: %d", graphElemBytes)) + return ctx.TakeError() + } + if graphCount != 0 && graphElemBytes != 0 && int64(graphCount) > fory.MaxInt64/graphElemBytes { + ctx.SetError(fory.DeserializationErrorf("graph memory estimate overflows: length=%d elementBytes=%d", graphCount, graphElemBytes)) + return ctx.TakeError() + } + if !ctx.ReserveGraphMemory(int64(graphCount) * graphElemBytes) { + return ctx.TakeError() + } + } if sliceLen == 0 { v.StringSlice = make([]string, 0) } else { @@ -1528,6 +1847,25 @@ func (g *SliceDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v *SliceD if ctx.HasError() { return ctx.TakeError() } + { + graphCount := sliceLen + graphElemBytes := int64(int64(unsafe.Sizeof(*new(string)))) + if graphCount < 0 { + ctx.SetError(fory.DeserializationErrorf("negative graph element count: %d", graphCount)) + return ctx.TakeError() + } + if graphElemBytes < 0 { + ctx.SetError(fory.DeserializationErrorf("negative graph element size: %d", graphElemBytes)) + return ctx.TakeError() + } + if graphCount != 0 && graphElemBytes != 0 && int64(graphCount) > fory.MaxInt64/graphElemBytes { + ctx.SetError(fory.DeserializationErrorf("graph memory estimate overflows: length=%d elementBytes=%d", graphCount, graphElemBytes)) + return ctx.TakeError() + } + if !ctx.ReserveGraphMemory(int64(graphCount) * graphElemBytes) { + return ctx.TakeError() + } + } if sliceLen == 0 { v.StringSlice = make([]string, 0) } else { @@ -1583,7 +1921,13 @@ func (g *SliceDemo_ForyGenSerializer) ReadData(ctx *fory.ReadContext, value refl var v *SliceDemo if value.Kind() == reflect.Ptr { if value.IsNil() { - // For pointer types, allocate using value.Type().Elem() + graphBytes := int64(unsafe.Sizeof(SliceDemo{})) + if graphBytes == 0 { + graphBytes = 1 + } + if !ctx.ReserveGraphMemory(graphBytes) { + return + } value.Set(reflect.New(value.Type().Elem())) } v = value.Interface().(*SliceDemo) @@ -1760,7 +2104,13 @@ func (g *ValidationDemo_ForyGenSerializer) ReadData(ctx *fory.ReadContext, value var v *ValidationDemo if value.Kind() == reflect.Ptr { if value.IsNil() { - // For pointer types, allocate using value.Type().Elem() + graphBytes := int64(unsafe.Sizeof(ValidationDemo{})) + if graphBytes == 0 { + graphBytes = 1 + } + if !ctx.ReserveGraphMemory(graphBytes) { + return + } value.Set(reflect.New(value.Type().Elem())) } v = value.Interface().(*ValidationDemo) diff --git a/go/fory/type_resolver.go b/go/fory/type_resolver.go index f0966b0596..ba4d3e5dc9 100644 --- a/go/fory/type_resolver.go +++ b/go/fory/type_resolver.go @@ -415,7 +415,7 @@ func (r *TypeResolver) initialize() { {stringPtrType, STRING, ptrToStringSerializer{}}, // Register interface types first so typeIDToTypeInfo maps to generic types // that can hold any element type when deserializing into any - {interfaceSliceType, LIST, sliceDynSerializer{}}, + {interfaceSliceType, LIST, mustNewSliceDynSerializer(interfaceType)}, {interfaceMapType, MAP, mapSerializer{type_: interfaceMapType, keyReferencable: true, valueReferencable: true}}, // stringSliceType uses dedicated stringSliceSerializer for optimized serialization // This ensures CollectionIsDeclElementType is set for Java compatibility @@ -1779,7 +1779,7 @@ func (r *TypeResolver) createSerializer(type_ reflect.Type, mapInStruct bool) (s } // For dynamic types, use dynamic slice serializer if isDynamicType(elem) { - return sliceDynSerializer{}, nil + return newSliceDynSerializer(elem) } else { elemSerializer, err := r.getSerializerByType(type_.Elem(), false) if err != nil { diff --git a/java/fory-annotation-processor/src/main/java/org/apache/fory/annotation/processing/ForyStructProcessor.java b/java/fory-annotation-processor/src/main/java/org/apache/fory/annotation/processing/ForyStructProcessor.java index d0ccf91f16..f8348c94db 100644 --- a/java/fory-annotation-processor/src/main/java/org/apache/fory/annotation/processing/ForyStructProcessor.java +++ b/java/fory-annotation-processor/src/main/java/org/apache/fory/annotation/processing/ForyStructProcessor.java @@ -64,6 +64,8 @@ "org.apache.fory.annotation.ForyDebug" }) public final class ForyStructProcessor extends AbstractProcessor { + private static final int OBJECT_SELF_BYTES = 1; + private static final int REFERENCE_BYTES = 4; private static final String ARRAY_TYPE = "org.apache.fory.annotation.ArrayType"; private static final String BFLOAT16_TYPE = "org.apache.fory.annotation.BFloat16Type"; private static final String EXPOSE = "org.apache.fory.annotation.Expose"; @@ -243,10 +245,46 @@ boolean record = isRecord(type); serializerName, record, isForyDebugEnabled(type), + graphMemoryBytes(type), sourceFields, recordConstructorFields); } + private int graphMemoryBytes(TypeElement type) { + int bytes = OBJECT_SELF_BYTES; + for (TypeElement current : hierarchy(type)) { + for (VariableElement field : ElementFilter.fieldsIn(current.getEnclosedElements())) { + if (!field.getModifiers().contains(Modifier.STATIC)) { + bytes = Math.addExact(bytes, fieldGraphMemoryBytes(field.asType())); + } + } + } + return bytes; + } + + private int fieldGraphMemoryBytes(TypeMirror type) { + TypeKind kind = type.getKind(); + if (!kind.isPrimitive()) { + return REFERENCE_BYTES; + } + switch (kind) { + case BOOLEAN: + case BYTE: + return 1; + case CHAR: + case SHORT: + return 2; + case INT: + case FLOAT: + return 4; + case LONG: + case DOUBLE: + return 8; + default: + return 0; + } + } + private boolean isForyDebugEnabled(TypeElement type) { return annotationMirror(type, FORY_DEBUG) != null; } diff --git a/java/fory-annotation-processor/src/main/java/org/apache/fory/annotation/processing/SourceStruct.java b/java/fory-annotation-processor/src/main/java/org/apache/fory/annotation/processing/SourceStruct.java index 859ec8b2c2..d451d5220c 100644 --- a/java/fory-annotation-processor/src/main/java/org/apache/fory/annotation/processing/SourceStruct.java +++ b/java/fory-annotation-processor/src/main/java/org/apache/fory/annotation/processing/SourceStruct.java @@ -31,6 +31,7 @@ final class SourceStruct { final boolean record; final boolean debug; final boolean hasNestedCompatibleStructFields; + final int graphMemoryBytes; final List fields; final List recordConstructorFields; @@ -41,6 +42,7 @@ final class SourceStruct { String serializerName, boolean record, boolean debug, + int graphMemoryBytes, List fields, List recordConstructorFields) { this.packageName = packageName; @@ -49,6 +51,7 @@ final class SourceStruct { this.serializerName = serializerName; this.record = record; this.debug = debug; + this.graphMemoryBytes = graphMemoryBytes; this.fields = Collections.unmodifiableList(new ArrayList<>(fields)); this.recordConstructorFields = Collections.unmodifiableList(new ArrayList<>(recordConstructorFields)); diff --git a/java/fory-annotation-processor/src/main/java/org/apache/fory/annotation/processing/StaticSerializerSourceWriter.java b/java/fory-annotation-processor/src/main/java/org/apache/fory/annotation/processing/StaticSerializerSourceWriter.java index 4d29797710..fb61507c63 100644 --- a/java/fory-annotation-processor/src/main/java/org/apache/fory/annotation/processing/StaticSerializerSourceWriter.java +++ b/java/fory-annotation-processor/src/main/java/org/apache/fory/annotation/processing/StaticSerializerSourceWriter.java @@ -247,6 +247,7 @@ private void writeSchemaConsistentRead() { .append(struct.typeName) .append(" readSchemaConsistent(ReadContext readContext) {\n"); builder.append(" MemoryBuffer buffer = readContext.getBuffer();\n"); + appendGraphMemoryReserve(); builder.append(" if (typeResolver.checkClassVersion()) {\n"); builder.append(" checkClassVersion(buffer.readInt32(), classVersionHash);\n"); builder.append(" }\n"); @@ -794,6 +795,13 @@ private String exactPrimitiveTypeId(SourceField field) { return meta.substring(start + prefix.length(), end); } + private void appendGraphMemoryReserve() { + builder + .append(" readContext.reserveGraphMemory(") + .append(struct.graphMemoryBytes) + .append(");\n"); + } + private void writeCompatibleRead() { builder.append(" @Override\n"); builder @@ -803,6 +811,7 @@ private void writeCompatibleRead() { builder.append(" if (sameSchemaCompatible) {\n"); builder.append(" return readSchemaConsistent(readContext);\n"); builder.append(" }\n"); + appendGraphMemoryReserve(); if (struct.record) { for (SourceField field : struct.fields) { builder diff --git a/java/fory-core/src/main/java/org/apache/fory/builder/ObjectCodecBuilder.java b/java/fory-core/src/main/java/org/apache/fory/builder/ObjectCodecBuilder.java index b467361ac0..7bb1e213a7 100644 --- a/java/fory-core/src/main/java/org/apache/fory/builder/ObjectCodecBuilder.java +++ b/java/fory-core/src/main/java/org/apache/fory/builder/ObjectCodecBuilder.java @@ -39,6 +39,8 @@ import static org.apache.fory.type.TypeUtils.SHORT_TYPE; import static org.apache.fory.type.TypeUtils.getRawType; +import java.lang.reflect.Field; +import java.lang.reflect.Modifier; import java.util.ArrayList; import java.util.Collection; import java.util.List; @@ -65,6 +67,7 @@ import org.apache.fory.logging.LoggerFactory; import org.apache.fory.meta.TypeDef; import org.apache.fory.platform.JdkVersion; +import org.apache.fory.reflect.ReflectionUtils; import org.apache.fory.reflect.TypeRef; import org.apache.fory.serializer.ObjectSerializer; import org.apache.fory.type.BFloat16; @@ -94,6 +97,8 @@ */ public class ObjectCodecBuilder extends BaseObjectCodecBuilder { private static final Logger LOG = LoggerFactory.getLogger(ObjectCodecBuilder.class); + private static final int OBJECT_SELF_BYTES = 1; + private static final int REFERENCE_BYTES = 4; private final Literal classVersionHash; protected ObjectCodecOptimizer objectCodecOptimizer; @@ -793,6 +798,7 @@ private Expression getWriterPos(Expression writerPos, long acc) { public Expression buildDecodeExpression() { Reference buffer = new Reference(BUFFER_NAME, bufferTypeRef, false); ListExpression expressions = new ListExpression(); + expressions.add(new Expression.Block(graphMemoryReserveCode())); if (typeResolver.checkClassVersion()) { expressions.add(checkClassVersion(buffer)); } @@ -832,6 +838,39 @@ public Expression buildDecodeExpression() { return expressions; } + protected String graphMemoryReserveCode() { + return READ_CONTEXT_NAME + ".reserveGraphMemory(" + objectGraphMemoryBytes() + ");"; + } + + private int objectGraphMemoryBytes() { + int bytes = OBJECT_SELF_BYTES; + for (Field field : ReflectionUtils.getFields(beanClass, true)) { + if (!Modifier.isStatic(field.getModifiers())) { + bytes = Math.addExact(bytes, fieldGraphMemoryBytes(field.getType())); + } + } + return bytes; + } + + private int fieldGraphMemoryBytes(Class fieldType) { + if (!fieldType.isPrimitive()) { + return REFERENCE_BYTES; + } + if (fieldType == boolean.class || fieldType == byte.class) { + return 1; + } + if (fieldType == char.class || fieldType == short.class) { + return 2; + } + if (fieldType == int.class || fieldType == float.class) { + return 4; + } + if (fieldType == long.class || fieldType == double.class) { + return 8; + } + return 0; + } + protected void deserializeReadGroup( List> readGroups, int numGroups, diff --git a/java/fory-core/src/main/java/org/apache/fory/builder/StaticCompatibleCodecBuilder.java b/java/fory-core/src/main/java/org/apache/fory/builder/StaticCompatibleCodecBuilder.java index 4a43269771..9b699d4cdb 100644 --- a/java/fory-core/src/main/java/org/apache/fory/builder/StaticCompatibleCodecBuilder.java +++ b/java/fory-core/src/main/java/org/apache/fory/builder/StaticCompatibleCodecBuilder.java @@ -160,6 +160,7 @@ private String genObjectCompatibleRead() { ? "((" + ctx.type(beanClass) + ") " + beanCode.value() + ")" : beanCode.value().toString(); StringBuilder code = new StringBuilder(); + code.append(graphMemoryReserveCode()).append('\n'); if (StringUtils.isNotBlank(beanCode.code())) { code.append(beanCode.code()).append('\n'); } @@ -189,6 +190,7 @@ private String genObjectCompatibleRead() { private String genRecordCompatibleRead() { RecordComponent[] components = RecordUtils.getRecordComponents(beanClass); StringBuilder code = new StringBuilder(); + code.append(graphMemoryReserveCode()).append('\n'); for (int i = 0; i < components.length; i++) { Class componentType = components[i].getType(); code.append(recordLocalType(componentType)) diff --git a/java/fory-core/src/main/java/org/apache/fory/config/Config.java b/java/fory-core/src/main/java/org/apache/fory/config/Config.java index 2b8db3ec66..8f07945706 100644 --- a/java/fory-core/src/main/java/org/apache/fory/config/Config.java +++ b/java/fory-core/src/main/java/org/apache/fory/config/Config.java @@ -68,6 +68,7 @@ public class Config implements Serializable { private final int maxTypeMetaBytes; private final int maxSchemaVersionsPerType; private final int maxAverageSchemaVersionsPerType; + private final long maxGraphMemoryBytes; private final float mapRefLoadFactor; private final boolean forVirtualThread; @@ -114,6 +115,7 @@ public Config(ForyBuilder builder) { maxTypeMetaBytes = builder.maxTypeMetaBytes; maxSchemaVersionsPerType = builder.maxSchemaVersionsPerType; maxAverageSchemaVersionsPerType = builder.maxAverageSchemaVersionsPerType; + maxGraphMemoryBytes = builder.maxGraphMemoryBytes; mapRefLoadFactor = builder.mapRefLoadFactor; forVirtualThread = builder.forVirtualThread; } @@ -320,6 +322,11 @@ public int maxAverageSchemaVersionsPerType() { return maxAverageSchemaVersionsPerType; } + /** Returns the root-operation estimated graph memory limit in bytes. */ + public long maxGraphMemoryBytes() { + return maxGraphMemoryBytes; + } + /** Returns loadFactor of MacRef's writtenObjects. */ public float mapRefLoadFactor() { return mapRefLoadFactor; @@ -368,6 +375,7 @@ public boolean equals(Object o) { && maxTypeMetaBytes == config.maxTypeMetaBytes && maxSchemaVersionsPerType == config.maxSchemaVersionsPerType && maxAverageSchemaVersionsPerType == config.maxAverageSchemaVersionsPerType + && maxGraphMemoryBytes == config.maxGraphMemoryBytes && Objects.equals(defaultJDKStreamSerializerType, config.defaultJDKStreamSerializerType) && longEncoding == config.longEncoding && forVirtualThread == config.forVirtualThread; @@ -403,6 +411,7 @@ public int hashCode() { maxTypeMetaBytes, maxSchemaVersionsPerType, maxAverageSchemaVersionsPerType, + maxGraphMemoryBytes, metaShareEnabled, scopedMetaShareEnabled, metaCompressor, diff --git a/java/fory-core/src/main/java/org/apache/fory/config/ForyBuilder.java b/java/fory-core/src/main/java/org/apache/fory/config/ForyBuilder.java index 48d9dcb433..30b18257c7 100644 --- a/java/fory-core/src/main/java/org/apache/fory/config/ForyBuilder.java +++ b/java/fory-core/src/main/java/org/apache/fory/config/ForyBuilder.java @@ -103,6 +103,7 @@ public final class ForyBuilder { int maxTypeMetaBytes = 4096; int maxSchemaVersionsPerType = 10; int maxAverageSchemaVersionsPerType = 3; + long maxGraphMemoryBytes = 128L * 1024 * 1024; float mapRefLoadFactor = 0.51f; boolean forVirtualThread = false; TypeChecker typeChecker; @@ -571,6 +572,18 @@ public ForyBuilder withMaxAverageSchemaVersionsPerType(int maxAverageSchemaVersi return this; } + /** + * Sets the maximum estimated graph memory accepted during one root deserialization. + * + *

The default is a fixed 128 MiB. Values must be positive byte limits. + */ + public ForyBuilder withMaxGraphMemoryBytes(long maxGraphMemoryBytes) { + Preconditions.checkArgument(maxGraphMemoryBytes > 0, "maxGraphMemoryBytes must be positive"); + this.maxGraphMemoryBytes = maxGraphMemoryBytes; + recordAction(b -> b.withMaxGraphMemoryBytes(maxGraphMemoryBytes)); + return this; + } + /** Set loadFactor of MapRefResolver writtenObjects. Default value is 0.51 */ public ForyBuilder withMapRefLoadFactor(float loadFactor) { Preconditions.checkArgument( diff --git a/java/fory-core/src/main/java/org/apache/fory/context/ReadContext.java b/java/fory-core/src/main/java/org/apache/fory/context/ReadContext.java index b1000161b3..1df6fead40 100644 --- a/java/fory-core/src/main/java/org/apache/fory/context/ReadContext.java +++ b/java/fory-core/src/main/java/org/apache/fory/context/ReadContext.java @@ -71,6 +71,7 @@ public final class ReadContext { private MetaReadContext metaReadContext; private boolean peerOutOfBandEnabled; private int depth; + private long remainingGraphMemoryBytes; /** * Creates read-side runtime state for one {@code Fory} instance. @@ -112,6 +113,7 @@ public void prepare( this.buffer = buffer; this.peerOutOfBandEnabled = peerOutOfBandEnabled; this.outOfBandBuffers = outOfBandBuffers == null ? null : outOfBandBuffers.iterator(); + remainingGraphMemoryBytes = config.maxGraphMemoryBytes(); } /** @@ -307,6 +309,7 @@ public void reset() { outOfBandBuffers = null; peerOutOfBandEnabled = false; depth = 0; + remainingGraphMemoryBytes = 0; } /** Returns the immutable runtime configuration for this context. */ @@ -314,6 +317,39 @@ public Config getConfig() { return config; } + // Failure may leave the remaining counter dirty; root cleanup resets read state for the + // operation. + public final void reserveGraphMemory(long bytes) { + long remaining = remainingGraphMemoryBytes - bytes; + remainingGraphMemoryBytes = remaining; + if ((bytes | remaining) < 0) { + throwInvalidGraphMemory(bytes, remaining + bytes); + } + } + + public final void reserveGraphMemory(int bytes) { + long remaining = remainingGraphMemoryBytes - bytes; + remainingGraphMemoryBytes = remaining; + if ((bytes | remaining) < 0) { + throwInvalidGraphMemory(bytes, remaining + bytes); + } + } + + private void throwInvalidGraphMemory(long bytes, long remaining) { + if (bytes < 0) { + throw new InsecureException( + "Estimated graph memory must be non-negative, but got " + bytes + " bytes."); + } + throw new InsecureException( + "Estimated graph memory request " + + bytes + + " bytes exceeds maxGraphMemoryBytes remaining budget " + + remaining + + " bytes out of effective limit " + + config.maxGraphMemoryBytes() + + " bytes. If the data is trusted, increase ForyBuilder#withMaxGraphMemoryBytes."); + } + /** Returns the generics stack shared by the owning runtime. */ public Generics getGenerics() { return generics; @@ -360,8 +396,10 @@ public boolean hasPreservedRefId() { } /** Binds the most recently preserved read ref id to {@code object}. */ - public void reference(Object object) { - refReader.reference(object); + public final void reference(Object object) { + if (trackingRef) { + refReader.reference(object); + } } /** Returns a previously read object by ref id. */ @@ -610,6 +648,15 @@ public Object readNonRef(TypeInfoHolder classInfoHolder) { /** Variant of {@link #readNonRef()} that uses already resolved {@link TypeInfo}. */ public Object readNonRef(TypeInfo typeInfo) { + int typeId = typeInfo.getTypeId(); + // User-defined xlang type IDs are contiguous; skip the primitive/string switch on this hot + // path. + if (typeId >= Types.ENUM && typeId <= Types.NAMED_UNION) { + increaseDepth(); + Object read = typeInfo.getSerializer().read(this); + decreaseDepth(); + return read; + } return readDataInternal(typeInfo); } diff --git a/java/fory-core/src/main/java/org/apache/fory/resolver/TypeResolver.java b/java/fory-core/src/main/java/org/apache/fory/resolver/TypeResolver.java index 275ccf1d05..8837fb8e50 100644 --- a/java/fory-core/src/main/java/org/apache/fory/resolver/TypeResolver.java +++ b/java/fory-core/src/main/java/org/apache/fory/resolver/TypeResolver.java @@ -773,7 +773,7 @@ public final TypeInfo readTypeInfo(ReadContext readContext, TypeInfo typeInfoCac break; case Types.COMPATIBLE_STRUCT: case Types.NAMED_COMPATIBLE_STRUCT: - typeInfo = readSharedClassTypeInfo(readContext, null); + typeInfo = readSharedClassTypeInfo(readContext, null, typeInfoCache); break; case Types.NAMED_ENUM: case Types.NAMED_STRUCT: @@ -782,7 +782,7 @@ public final TypeInfo readTypeInfo(ReadContext readContext, TypeInfo typeInfoCac if (!metaContextShareEnabled) { typeInfo = readTypeInfoFromBytes(readContext, typeInfoCache, typeId); } else { - typeInfo = readSharedClassTypeInfo(readContext, null); + typeInfo = readSharedClassTypeInfo(readContext, null, typeInfoCache); } break; case Types.LIST: diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/AbstractObjectSerializer.java b/java/fory-core/src/main/java/org/apache/fory/serializer/AbstractObjectSerializer.java index 5312540d41..f147288f4e 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/AbstractObjectSerializer.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/AbstractObjectSerializer.java @@ -69,6 +69,8 @@ public abstract class AbstractObjectSerializer extends Serializer { private static final Logger LOG = LoggerFactory.getLogger(AbstractObjectSerializer.class); private static final Object SELF_REFERENCE = new Object(); + private static final int OBJECT_SELF_BYTES = 1; + private static final int REFERENCE_BYTES = 4; // Constructor-bound objects reserve a ref id before constructor arguments are read, but the // object cannot be referenced semantically until the constructor returns. Generated constructor // serializers call the tracker before reading ref-tracking constructor-phase fields so nested @@ -79,6 +81,7 @@ public abstract class AbstractObjectSerializer extends Serializer { protected final TypeResolver typeResolver; protected final boolean isRecord; protected final ObjectInstantiator objectInstantiator; + protected final int objectGraphMemoryBytes; private SerializationFieldInfo[] fieldInfos; private RecordInfo copyRecordInfo; @@ -88,6 +91,7 @@ protected AbstractObjectSerializer() { this.typeResolver = null; this.isRecord = false; this.objectInstantiator = null; + this.objectGraphMemoryBytes = 0; } public AbstractObjectSerializer(TypeResolver typeResolver, Class type) { @@ -101,6 +105,37 @@ public AbstractObjectSerializer( this.typeResolver = typeResolver; this.isRecord = RecordUtils.isRecord(type); this.objectInstantiator = objectInstantiator; + this.objectGraphMemoryBytes = computeObjectGraphMemoryBytes(type); + } + + static int computeObjectGraphMemoryBytes(Class type) { + // One byte is a stable nonzero self cost, not an attempt to model JVM object headers. + int bytes = OBJECT_SELF_BYTES; + for (Field field : ReflectionUtils.getFields(type, true)) { + if (!Modifier.isStatic(field.getModifiers())) { + bytes = Math.addExact(bytes, fieldGraphMemoryBytes(field.getType())); + } + } + return bytes; + } + + private static int fieldGraphMemoryBytes(Class fieldType) { + if (!fieldType.isPrimitive()) { + return REFERENCE_BYTES; + } + if (fieldType == boolean.class || fieldType == byte.class) { + return 1; + } + if (fieldType == char.class || fieldType == short.class) { + return 2; + } + if (fieldType == int.class || fieldType == float.class) { + return 4; + } + if (fieldType == long.class || fieldType == double.class) { + return 8; + } + return 0; } static void writeField( diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/ArraySerializers.java b/java/fory-core/src/main/java/org/apache/fory/serializer/ArraySerializers.java index 9fe08fdfb5..4be9ce969f 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/ArraySerializers.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/ArraySerializers.java @@ -45,12 +45,28 @@ * object-array paths avoid adapter allocation. */ public final class ArraySerializers { + private static final int REFERENCE_BYTES = 4; + private static final int OBJECT_ARRAY_BYTES = 1; + private ArraySerializers() {} private static void throwInvalidObjectArraySize(int size) { throw new DeserializationException("Object array size must be non-negative: " + size); } + private static int readObjectArraySize(ReadContext readContext) { + MemoryBuffer buffer = readContext.getBuffer(); + int numElements = buffer.readVarUInt32Small7(); + // Keep this as direct primitive branches. Object-array reads allocate immediately; using + // Preconditions.checkArgument here would add helper/varargs overhead on the valid path. + if (numElements < 0) { + throwInvalidObjectArraySize(numElements); + } + readContext.reserveGraphMemory(OBJECT_ARRAY_BYTES + (long) numElements * REFERENCE_BYTES); + buffer.checkReadableBytes(numElements); + return numElements; + } + /** * Returns the object-array serializer for {@code cls}. * @@ -128,14 +144,7 @@ public Object[] copy(CopyContext copyContext, Object[] originArray) { @Override public Object[] read(ReadContext readContext) { - MemoryBuffer buffer = readContext.getBuffer(); - int numElements = buffer.readVarUInt32Small7(); - // Keep this as direct primitive branches. Object-array reads allocate immediately; using - // Preconditions.checkArgument here would add helper/varargs overhead on the valid path. - if (numElements < 0) { - throwInvalidObjectArraySize(numElements); - } - buffer.checkReadableBytes(numElements); + int numElements = readObjectArraySize(readContext); Object[] value = newArray(numElements); readContext.reference(value); if (numElements != 0) { @@ -213,14 +222,7 @@ public Object[] copy(CopyContext copyContext, Object[] originArray) { @Override public Object[] read(ReadContext readContext) { - MemoryBuffer buffer = readContext.getBuffer(); - int numElements = buffer.readVarUInt32Small7(); - // Keep this as direct primitive branches. Object-array reads allocate immediately; using - // Preconditions.checkArgument here would add helper/varargs overhead on the valid path. - if (numElements < 0) { - throwInvalidObjectArraySize(numElements); - } - buffer.checkReadableBytes(numElements); + int numElements = readObjectArraySize(readContext); Object[] value = newArray(numElements); readContext.reference(value); if (numElements != 0) { @@ -654,14 +656,7 @@ public void write(WriteContext writeContext, Object[] value) { @Override public Object[] read(ReadContext readContext) { - MemoryBuffer buffer = readContext.getBuffer(); - int numElements = buffer.readVarUInt32Small7(); - // Keep this as direct primitive branches. Object-array reads allocate immediately; using - // Preconditions.checkArgument here would add helper/varargs overhead on the valid path. - if (numElements < 0) { - throwInvalidObjectArraySize(numElements); - } - buffer.checkReadableBytes(numElements); + int numElements = readObjectArraySize(readContext); Object[] value = newArray(numElements); readContext.reference(value); if (numElements != 0) { diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/CompatibleCollectionArrayReader.java b/java/fory-core/src/main/java/org/apache/fory/serializer/CompatibleCollectionArrayReader.java index 35eeca550a..4fe26cb010 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/CompatibleCollectionArrayReader.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/CompatibleCollectionArrayReader.java @@ -59,6 +59,11 @@ import org.apache.fory.type.Types; final class CompatibleCollectionArrayReader { + // This compatible reader may be reached during native-image analysis. Use the settled + // reference-slot fallback instead of touching MemoryBuffer from class initialization. + private static final int COLLECTION_BYTES = 1; + private static final int REFERENCE_BYTES = 4; + static final int READ_LIST_TO_ARRAY = 1; static final int READ_ARRAY_TO_LIST = 2; static final int READ_LIST_TO_LIST = 3; @@ -343,18 +348,18 @@ private static Object readNotNull( if (array == null) { return null; } - return materializeTarget(array, arrayTypeId, targetType); + return materializeTarget(readContext, array, arrayTypeId, targetType); } if (readMode == READ_LIST_TO_LIST) { return readListBodyAsListTarget(readContext, arrayTypeId, elementTypeId, targetType); } if (readMode == READ_ARRAY_TO_LIST) { Object array = readDenseArrayBody(readContext, arrayTypeId); - return materializeTarget(array, arrayTypeId, targetType); + return materializeTarget(readContext, array, arrayTypeId, targetType); } if (readMode == READ_ARRAY_TO_ARRAY) { Object array = readDenseArrayBody(readContext, arrayTypeId); - return materializeTarget(array, arrayTypeId, targetType); + return materializeTarget(readContext, array, arrayTypeId, targetType); } throw new IllegalStateException("Unexpected compatible read mode " + readMode); } @@ -621,7 +626,7 @@ private static Object readListBodyAsListTarget( validateElementCount(numElements); if (numElements == 0) { Object array = readListPrimitiveElements(buffer, 0, arrayTypeId, elementTypeId, false); - return materializeTarget(array, arrayTypeId, targetType); + return materializeTarget(readContext, array, arrayTypeId, targetType); } int flags = buffer.readByte(); boolean hasNull = (flags & CollectionFlags.HAS_NULL) == CollectionFlags.HAS_NULL; @@ -654,11 +659,11 @@ private static Object readListBodyAsListTarget( throw new DeserializationException( "Cannot read null peer list element into local list field"); } - return readNullableListBoxedElements(buffer, numElements, arrayTypeId, elementTypeId); + return readNullableListBoxedElements(readContext, numElements, arrayTypeId, elementTypeId); } Object array = readListPrimitiveElements(buffer, numElements, arrayTypeId, elementTypeId, false); - return materializeTarget(array, arrayTypeId, targetType); + return materializeTarget(readContext, array, arrayTypeId, targetType); } private static Object readDenseArrayBody(ReadContext readContext, int arrayTypeId) { @@ -976,8 +981,11 @@ private static void readNonNullListElement(MemoryBuffer buffer) { } private static List readNullableListBoxedElements( - MemoryBuffer buffer, int numElements, int arrayTypeId, int elementTypeId) { - buffer.checkReadableBytes(minReadablePrimitiveListBytes(numElements, elementTypeId, true)); + ReadContext readContext, int numElements, int arrayTypeId, int elementTypeId) { + MemoryBuffer buffer = readContext.getBuffer(); + int bodyBytes = minReadablePrimitiveListBytes(numElements, elementTypeId, true); + readContext.reserveGraphMemory(COLLECTION_BYTES + (long) numElements * REFERENCE_BYTES); + buffer.checkReadableBytes(bodyBytes); ArrayList values = new ArrayList<>(numElements); for (int i = 0; i < numElements; i++) { byte headFlag = buffer.readByte(); @@ -1043,7 +1051,8 @@ private static Object readBoxedListElement( } } - private static Object materializeTarget(Object array, int arrayTypeId, Class targetType) { + private static Object materializeTarget( + ReadContext readContext, Object array, int arrayTypeId, Class targetType) { if (targetType.isArray()) { return array; } @@ -1058,7 +1067,7 @@ private static Object materializeTarget(Object array, int arrayTypeId, Class return primitiveList; } if (targetType.isAssignableFrom(ArrayList.class)) { - return materializeBoxedList(array, arrayTypeId); + return materializeBoxedList(readContext, array, arrayTypeId); } throw new DeserializationException("Unsupported compatible list/array target " + targetType); } @@ -1172,8 +1181,10 @@ private static boolean canMaterializePrimitiveListTarget(Class targetType, in } } - private static List materializeBoxedList(Object array, int arrayTypeId) { + private static List materializeBoxedList( + ReadContext readContext, Object array, int arrayTypeId) { int size = java.lang.reflect.Array.getLength(array); + readContext.reserveGraphMemory(COLLECTION_BYTES + (long) size * REFERENCE_BYTES); ArrayList list = new ArrayList<>(size); switch (arrayTypeId) { case Types.BOOL_ARRAY: diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/CompatibleLayerSerializerBase.java b/java/fory-core/src/main/java/org/apache/fory/serializer/CompatibleLayerSerializerBase.java index e7040f16f4..67fab5f10a 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/CompatibleLayerSerializerBase.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/CompatibleLayerSerializerBase.java @@ -135,6 +135,7 @@ public Object[] readFieldValues(ReadContext readContext) { @Override public T read(ReadContext readContext) { checkLayerSerializerMeta(); + readContext.reserveGraphMemory(objectGraphMemoryBytes); T obj = newBean(); readContext.reference(obj); return readAndSetFields(readContext, obj); diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/CompatibleSerializer.java b/java/fory-core/src/main/java/org/apache/fory/serializer/CompatibleSerializer.java index b22843e05e..1af0432dcc 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/CompatibleSerializer.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/CompatibleSerializer.java @@ -237,6 +237,7 @@ private T newInstance() { @Override public T read(ReadContext readContext) { + readContext.reserveGraphMemory(objectGraphMemoryBytes); if (isRecord) { Object[] fieldValues = new Object[allFields.length]; if (hasCompatibleCollectionArrayRead) { diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/ExceptionSerializers.java b/java/fory-core/src/main/java/org/apache/fory/serializer/ExceptionSerializers.java index b3b54e18be..345b40ee9f 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/ExceptionSerializers.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/ExceptionSerializers.java @@ -62,6 +62,7 @@ public static final class ExceptionSerializer extends Seria private final TypeResolver typeResolver; private final ObjectInstantiator objectInstantiator; private final Constructor messageConstructor; + private final int graphMemoryBytes; private volatile Serializer[] slotsSerializers; private volatile boolean rebuildSlotsSerializersAtRuntime; @@ -74,6 +75,7 @@ public ExceptionSerializer(TypeResolver typeResolver, Class type) { messageConstructor == null && MemoryUtils.JDK_LANG_FIELD_ACCESS ? createThrowableObjectInstantiator(typeResolver, type) : null; + graphMemoryBytes = AbstractObjectSerializer.computeObjectGraphMemoryBytes(type); slotsSerializers = buildSlotsSerializers(typeResolver, type); if (!MemoryUtils.JDK_LANG_FIELD_ACCESS && isJdkThrowable(type) @@ -117,6 +119,7 @@ public T read(ReadContext readContext) { return readAndroidThrowableWithoutDetailMessageField( readContext, stackTrace, slotsSerializers); } + readContext.reserveGraphMemory(graphMemoryBytes); T obj = newThrowableForRead(); readContext.reference(obj); Throwable cause = (Throwable) readContext.readRef(); @@ -157,6 +160,7 @@ private T readAndroidThrowableWithoutDetailMessageField( + " requires JDK internal field access. " + jdkFieldAccessMessage()); } + readContext.reserveGraphMemory(graphMemoryBytes); T obj = newThrowableWithMessage(detailMessage); readContext.reference(obj); if (stackTrace != null) { diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/ObjectSerializer.java b/java/fory-core/src/main/java/org/apache/fory/serializer/ObjectSerializer.java index e656359d14..e3f8300398 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/ObjectSerializer.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/ObjectSerializer.java @@ -65,6 +65,8 @@ public final class ObjectSerializer extends AbstractObjectSerializer { private final RecordInfo recordInfo; private final SerializationFieldInfo[] allFields; private final int classVersionHash; + private final boolean trackingRef; + private final boolean checkClassVersion; public ObjectSerializer(TypeResolver typeResolver, Class cls) { this(typeResolver, cls, true); @@ -80,6 +82,8 @@ public ObjectSerializer( boolean resolveParent, ObjectInstantiator objectInstantiator) { super(typeResolver, cls, objectInstantiator); + trackingRef = config.trackingRef(); + checkClassVersion = typeResolver.checkClassVersion(); // avoid recursive building serializers. // Use `setSerializerIfAbsent` to avoid overwriting existing serializer for class when used // as data serializer. @@ -208,6 +212,7 @@ private void writeFieldByCodecCategory( @Override public T read(ReadContext readContext) { + readContext.reserveGraphMemory((long) objectGraphMemoryBytes); MemoryBuffer buffer = readContext.getBuffer(); if (isRecord) { Object[] fields = readFields(readContext); @@ -217,14 +222,16 @@ public T read(ReadContext readContext) { return obj; } T obj = newBean(); - readContext.reference(obj); + if (trackingRef) { + readContext.reference(obj); + } return readAndSetFields(readContext, obj); } public Object[] readFields(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); RefReader refReader = readContext.getRefReader(); - if (typeResolver.checkClassVersion()) { + if (checkClassVersion) { int hash = buffer.readInt32(); checkClassVersion(type, hash, classVersionHash); } @@ -241,7 +248,7 @@ public Object[] readFields(ReadContext readContext) { public T readAndSetFields(ReadContext readContext, T obj) { MemoryBuffer buffer = readContext.getBuffer(); RefReader refReader = readContext.getRefReader(); - if (typeResolver.checkClassVersion()) { + if (checkClassVersion) { int hash = buffer.readInt32(); checkClassVersion(type, hash, classVersionHash); } diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/ObjectStreamSerializer.java b/java/fory-core/src/main/java/org/apache/fory/serializer/ObjectStreamSerializer.java index 149e872862..8f310efc7f 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/ObjectStreamSerializer.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/ObjectStreamSerializer.java @@ -269,6 +269,7 @@ public void write(WriteContext writeContext, Object value) { @Override public Object read(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); + readContext.reserveGraphMemory(objectGraphMemoryBytes); Object obj = objectInstantiator.newInstance(); readContext.reference(obj); int numClasses = buffer.readInt16(); diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/UnknownClassSerializers.java b/java/fory-core/src/main/java/org/apache/fory/serializer/UnknownClassSerializers.java index 1356d80a35..efd02dec7d 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/UnknownClassSerializers.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/UnknownClassSerializers.java @@ -58,6 +58,11 @@ private ClassFieldsInfo(FieldGroups fieldGroups, int classVersionHash) { } public static final class UnknownStructSerializer extends Serializer { + private static final int UNKNOWN_STRUCT_SELF_BYTES = 1; + private static final int UNKNOWN_STRUCT_REFERENCE_BYTES = 4; + private static final int UNKNOWN_STRUCT_ENTRY_BYTES = + UNKNOWN_STRUCT_SELF_BYTES + 2 * UNKNOWN_STRUCT_REFERENCE_BYTES; + private static final int NONEXISTENT_META_SHARED_ID_SIZE = computeVarUInt32Size(ClassResolver.NONEXISTENT_META_SHARED_ID); private final Config config; @@ -246,11 +251,15 @@ private ClassFieldsInfo getClassFieldsInfo(TypeDef typeDef) { @Override public Object read(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); + ClassFieldsInfo allFieldsInfo = getClassFieldsInfo(typeDef); + readContext.reserveGraphMemory( + UNKNOWN_STRUCT_SELF_BYTES + + 2L * UNKNOWN_STRUCT_REFERENCE_BYTES + + (long) allFieldsInfo.allFields.length * UNKNOWN_STRUCT_ENTRY_BYTES); UnknownClass.UnknownStruct obj = new UnknownClass.UnknownStruct(typeDef); readContext.reference(obj); List entries = new ArrayList<>(); // Protocol order: primitive, nullable primitive, then all non-primitives by field identifier. - ClassFieldsInfo allFieldsInfo = getClassFieldsInfo(typeDef); Generics generics = readContext.getGenerics(); for (SerializationFieldInfo fieldInfo : allFieldsInfo.allFields) { Object fieldValue = readFieldByCodecCategory(readContext, generics, fieldInfo, buffer); diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/ChildContainerSerializers.java b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/ChildContainerSerializers.java index f7840349ef..49c65b56cf 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/ChildContainerSerializers.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/ChildContainerSerializers.java @@ -249,7 +249,7 @@ public Collection onCollectionWrite(WriteContext writeContext, T value) { @Override public T newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readCollectionSize(buffer); + int numElements = readCollectionSize(readContext, buffer); setNumElements(numElements); int refId = readContext.lastPreservedRefId(); Comparator comparator = (Comparator) readContext.readRef(); @@ -295,7 +295,7 @@ public Collection onCollectionWrite(WriteContext writeContext, T value) { @Override public T newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readCollectionSize(buffer); + int numElements = readCollectionSize(readContext, buffer); setNumElements(numElements); int refId = readContext.lastPreservedRefId(); Comparator comparator = (Comparator) readContext.readRef(); @@ -403,7 +403,7 @@ public Map onMapWrite(WriteContext writeContext, T value) { @Override public Map newMap(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readMapSize(buffer); + int numElements = readMapSize(readContext, buffer); setNumElements(numElements); int refId = readContext.lastPreservedRefId(); Comparator comparator = (Comparator) readContext.readRef(); diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/CollectionLikeSerializer.java b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/CollectionLikeSerializer.java index 3915b5d888..2796a9b2bf 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/CollectionLikeSerializer.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/CollectionLikeSerializer.java @@ -46,6 +46,9 @@ */ @SuppressWarnings({"unchecked", "rawtypes"}) public abstract class CollectionLikeSerializer extends Serializer { + private static final int COLLECTION_BYTES = 1; + private static final int REFERENCE_BYTES = 4; + private MethodHandle constructor; private int numElements; protected final Config config; @@ -461,7 +464,7 @@ public T read(ReadContext readContext) { */ public Collection newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - numElements = readCollectionSize(buffer); + numElements = readCollectionSize(readContext, buffer); if (AndroidSupport.IS_ANDROID) { try { Constructor constructor = type.getDeclaredConstructor(); @@ -560,9 +563,10 @@ protected void setNumElements(int numElements) { this.numElements = numElements; } - protected final int readCollectionSize(MemoryBuffer buffer) { + protected final int readCollectionSize(ReadContext readContext, MemoryBuffer buffer) { int numElements = buffer.readVarUInt32Small7(); checkCollectionSize(numElements); + readContext.reserveGraphMemory(COLLECTION_BYTES + (long) numElements * REFERENCE_BYTES); buffer.checkReadableBytes(numElements); return numElements; } diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/CollectionSerializers.java b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/CollectionSerializers.java index a81c38298f..69118249a0 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/CollectionSerializers.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/CollectionSerializers.java @@ -74,6 +74,8 @@ */ @SuppressWarnings({"unchecked", "rawtypes"}) public class CollectionSerializers { + private static final int REFERENCE_BYTES = 4; + private static final Comparator NATURAL_ORDER_COMPARATOR = Comparator.naturalOrder(); private static void requireXlangNaturalOrdering(Class type, Comparator comparator) { @@ -127,7 +129,7 @@ public ArrayListSerializer(TypeResolver typeResolver) { @Override public ArrayList newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readCollectionSize(buffer); + int numElements = readCollectionSize(readContext, buffer); setNumElements(numElements); ArrayList arrayList = new ArrayList(numElements); readContext.reference(arrayList); @@ -189,7 +191,7 @@ public List read(ReadContext readContext) { @Override public ArrayList newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readCollectionSize(buffer); + int numElements = readCollectionSize(readContext, buffer); setNumElements(numElements); ArrayList arrayList = new ArrayList(numElements); readContext.reference(arrayList); @@ -205,7 +207,7 @@ public HashSetSerializer(TypeResolver typeResolver) { @Override public HashSet newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readCollectionSize(buffer); + int numElements = readCollectionSize(readContext, buffer); setNumElements(numElements); HashSet hashSet = new HashSet(numElements); readContext.reference(hashSet); @@ -221,7 +223,7 @@ public LinkedHashSetSerializer(TypeResolver typeResolver) { @Override public LinkedHashSet newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readCollectionSize(buffer); + int numElements = readCollectionSize(readContext, buffer); setNumElements(numElements); LinkedHashSet hashSet = new LinkedHashSet(numElements); readContext.reference(hashSet); @@ -270,7 +272,7 @@ public Collection onCollectionWrite(WriteContext writeContext, T value) { @Override public T newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readCollectionSize(buffer); + int numElements = readCollectionSize(readContext, buffer); setNumElements(numElements); T collection; Comparator comparator = config.isXlang() ? null : (Comparator) readContext.readRef(); @@ -335,7 +337,8 @@ public void write(WriteContext writeContext, List value) { @Override public List read(ReadContext readContext) { if (config.isXlang()) { - int numElements = readCollectionSize(readContext.getBuffer()); + MemoryBuffer buffer = readContext.getBuffer(); + int numElements = readCollectionSize(readContext, buffer); if (numElements != 0) { throw new DeserializationException( "Empty list body must have zero elements but got " + numElements); @@ -356,7 +359,7 @@ public CopyOnWriteArrayListSerializer( @Override public Collection newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readCollectionSize(buffer); + int numElements = readCollectionSize(readContext, buffer); setNumElements(numElements); return new CollectionContainer<>(numElements); } @@ -390,7 +393,7 @@ public CopyOnWriteArraySetSerializer( @Override public Collection newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readCollectionSize(buffer); + int numElements = readCollectionSize(readContext, buffer); setNumElements(numElements); return new CollectionContainer<>(numElements); } @@ -542,7 +545,7 @@ public CollectionSnapshot onCollectionWrite( @Override public ConcurrentSkipListSet newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readCollectionSize(buffer); + int numElements = readCollectionSize(readContext, buffer); setNumElements(numElements); if (config.isXlang()) { ConcurrentSkipListSet skipListSet = new ConcurrentSkipListSet(); @@ -726,7 +729,7 @@ public VectorSerializer(TypeResolver typeResolver, Class cls) { @Override public Vector newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readCollectionSize(buffer); + int numElements = readCollectionSize(readContext, buffer); setNumElements(numElements); Vector vector = new Vector<>(numElements); readContext.reference(vector); @@ -743,7 +746,7 @@ public ArrayDequeSerializer(TypeResolver typeResolver, Class cls) { @Override public ArrayDeque newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readCollectionSize(buffer); + int numElements = readCollectionSize(readContext, buffer); setNumElements(numElements); ArrayDeque deque = new ArrayDeque(numElements); readContext.reference(deque); @@ -786,9 +789,9 @@ public void write(WriteContext writeContext, EnumSet object) { public EnumSet read(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); Class elemClass = typeResolver.readTypeInfo(readContext).getType(); + int length = readCollectionSize(readContext, buffer); EnumSet object = EnumSet.noneOf(elemClass); Serializer elemSerializer = typeResolver.getSerializer(elemClass); - int length = readCollectionSize(buffer); for (int i = 0; i < length; i++) { object.add(elemSerializer.read(readContext)); } @@ -863,7 +866,7 @@ public Collection newCollection(CopyContext copyContext, Collection collection) public PriorityQueue newCollection(ReadContext readContext) { assert !config.isXlang(); MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readCollectionSize(buffer); + int numElements = readCollectionSize(readContext, buffer); setNumElements(numElements); Comparator comparator = (Comparator) readContext.readRef(); PriorityQueue queue = new PriorityQueue(comparator); @@ -923,10 +926,11 @@ public CollectionSnapshot onCollectionWrite( @Override public ArrayBlockingQueue newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readCollectionSize(buffer); + int numElements = readCollectionSize(readContext, buffer); setNumElements(numElements); int capacity = buffer.readVarUInt32Small7(); checkBoundedQueueCapacity(numElements, capacity); + readContext.reserveGraphMemory((long) (capacity - numElements) * REFERENCE_BYTES); buffer.checkReadableBytes(capacity); ArrayBlockingQueue queue = new ArrayBlockingQueue<>(capacity); readContext.reference(queue); @@ -990,10 +994,12 @@ public CollectionSnapshot onCollectionWrite( @Override public LinkedBlockingQueue newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readCollectionSize(buffer); + int numElements = readCollectionSize(readContext, buffer); setNumElements(numElements); int capacity = buffer.readVarUInt32Small7(); checkBoundedQueueCapacity(numElements, capacity); + // LinkedBlockingQueue capacity is a logical bound, not preallocated backing storage. The + // current node storage is already reserved by readCollectionSize(numElements). LinkedBlockingQueue queue = new LinkedBlockingQueue<>(capacity); readContext.reference(queue); return queue; @@ -1130,7 +1136,7 @@ public XlangListDefaultSerializer(TypeResolver typeResolver, Class cls) { @Override public List newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readCollectionSize(buffer); + int numElements = readCollectionSize(readContext, buffer); setNumElements(numElements); ArrayList list = new ArrayList(numElements); readContext.reference(list); @@ -1146,7 +1152,7 @@ public XlangSetDefaultSerializer(TypeResolver typeResolver, Class cls) { @Override public Set newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readCollectionSize(buffer); + int numElements = readCollectionSize(readContext, buffer); setNumElements(numElements); HashSet set = new HashSet(numElements); readContext.reference(set); diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/GuavaCollectionSerializers.java b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/GuavaCollectionSerializers.java index c28aa04561..52ce79a937 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/GuavaCollectionSerializers.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/GuavaCollectionSerializers.java @@ -94,7 +94,7 @@ public ImmutableListSerializer(TypeResolver typeResolver, Class cls) { @Override public Collection newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readCollectionSize(buffer); + int numElements = readCollectionSize(readContext, buffer); setNumElements(numElements); return new CollectionContainer<>(numElements); } @@ -127,7 +127,7 @@ public RegularImmutableListSerializer(TypeResolver typeResolver, Class cls) { @Override public Collection newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readCollectionSize(buffer); + int numElements = readCollectionSize(readContext, buffer); setNumElements(numElements); return new CollectionContainer(numElements); } @@ -161,7 +161,7 @@ public ImmutableSetSerializer(TypeResolver typeResolver, Class cls) { @Override public Collection newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readCollectionSize(buffer); + int numElements = readCollectionSize(readContext, buffer); setNumElements(numElements); return new CollectionContainer<>(numElements); } @@ -203,7 +203,7 @@ public Collection onCollectionWrite(WriteContext writeContext, T value) { @Override public Collection newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readCollectionSize(buffer); + int numElements = readCollectionSize(readContext, buffer); setNumElements(numElements); Comparator comparator = (Comparator) readContext.readRef(); return new SortedCollectionContainer(comparator, numElements); @@ -236,7 +236,7 @@ public GuavaMapSerializer(TypeResolver typeResolver, Class cls) { @Override public Map newMap(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readMapSize(buffer); + int numElements = readMapSize(readContext, buffer); setNumElements(numElements); return new MapContainer(numElements); } @@ -264,7 +264,8 @@ public T onMapRead(Map map) { @Override public T read(ReadContext readContext) { - int size = readMapSize(readContext.getBuffer()); + MemoryBuffer buffer = readContext.getBuffer(); + int size = readMapSize(readContext, buffer); Map map = new HashMap(); readElements(readContext, size, map); return xnewInstance(map); @@ -574,7 +575,7 @@ public Map onMapWrite(WriteContext writeContext, T value) { @Override public Map newMap(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readMapSize(buffer); + int numElements = readMapSize(readContext, buffer); setNumElements(numElements); Comparator comparator = (Comparator) readContext.readRef(); return new SortedMapContainer<>(comparator, numElements); diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/ImmutableCollectionSerializers.java b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/ImmutableCollectionSerializers.java index cd69f2b6cf..65c8dcdacc 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/ImmutableCollectionSerializers.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/ImmutableCollectionSerializers.java @@ -125,7 +125,7 @@ public ImmutableListSerializer(TypeResolver typeResolver, Class cls) { @Override public Collection newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readCollectionSize(buffer); + int numElements = readCollectionSize(readContext, buffer); setNumElements(numElements); if (JdkVersion.MAJOR_VERSION > 8) { return new CollectionContainer<>(numElements); @@ -186,7 +186,7 @@ public ImmutableSetSerializer(TypeResolver typeResolver, Class cls) { @Override public Collection newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readCollectionSize(buffer); + int numElements = readCollectionSize(readContext, buffer); setNumElements(numElements); if (JdkVersion.MAJOR_VERSION > 8) { return new CollectionContainer<>(numElements); @@ -247,7 +247,7 @@ public ImmutableMapSerializer(TypeResolver typeResolver, Class cls) { @Override public Map newMap(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readMapSize(buffer); + int numElements = readMapSize(readContext, buffer); setNumElements(numElements); if (JdkVersion.MAJOR_VERSION > 8) { return new JDKImmutableMapContainer(numElements); diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/MapLikeSerializer.java b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/MapLikeSerializer.java index 4f6f828d11..eab1479522 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/MapLikeSerializer.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/MapLikeSerializer.java @@ -58,6 +58,8 @@ @SuppressWarnings({"unchecked", "rawtypes"}) public abstract class MapLikeSerializer extends Serializer { public static final int MAX_CHUNK_SIZE = 255; + private static final int MAP_BYTES = 1; + private static final int REFERENCE_BYTES = 4; static final class MapTypeCache { final TypeInfoHolder keyTypeInfoWriteCache; @@ -895,7 +897,7 @@ public void onMapWriteFinish(Map map) {} */ public Map newMap(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - numElements = readMapSize(buffer); + numElements = readMapSize(readContext, buffer); if (AndroidSupport.IS_ANDROID) { try { Constructor constructor = type.getDeclaredConstructor(); @@ -964,12 +966,13 @@ public void setNumElements(int numElements) { this.numElements = numElements; } - protected final int readMapSize(MemoryBuffer buffer) { + protected final int readMapSize(ReadContext readContext, MemoryBuffer buffer) { int numElements = buffer.readVarUInt32Small7(); checkMapSize(numElements); if (numElements > Integer.MAX_VALUE / 2) { throwInvalidMapBodySize(numElements); } + readContext.reserveGraphMemory(MAP_BYTES + (long) numElements * 2 * REFERENCE_BYTES); buffer.checkReadableBytes(numElements << 1); return numElements; } diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/MapSerializers.java b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/MapSerializers.java index 9b3825495a..64234a6a47 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/MapSerializers.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/MapSerializers.java @@ -86,7 +86,7 @@ public HashMapSerializer(TypeResolver typeResolver) { @Override public HashMap newMap(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readMapSize(buffer); + int numElements = readMapSize(readContext, buffer); setNumElements(numElements); HashMap hashMap = new HashMap(numElements); readContext.reference(hashMap); @@ -107,7 +107,7 @@ public LinkedHashMapSerializer(TypeResolver typeResolver) { @Override public LinkedHashMap newMap(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readMapSize(buffer); + int numElements = readMapSize(readContext, buffer); setNumElements(numElements); LinkedHashMap hashMap = new LinkedHashMap(numElements); readContext.reference(hashMap); @@ -146,7 +146,7 @@ public LazyMapSerializer(TypeResolver typeResolver) { @Override public LazyMap newMap(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readMapSize(buffer); + int numElements = readMapSize(readContext, buffer); setNumElements(numElements); LazyMap map = new LazyMap(numElements); readContext.reference(map); @@ -200,7 +200,7 @@ public Map onMapWrite(WriteContext writeContext, T value) { @Override public Map newMap(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - setNumElements(readMapSize(buffer)); + setNumElements(readMapSize(readContext, buffer)); T map; Comparator comparator = config.isXlang() ? null : (Comparator) readContext.readRef(); if (type == TreeMap.class) { @@ -322,7 +322,7 @@ public ConcurrentHashMapSerializer(TypeResolver typeResolver, Class keyType = typeResolver.readTypeInfo(readContext).getType(); EnumMap map = new EnumMap(keyType); readContext.reference(map); @@ -619,7 +619,7 @@ public Object onMapCopy(Map map) { public Map newMap(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readMapSize(buffer); + int numElements = readMapSize(readContext, buffer); setNumElements(numElements); HashMap map = new HashMap<>(numElements); readContext.reference(map); diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/SubListSerializers.java b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/SubListSerializers.java index f11f4b79a3..3775f7f3b6 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/SubListSerializers.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/SubListSerializers.java @@ -158,7 +158,7 @@ public List read(ReadContext readContext) { @Override public Collection newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readCollectionSize(buffer); + int numElements = readCollectionSize(readContext, buffer); setNumElements(numElements); ArrayList list = new ArrayList(numElements); readContext.reference(list); diff --git a/java/fory-core/src/test/java/org/apache/fory/serializer/ArraySerializersTest.java b/java/fory-core/src/test/java/org/apache/fory/serializer/ArraySerializersTest.java index b0fb46f0f4..df69696d64 100644 --- a/java/fory-core/src/test/java/org/apache/fory/serializer/ArraySerializersTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/serializer/ArraySerializersTest.java @@ -387,8 +387,8 @@ private static Object readTruncatedPrimitiveArrayBody( ReadContext readContext = fory.getReadContext(); MemoryBuffer buffer = MemoryBuffer.newHeapBuffer(5); buffer.writeVarUInt32Small7(byteSize); - readContext.prepare( - MemoryBuffer.fromByteArray(buffer.getBytes(0, buffer.writerIndex())), null, false); + MemoryBuffer truncated = MemoryBuffer.fromByteArray(buffer.getBytes(0, buffer.writerIndex())); + readContext.prepare(truncated, null, false); return fory.getSerializer(arrayType).read(readContext); } diff --git a/java/fory-core/src/test/java/org/apache/fory/serializer/CompatibleSerializerTest.java b/java/fory-core/src/test/java/org/apache/fory/serializer/CompatibleSerializerTest.java index cd3d31dcac..ad25757355 100644 --- a/java/fory-core/src/test/java/org/apache/fory/serializer/CompatibleSerializerTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/serializer/CompatibleSerializerTest.java @@ -33,6 +33,7 @@ import org.apache.fory.ForyTestBase; import org.apache.fory.TestUtils; import org.apache.fory.config.Language; +import org.apache.fory.context.ReadContext; import org.apache.fory.memory.MemoryBuffer; import org.apache.fory.memory.MemoryUtils; import org.apache.fory.serializer.collection.UnmodifiableSerializersTest; @@ -138,14 +139,21 @@ public void testWriteCompatibleBasic() throws Exception { public void testNullableListBodyBounds() throws Exception { Method method = CompatibleCollectionArrayReader.class.getDeclaredMethod( - "readNullableListBoxedElements", MemoryBuffer.class, int.class, int.class, int.class); + "readNullableListBoxedElements", ReadContext.class, int.class, int.class, int.class); method.setAccessible(true); MemoryBuffer buffer = MemoryUtils.buffer(0); - InvocationTargetException exception = - Assert.expectThrows( - InvocationTargetException.class, - () -> method.invoke(null, buffer, 1024, Types.INT32_ARRAY, Types.INT32)); - Assert.assertTrue(exception.getCause() instanceof IndexOutOfBoundsException); + Fory fory = builder().build(); + ReadContext readContext = fory.getReadContext(); + readContext.prepare(buffer, null, false); + try { + InvocationTargetException exception = + Assert.expectThrows( + InvocationTargetException.class, + () -> method.invoke(null, readContext, 1024, Types.INT32_ARRAY, Types.INT32)); + Assert.assertTrue(exception.getCause() instanceof IndexOutOfBoundsException); + } finally { + readContext.reset(); + } } @Test diff --git a/java/fory-core/src/test/java/org/apache/fory/serializer/GraphMemoryBudgetTest.java b/java/fory-core/src/test/java/org/apache/fory/serializer/GraphMemoryBudgetTest.java new file mode 100644 index 0000000000..b4aa372fee --- /dev/null +++ b/java/fory-core/src/test/java/org/apache/fory/serializer/GraphMemoryBudgetTest.java @@ -0,0 +1,318 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.fory.serializer; + +import static org.apache.fory.io.ForyStreamReader.of; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertThrows; +import static org.testng.Assert.assertTrue; + +import java.io.ByteArrayInputStream; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import org.apache.fory.Fory; +import org.apache.fory.ForyTestBase; +import org.apache.fory.collection.Int32List; +import org.apache.fory.context.ReadContext; +import org.apache.fory.exception.DeserializationException; +import org.apache.fory.exception.InsecureException; +import org.apache.fory.memory.MemoryBuffer; +import org.testng.annotations.Test; + +public class GraphMemoryBudgetTest extends ForyTestBase { + private static final long DEFAULT_GRAPH_MEMORY_BYTES = 128L * 1024 * 1024; + private static final int REFERENCE_BYTES = 4; + private static final int OBJECT_SELF_BYTES = 1; + + @Test + public void testConfigDefaultsAndValidation() { + assertEquals(builder().build().getConfig().maxGraphMemoryBytes(), DEFAULT_GRAPH_MEMORY_BYTES); + assertEquals(newFory(123).getConfig().maxGraphMemoryBytes(), 123); + assertThrows(IllegalArgumentException.class, () -> newFory(0)); + assertThrows(IllegalArgumentException.class, () -> newFory(-2)); + } + + @Test + public void testDefaultFixedBudget() { + ReadContext readContext = prepareContext(builder().build()); + try { + readContext.reserveGraphMemory(DEFAULT_GRAPH_MEMORY_BYTES); + assertThrows(InsecureException.class, () -> readContext.reserveGraphMemory(1)); + } finally { + readContext.reset(); + } + } + + @Test + public void testExplicitBudgetWins() { + Fory fory = newFory(7); + ReadContext readContext = prepareContext(fory); + try { + readContext.reserveGraphMemory(7); + assertThrows(InsecureException.class, () -> readContext.reserveGraphMemory(1)); + } finally { + readContext.reset(); + } + } + + @Test + public void testNestedEmptyContainers() { + List value = emptyLists(1); + byte[] bytes = builder().build().serialize(value); + long required = collectionBytes(1) + collectionBytes(0); + + assertThrows(InsecureException.class, () -> newFory(required - 1).deserialize(bytes)); + assertThrows( + InsecureException.class, + () -> newFory(required - 1).deserialize(of(new ByteArrayInputStream(bytes)))); + assertEquals(newFory(required).deserialize(bytes), value); + assertEquals(newFory(required).deserialize(of(new ByteArrayInputStream(bytes))), value); + } + + @Test + public void testSiblingBudgetIsCumulative() { + List value = nullLists(2, 64); + byte[] bytes = builder().build().serialize(value); + long firstChildOnly = collectionBytes(2) + collectionBytes(64); + + assertThrows(InsecureException.class, () -> newFory(firstChildOnly).deserialize(bytes)); + assertEquals(newFory(collectionBytes(2) + 2L * collectionBytes(64)).deserialize(bytes), value); + } + + @Test + public void testMapBudgetAndOverflow() { + Fory fory = newFory(mapBytes(1) - 1); + ReadContext readContext = prepareContext(fory); + try { + assertThrows(InsecureException.class, () -> readContext.reserveGraphMemory(mapBytes(1))); + } finally { + readContext.reset(); + } + + Fory exactFory = newFory(mapBytes(1)); + ReadContext exactContext = prepareContext(exactFory); + try { + exactContext.reserveGraphMemory(mapBytes(1)); + assertThrows(InsecureException.class, () -> exactContext.reserveGraphMemory(1)); + } finally { + exactContext.reset(); + } + + MemoryBuffer buffer = MemoryBuffer.newHeapBuffer(8); + buffer.writeVarUInt32Small7(Integer.MAX_VALUE); + buffer = trimBuffer(buffer); + Fory reader = newFory(DEFAULT_GRAPH_MEMORY_BYTES); + ReadContext mapContext = reader.getReadContext(); + mapContext.prepare(buffer, null, false); + try { + assertThrows( + DeserializationException.class, + () -> reader.getSerializer(HashMap.class).read(mapContext)); + } finally { + mapContext.reset(); + } + } + + @Test + public void testObjectArrayBudget() { + Fory exactFory = newFory(1); + ReadContext exactContext = exactFory.getReadContext(); + MemoryBuffer exactBuffer = objectArraySizeBuffer(0); + exactContext.prepare(exactBuffer, null, false); + try { + Object[] array = (Object[]) exactFory.getSerializer(Object[].class).read(exactContext); + assertEquals(array.length, 0); + } finally { + exactContext.reset(); + } + + Fory slotFory = newFory(objectArrayBytes(2) - 1); + ReadContext slotContext = slotFory.getReadContext(); + MemoryBuffer slotBuffer = objectArraySizeBuffer(2); + slotContext.prepare(slotBuffer, null, false); + try { + assertThrows( + InsecureException.class, () -> slotFory.getSerializer(Object[].class).read(slotContext)); + } finally { + slotContext.reset(); + } + } + + @Test + public void testPojoGraphBudget() { + Pojo value = new Pojo(7, 9L, "child string is skipped as a leaf"); + byte[] bytes = builder().build().serialize(value); + long required = pojoBytes(); + + assertThrows(InsecureException.class, () -> newFory(required - 1, false).deserialize(bytes)); + assertEquals(newFory(required, false).deserialize(bytes), value); + + assertThrows(InsecureException.class, () -> newFory(required - 1, true).deserialize(bytes)); + assertEquals(newFory(required, true).deserialize(bytes), value); + } + + @Test + public void testNestedEmptyPojoGraphBudget() { + ArrayList value = new ArrayList<>(); + value.add(new EmptyPojo()); + value.add(new EmptyPojo()); + byte[] bytes = builder().build().serialize(value); + long required = collectionBytes(2) + 2L * emptyPojoBytes(); + + assertThrows(InsecureException.class, () -> newFory(required - 1).deserialize(bytes)); + List decoded = (List) newFory(required).deserialize(bytes); + assertEquals(decoded.size(), 2); + assertTrue(decoded.get(0) instanceof EmptyPojo); + assertTrue(decoded.get(1) instanceof EmptyPojo); + } + + @Test + public void testScalarOwnersSkipBudget() { + Fory fory = newFory(1); + assertEquals(fory.deserialize(fory.serialize("graph budget")), "graph budget"); + + byte[] bytes = new byte[] {1, 2, 3}; + assertTrue(Arrays.equals((byte[]) fory.deserialize(fory.serialize(bytes)), bytes)); + + int[] ints = new int[] {4, 5, 6}; + assertTrue(Arrays.equals((int[]) fory.deserialize(fory.serialize(ints)), ints)); + + Int32List denseList = new Int32List(new int[] {7, 8, 9}); + assertEquals(fory.deserialize(fory.serialize(denseList)), denseList); + } + + @Test + public void testTruncatedCollectionStillFails() { + Fory fory = newFory(collectionBytes(3)); + MemoryBuffer buffer = MemoryBuffer.newHeapBuffer(8); + buffer.writeVarUInt32Small7(3); + buffer.writeByte(0); + buffer.writeByte(0); + buffer = trimBuffer(buffer); + ReadContext readContext = fory.getReadContext(); + readContext.prepare(buffer, null, false); + try { + assertThrows( + IndexOutOfBoundsException.class, + () -> fory.getSerializer(ArrayList.class).read(readContext)); + } finally { + readContext.reset(); + } + } + + private static Fory newFory(long maxGraphMemoryBytes) { + return newFory(maxGraphMemoryBytes, true); + } + + private static Fory newFory(long maxGraphMemoryBytes, boolean codegen) { + return builder().withMaxGraphMemoryBytes(maxGraphMemoryBytes).withCodegen(codegen).build(); + } + + private static ReadContext prepareContext(Fory fory) { + MemoryBuffer buffer = MemoryBuffer.newHeapBuffer(0); + ReadContext readContext = fory.getReadContext(); + readContext.prepare(buffer, null, false); + return readContext; + } + + private static long collectionBytes(int numElements) { + return OBJECT_SELF_BYTES + (long) numElements * REFERENCE_BYTES; + } + + private static long mapBytes(int numElements) { + return OBJECT_SELF_BYTES + (long) numElements * 2 * REFERENCE_BYTES; + } + + private static long objectArrayBytes(int numElements) { + return OBJECT_SELF_BYTES + (long) numElements * REFERENCE_BYTES; + } + + private static long emptyPojoBytes() { + return OBJECT_SELF_BYTES; + } + + private static long pojoBytes() { + return OBJECT_SELF_BYTES + 4 + 8 + REFERENCE_BYTES; + } + + private static List emptyLists(int numElements) { + List root = new ArrayList<>(numElements); + for (int i = 0; i < numElements; i++) { + root.add(new ArrayList<>()); + } + return root; + } + + private static List nullLists(int siblings, int childElements) { + List root = new ArrayList<>(siblings); + for (int i = 0; i < siblings; i++) { + List child = new ArrayList<>(childElements); + for (int j = 0; j < childElements; j++) { + child.add(null); + } + root.add(child); + } + return root; + } + + private static MemoryBuffer objectArraySizeBuffer(int numElements) { + MemoryBuffer buffer = MemoryBuffer.newHeapBuffer(8); + buffer.writeVarUInt32Small7(numElements); + return trimBuffer(buffer); + } + + private static MemoryBuffer trimBuffer(MemoryBuffer buffer) { + return MemoryBuffer.fromByteArray(buffer.getBytes(0, buffer.writerIndex())); + } + + public static final class EmptyPojo {} + + public static final class Pojo { + public int intValue; + public long longValue; + public String name; + + public Pojo() {} + + Pojo(int intValue, long longValue, String name) { + this.intValue = intValue; + this.longValue = longValue; + this.name = name; + } + + @Override + public boolean equals(Object obj) { + if (!(obj instanceof Pojo)) { + return false; + } + Pojo other = (Pojo) obj; + return intValue == other.intValue + && longValue == other.longValue + && java.util.Objects.equals(name, other.name); + } + + @Override + public int hashCode() { + return java.util.Objects.hash(intValue, longValue, name); + } + } +} diff --git a/javascript/packages/core/lib/context.ts b/javascript/packages/core/lib/context.ts index 3b50a58337..1b34ebdd71 100644 --- a/javascript/packages/core/lib/context.ts +++ b/javascript/packages/core/lib/context.ts @@ -539,6 +539,8 @@ export class ReadContext { private _depth = 0; private _maxDepth: number; + private readonly maxGraphMemoryBytes: number; + private remainingGraphMemoryBytes = 0; private remoteSchemaVersionsByType: Map | undefined = undefined; constructor( @@ -549,6 +551,7 @@ export class ReadContext { this.refReader = new RefReader(this.reader); this.metaStringReader = new MetaStringReader(); this._maxDepth = config.maxDepth ?? 50; + this.maxGraphMemoryBytes = config.maxGraphMemoryBytes; } reset(bytes: Uint8Array) { @@ -557,6 +560,30 @@ export class ReadContext { this.metaStringReader.reset(); this.typeMeta = []; this._depth = 0; + this.remainingGraphMemoryBytes = this.maxGraphMemoryBytes; + } + + reserveGraphMemory(bytes: number) { + if (!Number.isSafeInteger(bytes) || bytes < 0) { + this.throwGraphMemoryOverflow(bytes); + } + const remaining = this.remainingGraphMemoryBytes - bytes; + if (remaining < 0) { + this.throwGraphBudgetExceeded(bytes); + } + this.remainingGraphMemoryBytes = remaining; + } + + private throwGraphMemoryOverflow(bytes: number): never { + throw new Error(`maxGraphMemoryBytes overflow: requested ${bytes} estimated graph bytes`); + } + + private throwGraphBudgetExceeded(bytes: number): never { + throw new Error( + `maxGraphMemoryBytes exceeded: requested ${bytes} estimated graph bytes, ` + + `${this.remainingGraphMemoryBytes} remaining, effective limit ` + + `${this.maxGraphMemoryBytes}`, + ); } isCompatible() { diff --git a/javascript/packages/core/lib/fory.ts b/javascript/packages/core/lib/fory.ts index f50d2fcebf..abef11da3a 100644 --- a/javascript/packages/core/lib/fory.ts +++ b/javascript/packages/core/lib/fory.ts @@ -38,6 +38,7 @@ const DEFAULT_MAX_TYPE_FIELDS = 512 as const; const DEFAULT_MAX_TYPE_META_BYTES = 4096 as const; const DEFAULT_MAX_SCHEMA_VERSIONS_PER_TYPE = 10 as const; const DEFAULT_MAX_AVERAGE_SCHEMA_VERSIONS_PER_TYPE = 3 as const; +const DEFAULT_MAX_GRAPH_MEMORY_BYTES = 128 * 1024 * 1024; export default class Fory { readonly typeResolver: TypeResolver; readonly anySerializer: Serializer; @@ -88,10 +89,17 @@ export default class Fory { `maxAverageSchemaVersionsPerType must be a positive integer but got ${maxAverageSchemaVersionsPerType}`, ); } + const maxGraphMemoryBytes = config?.maxGraphMemoryBytes ?? DEFAULT_MAX_GRAPH_MEMORY_BYTES; + if (!Number.isSafeInteger(maxGraphMemoryBytes) || maxGraphMemoryBytes <= 0) { + throw new Error( + `maxGraphMemoryBytes must be a positive safe integer but got ${maxGraphMemoryBytes}`, + ); + } return { ref: Boolean(config?.ref), useSliceString: Boolean(config?.useSliceString), maxDepth: config?.maxDepth, + maxGraphMemoryBytes, maxTypeFields, maxTypeMetaBytes, maxSchemaVersionsPerType, diff --git a/javascript/packages/core/lib/gen/collection.ts b/javascript/packages/core/lib/gen/collection.ts index 03139bebb1..04c3f293d8 100644 --- a/javascript/packages/core/lib/gen/collection.ts +++ b/javascript/packages/core/lib/gen/collection.ts @@ -26,6 +26,9 @@ import { Scope } from "./scope"; import { AnyHelper } from "./any"; import type { ReadContext, WriteContext } from "../context"; +const REFERENCE_BYTES = 4; +const COLLECTION_BYTES = 1; + export type CompatibleCollectionArrayReadAction = { target: "array" | "list"; elementTypeId: number; @@ -234,10 +237,12 @@ class CollectionAnySerializer { ): any { void fromRef; const len = this.readContext.reader.readVarUint32Small7(); + this.readContext.reserveGraphMemory(COLLECTION_BYTES + len * REFERENCE_BYTES); if (len === 0) { return createCollection(len); } const flags = this.readContext.reader.readUint8(); + this.readContext.reader.checkReadableBytes(len); const result = createCollection(len); // IMPORTANT: collection readers must obey the ref/null bits written on the // wire, not local TypeScript metadata that may imply a different ref @@ -418,6 +423,9 @@ export abstract class CollectionSerializerGenerator extends BaseSerializerGenera const newCollection = compatibleListToArray ? compatibleArrayCollectionExpr(compatibleReadAction!.elementTypeId, len) : this.newCollection(len); + const reserveMemory = compatibleListToArray + ? "" + : `${readContextName}.reserveGraphMemory(${COLLECTION_BYTES} + ${len} * ${REFERENCE_BYTES});`; const putAccessor = (item: string, index: string) => compatibleListToArray ? compatibleArrayPutAccessor(compatibleReadAction!.elementTypeId, result, item, index) @@ -449,6 +457,7 @@ export abstract class CollectionSerializerGenerator extends BaseSerializerGenera : `${elemSerializer} = ${anyHelper}.detectSerializer(${readContextName});`; return ` const ${len} = ${this.builder.reader.readVarUint32Small7()}; + ${reserveMemory} let ${flags} = 0; if (${len} > 0) { ${flags} = ${this.builder.reader.readUint8()}; diff --git a/javascript/packages/core/lib/gen/ext.ts b/javascript/packages/core/lib/gen/ext.ts index 5d8af562a4..a45d3df4df 100644 --- a/javascript/packages/core/lib/gen/ext.ts +++ b/javascript/packages/core/lib/gen/ext.ts @@ -25,6 +25,9 @@ import { CodegenRegistry } from "./router"; import { BaseSerializerGenerator } from "./serializer"; import { TypeMeta } from "../meta/TypeMeta"; +const OBJECT_BYTES = 1; +const REFERENCE_BYTES = 4; + class ExtSerializerGenerator extends BaseSerializerGenerator { typeInfo: TypeInfo; typeMeta: TypeMeta; @@ -41,6 +44,10 @@ class ExtSerializerGenerator extends BaseSerializerGenerator { this.ownTypeInfoExpr = `${this.serializerExpr}.getTypeInfo()`; } + private objectGraphBytes(): number { + return OBJECT_BYTES + Object.keys(this.typeInfo.options?.props ?? {}).length * REFERENCE_BYTES; + } + write(accessor: string): string { return ` ${this.builder.getOptions("customSerializer")}.write(${this.builder.getWriteContextName()}, ${accessor}) @@ -50,6 +57,7 @@ class ExtSerializerGenerator extends BaseSerializerGenerator { read(accessor: (expr: string) => string, refState: string): string { const result = this.scope.uniqueName("result"); return ` + ${this.builder.getReadContextName()}.reserveGraphMemory(${this.objectGraphBytes()}); ${ this.typeInfo.options!.withConstructor ? ` diff --git a/javascript/packages/core/lib/gen/map.ts b/javascript/packages/core/lib/gen/map.ts index db0e147a4d..824d59a37d 100644 --- a/javascript/packages/core/lib/gen/map.ts +++ b/javascript/packages/core/lib/gen/map.ts @@ -26,6 +26,9 @@ import { Scope } from "./scope"; import { AnyHelper } from "./any"; import { ReadContext, WriteContext } from "../context"; +const REFERENCE_BYTES = 4; +const MAP_BYTES = 1; + const MapFlags = { /** Whether track elements ref. */ TRACKING_REF: 0b1, @@ -272,6 +275,7 @@ class MapAnySerializer { read(fromRef: boolean): any { let count = this.readContext.reader.readVarUint32Small7(); + this.readContext.reserveGraphMemory(MAP_BYTES + count * 2 * REFERENCE_BYTES); const result = new Map(); if (fromRef) { this.readContext.reference(result); @@ -491,6 +495,7 @@ export class MapSerializerGenerator extends BaseSerializerGenerator { return ` let ${count} = ${this.builder.reader.readVarUint32Small7()}; + ${readContextName}.reserveGraphMemory(${MAP_BYTES} + ${count} * 2 * ${REFERENCE_BYTES}); const ${result} = new Map(); if (${refState}) { ${this.builder.referenceResolver.reference(result)} diff --git a/javascript/packages/core/lib/gen/struct.ts b/javascript/packages/core/lib/gen/struct.ts index 69ecb5dc1b..d4824fd086 100644 --- a/javascript/packages/core/lib/gen/struct.ts +++ b/javascript/packages/core/lib/gen/struct.ts @@ -28,6 +28,9 @@ import { getCompatibleCollectionArrayReadAction } from "./collection"; import { CompatibleScalarConverter, getCompatibleScalarReadAction } from "../compatible/scalar"; import { shouldSkipCompatibleRead } from "../compatible/field"; +const OBJECT_BYTES = 1; +const REFERENCE_BYTES = 4; + /** * Returns true when a field's read cannot recurse and needs no depth tracking. * Covers leaf scalars, typed arrays, and collections/maps whose elements are all leaf types. @@ -561,6 +564,10 @@ class StructSerializerGenerator extends BaseSerializerGenerator { ); } + private objectGraphBytes(): number { + return OBJECT_BYTES + this.sortedProps.length * REFERENCE_BYTES; + } + readField( fieldName: string, fieldTypeInfo: TypeInfo, @@ -809,13 +816,28 @@ class StructSerializerGenerator extends BaseSerializerGenerator { read(accessor: (expr: string) => string, refState: string): string { const result = this.scope.uniqueName("result"); + const hash = this.typeMeta.computeStructHash(); if (!this.typeInfo.options?.props || Object.keys(this.typeInfo.options.props).length === 0) { return ` - let ${result} = ${this.serializerExpr}.read(${refState}); + ${ + !this.builder.resolver.isCompatible() + ? ` + if(${this.builder.reader.readInt32()} !== ${hash}) { + throw new Error("Read class version is not consistent with ${hash} ") + } + ` + : "" + } + ${this.builder.getReadContextName()}.reserveGraphMemory(${OBJECT_BYTES}); + ${ + this.typeInfo.options?.withConstructor + ? `const ${result} = new ${this.builder.getOptions("creator")}();` + : `const ${result} = {};` + } + ${this.maybeReference(result, refState)} ${accessor(result)}; `; } - const hash = this.typeMeta.computeStructHash(); const directNumericObjectRead = this.readDirectNumericObject(accessor, refState); if (directNumericObjectRead !== null) { return ` @@ -841,6 +863,7 @@ class StructSerializerGenerator extends BaseSerializerGenerator { ` : "" } + ${this.builder.getReadContextName()}.reserveGraphMemory(${this.objectGraphBytes()}); ${ this.typeInfo.options!.withConstructor ? ` @@ -911,6 +934,7 @@ class StructSerializerGenerator extends BaseSerializerGenerator { } const result = this.scope.uniqueName("result"); return ` + ${this.builder.getReadContextName()}.reserveGraphMemory(${this.objectGraphBytes()}); const ${result} = { ${fields.map(({ key, expr }) => `${CodecBuilder.safePropName(key)}: ${expr}`).join(",\n")} }; @@ -1003,6 +1027,7 @@ class StructSerializerGenerator extends BaseSerializerGenerator { let ${value}; ${reads} ${this.builder.reader.readSetCursor(cursor)} + ${this.builder.getReadContextName()}.reserveGraphMemory(${this.objectGraphBytes()}); const ${result} = { ${fields.map(({ key, local }) => `${CodecBuilder.safePropName(key)}: ${local}`).join(",\n")} }; diff --git a/javascript/packages/core/lib/type.ts b/javascript/packages/core/lib/type.ts index ddbef54ec9..eb65815987 100644 --- a/javascript/packages/core/lib/type.ts +++ b/javascript/packages/core/lib/type.ts @@ -291,6 +291,7 @@ export interface Config { ref: boolean; useSliceString: boolean; maxDepth?: number; + maxGraphMemoryBytes: number; maxTypeFields: number; maxTypeMetaBytes: number; maxSchemaVersionsPerType: number; diff --git a/javascript/test/graphMemoryBudget.test.ts b/javascript/test/graphMemoryBudget.test.ts new file mode 100644 index 0000000000..8a0f633cca --- /dev/null +++ b/javascript/test/graphMemoryBudget.test.ts @@ -0,0 +1,280 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +import Fory, { Type } from "../packages/core/index"; +import { describe, expect, test } from "@jest/globals"; + +const DEFAULT_GRAPH_MEMORY_BYTES = 128 * 1024 * 1024; +const OBJECT_BYTES = 1; +const REFERENCE_BYTES = 4; + +const objectBytes = (fields: number) => OBJECT_BYTES + fields * REFERENCE_BYTES; +const listBytes = (count: number) => OBJECT_BYTES + count * REFERENCE_BYTES; +const mapBytes = (count: number) => OBJECT_BYTES + count * 2 * REFERENCE_BYTES; + +function serializeAny(value: unknown) { + return new Fory({ compatible: false, ref: true }).serialize(value); +} + +function deserializeAny(bytes: Uint8Array, maxGraphMemoryBytes: number) { + return new Fory({ + compatible: false, + ref: true, + maxGraphMemoryBytes, + }).deserialize(bytes); +} + +describe("graph memory budget", () => { + test("uses fixed default budget", () => { + const fory = new Fory({ compatible: false }); + + fory.readContext.reset(new Uint8Array(17)); + expect(() => fory.readContext.reserveGraphMemory(DEFAULT_GRAPH_MEMORY_BYTES)).not.toThrow(); + expect(() => fory.readContext.reserveGraphMemory(1)).toThrow(/maxGraphMemoryBytes/); + }); + + test("handles explicit config and validation", () => { + const fory = new Fory({ maxGraphMemoryBytes: 24 }); + fory.readContext.reset(new Uint8Array(1)); + expect(() => fory.readContext.reserveGraphMemory(0)).not.toThrow(); + expect(() => fory.readContext.reserveGraphMemory(24)).not.toThrow(); + expect(() => fory.readContext.reserveGraphMemory(1)).toThrow(/maxGraphMemoryBytes/); + + expect(() => new Fory({ maxGraphMemoryBytes: 0 })).toThrow(/maxGraphMemoryBytes/); + expect(() => new Fory({ maxGraphMemoryBytes: -2 })).toThrow(/maxGraphMemoryBytes/); + }); + + test("uses parent storage for nested empty containers", () => { + const typeInfo = Type.struct("budget.nested.empty", { + values: Type.list(Type.list(Type.int32({ encoding: "fixed" }))).setId(1), + }); + const writer = new Fory({ compatible: false, ref: true }); + const bytes = writer.register(typeInfo).serialize({ values: [[]] }); + const passingReader = new Fory({ + compatible: false, + ref: true, + maxGraphMemoryBytes: objectBytes(1) + listBytes(1) + listBytes(0), + }).register(typeInfo); + const failingReader = new Fory({ + compatible: false, + ref: true, + maxGraphMemoryBytes: objectBytes(1) + listBytes(1) + listBytes(0) - 1, + }).register(typeInfo); + + expect(() => failingReader.deserialize(bytes)).toThrow(/maxGraphMemoryBytes/); + expect(passingReader.deserialize(bytes)).toEqual({ values: [[]] }); + }); + + test("reserves sibling containers cumulatively", () => { + const typeInfo = Type.struct("budget.sibling.empty", { + values: Type.list(Type.list(Type.int32({ encoding: "fixed" }))).setId(1), + }); + const writer = new Fory({ compatible: false, ref: true }); + const bytes = writer.register(typeInfo).serialize({ + values: [[], [], []], + }); + const passingReader = new Fory({ + compatible: false, + ref: true, + maxGraphMemoryBytes: objectBytes(1) + listBytes(3) + 3 * listBytes(0), + }).register(typeInfo); + const failingReader = new Fory({ + compatible: false, + ref: true, + maxGraphMemoryBytes: objectBytes(1) + listBytes(3) + 3 * listBytes(0) - 1, + }).register(typeInfo); + + expect(() => failingReader.deserialize(bytes)).toThrow(/maxGraphMemoryBytes/); + expect(passingReader.deserialize(bytes)).toEqual({ + values: [[], [], []], + }); + }); + + test("reserves empty generated object owner", () => { + const childType = Type.struct("budget.empty.object", {}); + const typeInfo = Type.struct("budget.empty.parent", { + first: Type.struct("budget.empty.object").setId(1), + second: Type.struct("budget.empty.object").setId(2), + }); + const writer = new Fory({ compatible: false, ref: true }); + writer.register(childType); + const bytes = writer.register(typeInfo).serialize({ + first: {}, + second: {}, + }); + const required = objectBytes(2) + 2 * OBJECT_BYTES; + const passingReader = new Fory({ + compatible: false, + ref: true, + maxGraphMemoryBytes: required, + }); + passingReader.register(childType); + passingReader.register(typeInfo); + const failingReader = new Fory({ + compatible: false, + ref: true, + maxGraphMemoryBytes: required - 1, + }); + failingReader.register(childType); + failingReader.register(typeInfo); + + expect(() => failingReader.deserialize(bytes)).toThrow(/maxGraphMemoryBytes/); + expect(passingReader.deserialize(bytes)).toEqual({ + first: {}, + second: {}, + }); + }); + + test("preserves constructor for empty generated object owner", () => { + class EmptyChild {} + Type.struct("budget.empty.ctor.child", {})(EmptyChild); + class EmptyParent { + child = new EmptyChild(); + } + Type.struct("budget.empty.ctor.parent", { + child: Type.struct("budget.empty.ctor.child").setId(1), + })(EmptyParent); + + const writer = new Fory({ compatible: false, ref: true }); + writer.register(EmptyChild); + const bytes = writer.register(EmptyParent).serialize(new EmptyParent()); + const required = objectBytes(1) + OBJECT_BYTES; + const passingReader = new Fory({ + compatible: false, + ref: true, + maxGraphMemoryBytes: required, + }); + passingReader.register(EmptyChild); + passingReader.register(EmptyParent); + const failingReader = new Fory({ + compatible: false, + ref: true, + maxGraphMemoryBytes: required - 1, + }); + failingReader.register(EmptyChild); + failingReader.register(EmptyParent); + + expect(() => failingReader.deserialize(bytes)).toThrow(/maxGraphMemoryBytes/); + const decoded = passingReader.deserialize(bytes); + expect(decoded).toBeInstanceOf(EmptyParent); + expect(decoded.child).toBeInstanceOf(EmptyChild); + }); + + test("reserves map entries", () => { + const bytes = serializeAny(new Map([[1, 2]])); + + expect(() => deserializeAny(bytes, mapBytes(1) - 1)).toThrow(/maxGraphMemoryBytes/); + expect(deserializeAny(bytes, mapBytes(1))).toEqual(new Map([[1, 2]])); + }); + + test("reserves generated containers", () => { + const typeInfo = Type.struct("budget.generated", { + list: Type.list(Type.int32({ encoding: "fixed" })).setId(1), + set: Type.set(Type.string()).setId(2), + map: Type.map(Type.string(), Type.int32({ encoding: "fixed" })).setId(3), + }); + const writer = new Fory({ compatible: false, ref: true }); + const bytes = writer.register(typeInfo).serialize({ + list: [1], + set: new Set(["a"]), + map: new Map([["k", 1]]), + }); + const passingReader = new Fory({ + compatible: false, + ref: true, + maxGraphMemoryBytes: objectBytes(3) + listBytes(1) + listBytes(1) + mapBytes(1), + }).register(typeInfo); + const failingReader = new Fory({ + compatible: false, + ref: true, + maxGraphMemoryBytes: objectBytes(3) + listBytes(1) + listBytes(1) + mapBytes(1) - 1, + }).register(typeInfo); + + expect(() => failingReader.deserialize(bytes)).toThrow(/maxGraphMemoryBytes/); + expect(passingReader.deserialize(bytes)).toEqual({ + list: [1], + set: new Set(["a"]), + map: new Map([["k", 1]]), + }); + }); + + test("skips compatible list to typed array leaf", () => { + const writerType = Type.struct(9010, { + values: Type.list(Type.int32({ encoding: "fixed" })).setId(1), + }); + const readerType = Type.struct(9010, { + values: Type.int32Array().setId(1), + }); + const writer = new Fory({ compatible: true }); + const bytes = writer.register(writerType).serialize({ values: [1, 2, 3] }); + const passingReader = new Fory({ + compatible: true, + maxGraphMemoryBytes: objectBytes(1), + }).register(readerType); + const failingReader = new Fory({ + compatible: true, + maxGraphMemoryBytes: objectBytes(1) - 1, + }).register(readerType); + + expect(() => failingReader.deserialize(bytes)).toThrow(/maxGraphMemoryBytes/); + expect(Array.from(passingReader.deserialize(bytes).values)).toEqual([1, 2, 3]); + }); + + test("skips scalar dense owners", () => { + const typeInfo = Type.struct("budget.skipped", { + text: Type.string().setId(1), + binary: Type.binary().setId(2), + values: Type.int32Array().setId(3), + }); + const writer = new Fory({ compatible: false, ref: true }); + const bytes = writer.register(typeInfo).serialize({ + text: "hello", + binary: new Uint8Array([1, 2, 3]), + values: new Int32Array([1, 2, 3]), + }); + const reader = new Fory({ + compatible: false, + ref: true, + maxGraphMemoryBytes: objectBytes(3), + }).register(typeInfo); + + expect(reader.deserialize(bytes)).toEqual({ + text: "hello", + binary: new Uint8Array([1, 2, 3]), + values: new Int32Array([1, 2, 3]), + }); + }); + + test("keeps byte checks", () => { + const typeInfo = Type.struct("budget.bytecheck", { + values: Type.int32Array().setId(1), + }); + const writer = new Fory({ compatible: false, ref: true }); + const bytes = writer.register(typeInfo).serialize({ + values: new Int32Array([1, 2, 3]), + }); + const reader = new Fory({ + compatible: false, + ref: true, + maxGraphMemoryBytes: 1024 * 1024, + }).register(typeInfo); + + expect(() => reader.deserialize(bytes.slice(0, bytes.length - 1))).toThrow(); + }); +}); diff --git a/kotlin/fory-kotlin-ksp/src/main/kotlin/org/apache/fory/kotlin/ksp/ForyKotlinSymbolProcessor.kt b/kotlin/fory-kotlin-ksp/src/main/kotlin/org/apache/fory/kotlin/ksp/ForyKotlinSymbolProcessor.kt index 827e4bec27..f582200185 100644 --- a/kotlin/fory-kotlin-ksp/src/main/kotlin/org/apache/fory/kotlin/ksp/ForyKotlinSymbolProcessor.kt +++ b/kotlin/fory-kotlin-ksp/src/main/kotlin/org/apache/fory/kotlin/ksp/ForyKotlinSymbolProcessor.kt @@ -249,11 +249,32 @@ internal class ForyKotlinSymbolProcessor(private val environment: SymbolProcesso KotlinSerializerVisibility.PUBLIC }, construction = parsed.construction, + graphMemoryBytes = graphMemoryBytes(fields), fields = fields, originatingFiles = listOfNotNull(declaration.containingFile), ) } + private fun graphMemoryBytes(fields: List): Int { + var bytes = 1 + for (field in fields) { + bytes = Math.addExact(bytes, fieldGraphMemoryBytes(field.type)) + } + return bytes + } + + private fun fieldGraphMemoryBytes(type: KotlinSourceTypeNode): Int = + when (type.typeName) { + "boolean", + "byte" -> 1 + "short" -> 2 + "int", + "float" -> 4 + "long", + "double" -> 8 + else -> 4 + } + private fun parseStructFields( declaration: KSClassDeclaration, primaryConstructor: KSFunctionDeclaration, diff --git a/kotlin/fory-kotlin-ksp/src/main/kotlin/org/apache/fory/kotlin/ksp/KotlinSerializerSourceWriter.kt b/kotlin/fory-kotlin-ksp/src/main/kotlin/org/apache/fory/kotlin/ksp/KotlinSerializerSourceWriter.kt index 7517ca54ac..08c1ff0799 100644 --- a/kotlin/fory-kotlin-ksp/src/main/kotlin/org/apache/fory/kotlin/ksp/KotlinSerializerSourceWriter.kt +++ b/kotlin/fory-kotlin-ksp/src/main/kotlin/org/apache/fory/kotlin/ksp/KotlinSerializerSourceWriter.kt @@ -386,11 +386,20 @@ internal class KotlinSerializerSourceWriter(private val struct: KotlinSourceStru writeConstructorRead() } + private fun appendGraphMemoryReserve(indent: String) { + builder + .append(indent) + .append("readContext.reserveGraphMemory(") + .append(struct.graphMemoryBytes) + .append(")\n") + } + private fun writeConstructorRead() { builder .append(" private fun readSchemaConstructor(readContext: ReadContext): ") .append(struct.typeName) .append(" {\n") + appendGraphMemoryReserve(" ") builder.append(" val fieldValues = arrayOfNulls(DESCRIPTORS.size)\n") builder.append(" val bufferedFields = newFieldBits(DESCRIPTORS.size)\n") builder.append(" beginConstructorRef(readContext)\n") @@ -654,6 +663,7 @@ internal class KotlinSerializerSourceWriter(private val struct: KotlinSourceStru } private fun writeMutableReadBody() { + appendGraphMemoryReserve(" ") builder.append(" val value = ").append(struct.typeName).append("()\n") builder.append(" if (readContext.hasPreservedRefId()) {\n") builder.append(" readContext.reference(value)\n") @@ -700,6 +710,7 @@ internal class KotlinSerializerSourceWriter(private val struct: KotlinSourceStru builder.append(" }\n\n") return } + appendGraphMemoryReserve(" ") writeCompatibleValueReadBody(" ", constructorRefs = false) builder.append(" }\n\n") } @@ -709,6 +720,7 @@ internal class KotlinSerializerSourceWriter(private val struct: KotlinSourceStru .append(" private fun readCompatibleConstructor(readContext: ReadContext): ") .append(struct.typeName) .append(" {\n") + appendGraphMemoryReserve(" ") builder.append(" beginConstructorRef(readContext)\n") builder.append(" try {\n") writeCompatibleValueReadBody(" ", constructorRefs = true) @@ -829,6 +841,7 @@ internal class KotlinSerializerSourceWriter(private val struct: KotlinSourceStru private fun writeMutableCompatibleReadBody() { writePresenceVars() + appendGraphMemoryReserve(" ") builder.append(" val value = ").append(struct.typeName).append("()\n") builder.append(" if (readContext.hasPreservedRefId()) {\n") builder.append(" readContext.reference(value)\n") diff --git a/kotlin/fory-kotlin-ksp/src/main/kotlin/org/apache/fory/kotlin/ksp/Model.kt b/kotlin/fory-kotlin-ksp/src/main/kotlin/org/apache/fory/kotlin/ksp/Model.kt index a50b0ca62f..547f4e7dc6 100644 --- a/kotlin/fory-kotlin-ksp/src/main/kotlin/org/apache/fory/kotlin/ksp/Model.kt +++ b/kotlin/fory-kotlin-ksp/src/main/kotlin/org/apache/fory/kotlin/ksp/Model.kt @@ -28,6 +28,7 @@ internal data class KotlinSourceStruct( val construction: KotlinStructConstruction = KotlinStructConstruction.CONSTRUCTOR, val fields: List, val originatingFiles: List, + val graphMemoryBytes: Int = 1, ) { val qualifiedSerializerName: String = if (packageName.isEmpty()) serializerName else "$packageName.$serializerName" diff --git a/kotlin/fory-kotlin/src/main/kotlin/org/apache/fory/serializer/kotlin/CollectionSerializer.kt b/kotlin/fory-kotlin/src/main/kotlin/org/apache/fory/serializer/kotlin/CollectionSerializer.kt index e3b36d4785..9b6498e21a 100644 --- a/kotlin/fory-kotlin/src/main/kotlin/org/apache/fory/serializer/kotlin/CollectionSerializer.kt +++ b/kotlin/fory-kotlin/src/main/kotlin/org/apache/fory/serializer/kotlin/CollectionSerializer.kt @@ -21,7 +21,6 @@ package org.apache.fory.serializer.kotlin import org.apache.fory.context.ReadContext import org.apache.fory.context.WriteContext -import org.apache.fory.exception.DeserializationException import org.apache.fory.resolver.TypeResolver import org.apache.fory.serializer.collection.CollectionLikeSerializer @@ -57,15 +56,8 @@ public class KotlinArrayDequeSerializer( } override fun newCollection(readContext: ReadContext): Collection { - val buffer = readContext.buffer - val numElements = buffer.readVarUInt32Small7() - if (numElements < 0) { - throw DeserializationException("Collection size must be non-negative: $numElements") - } + val numElements = readCollectionSize(readContext, readContext.buffer) setNumElements(numElements) - if (numElements != 0) { - buffer.checkReadableBytes(numElements) - } return ArrayDequeBuilder(ArrayDeque(numElements)) } } diff --git a/kotlin/fory-kotlin/src/test/kotlin/org/apache/fory/serializer/kotlin/CollectionSerializerTest.kt b/kotlin/fory-kotlin/src/test/kotlin/org/apache/fory/serializer/kotlin/CollectionSerializerTest.kt index 5684821ebe..c66dc6582a 100644 --- a/kotlin/fory-kotlin/src/test/kotlin/org/apache/fory/serializer/kotlin/CollectionSerializerTest.kt +++ b/kotlin/fory-kotlin/src/test/kotlin/org/apache/fory/serializer/kotlin/CollectionSerializerTest.kt @@ -20,8 +20,10 @@ package org.apache.fory.serializer.kotlin import org.apache.fory.Fory +import org.apache.fory.exception.InsecureException import org.apache.fory.kotlin.ForyKotlin import org.testng.Assert.assertEquals +import org.testng.Assert.fail import org.testng.annotations.Test class CollectionSerializerTest { @@ -33,6 +35,22 @@ class CollectionSerializerTest { assertEquals(arrayDeque, fory.deserialize(fory.serialize(arrayDeque))) } + @Test + fun testArrayDequeGraphMemoryBudget() { + val writer: Fory = ForyKotlin.builder().withXlang(false).requireClassRegistration(true).build() + val reader: Fory = + ForyKotlin.builder() + .withXlang(false) + .requireClassRegistration(true) + .withMaxGraphMemoryBytes(23) + .build() + + try { + reader.deserialize(writer.serialize(ArrayDeque(listOf(1, 2, 3, 4, 5, 6)))) + fail("Expected graph memory budget failure") + } catch (ignored: InsecureException) {} + } + @Test fun testSerializeArrayList() { val fory: Fory = ForyKotlin.builder().withXlang(false).requireClassRegistration(true).build() diff --git a/kotlin/fory-kotlin/src/test/kotlin/org/apache/fory/serializer/kotlin/GenericDataClassTest.kt b/kotlin/fory-kotlin/src/test/kotlin/org/apache/fory/serializer/kotlin/GenericDataClassTest.kt index 4172768b3e..47313bc953 100644 --- a/kotlin/fory-kotlin/src/test/kotlin/org/apache/fory/serializer/kotlin/GenericDataClassTest.kt +++ b/kotlin/fory-kotlin/src/test/kotlin/org/apache/fory/serializer/kotlin/GenericDataClassTest.kt @@ -113,7 +113,11 @@ class GenericDataClassTest { assertEquals(typeInfo.decodeTypeName(), "Change") assertThrows(IllegalArgumentException::class.java) { KotlinSerializers.registerType( - Fory.builder().withXlang(false).requireClassRegistration(true).withCompatible(false).build(), + Fory.builder() + .withXlang(false) + .requireClassRegistration(true) + .withCompatible(false) + .build(), Change::class.java, "kotlin", "Bad.Name", diff --git a/python/pyfory/_fory.py b/python/pyfory/_fory.py index e4819ba424..9f5615565f 100644 --- a/python/pyfory/_fory.py +++ b/python/pyfory/_fory.py @@ -124,6 +124,7 @@ class Fory: "strict", "buffer", "max_depth", + "max_graph_memory_bytes", "field_nullable", "policy", ) @@ -139,6 +140,7 @@ def __init__( max_type_meta_bytes: int = 4096, max_schema_versions_per_type: int = 10, max_average_schema_versions_per_type: int = 3, + max_graph_memory_bytes: int = 128 * 1024 * 1024, policy: DeserializationPolicy = None, field_nullable: bool = False, meta_compressor=None, @@ -183,6 +185,9 @@ def __init__( max_average_schema_versions_per_type: Average remote metadata versions allowed across accepted remote types. + max_graph_memory_bytes: Maximum estimated graph memory per root + deserialization. Defaults to 128 MiB and must be a positive byte limit. + policy: Custom deserialization policy for security checks. When provided, it controls which types can be deserialized, overriding the default policy. **Strongly recommended** when strict=False to maintain security controls. @@ -213,6 +218,9 @@ def __init__( raise ValueError("max_schema_versions_per_type must be a positive integer") if not isinstance(max_average_schema_versions_per_type, int) or max_average_schema_versions_per_type <= 0: raise ValueError("max_average_schema_versions_per_type must be a positive integer") + if not isinstance(max_graph_memory_bytes, int) or max_graph_memory_bytes <= 0 or max_graph_memory_bytes > (1 << 63) - 1: + raise ValueError("max_graph_memory_bytes must be a positive 63-bit integer") + self.max_graph_memory_bytes = max_graph_memory_bytes self.config = Config( xlang=xlang, track_ref=ref, @@ -225,6 +233,7 @@ def __init__( max_type_meta_bytes=max_type_meta_bytes, max_schema_versions_per_type=max_schema_versions_per_type, max_average_schema_versions_per_type=max_average_schema_versions_per_type, + max_graph_memory_bytes=max_graph_memory_bytes, field_nullable=field_nullable, policy=self.policy, meta_compressor=meta_compressor, diff --git a/python/pyfory/buffer.pxi b/python/pyfory/buffer.pxi index 4fa77e9e68..53ff441f25 100644 --- a/python/pyfory/buffer.pxi +++ b/python/pyfory/buffer.pxi @@ -335,6 +335,15 @@ cdef class Buffer: f"Address range {offset, offset + length} out of bound {0, size_}", ) + cpdef inline ensure_readable(self, int32_t length): + if length < 0: + raise_fory_error(CErrorCode.InvalidData, f"Readable byte count {length} is negative") + if length == 0: + return + if not self.c_buffer.ensure_readable(length, self._error): + if not self._error.ok(): + self._raise_if_error() + cpdef inline write_bool(self, c_bool value): self.c_buffer.write_uint8(value) diff --git a/python/pyfory/collection.pxi b/python/pyfory/collection.pxi index 0183b26231..62ef71b6dc 100644 --- a/python/pyfory/collection.pxi +++ b/python/pyfory/collection.pxi @@ -41,6 +41,8 @@ cdef int8_t NULL_KEY_VALUE_DECL_TYPE = KEY_HAS_NULL | VALUE_DECL_TYPE cdef int8_t NULL_KEY_VALUE_DECL_TYPE_TRACKING_REF = KEY_HAS_NULL | VALUE_DECL_TYPE | TRACKING_VALUE_REF cdef int8_t NULL_VALUE_KEY_DECL_TYPE = VALUE_HAS_NULL | KEY_DECL_TYPE cdef int8_t NULL_VALUE_KEY_DECL_TYPE_TRACKING_REF = VALUE_HAS_NULL | KEY_DECL_TYPE | TRACKING_KEY_REF +cdef int64_t _REFERENCE_BYTES = sizeof(PyObject*) +cdef int64_t _OWNER_BYTES = 1 ctypedef PyObject *PyObjectPtr cdef class ListSerializer @@ -466,7 +468,11 @@ cdef class ListSerializer(CollectionSerializer): cdef int8_t head_flag cdef int32_t ref_id cdef int64_t i - + cdef int64_t graph_bytes + if len_ < 0: + raise ValueError("Container element count is negative") + graph_bytes = _OWNER_BYTES + len_ * _REFERENCE_BYTES + read_context.reserve_graph_memory_c(graph_bytes) if len_ == 0: list_ = PyList_New(0) return list_ @@ -583,7 +589,11 @@ cdef class TupleSerializer(CollectionSerializer): cdef bint has_null cdef int8_t head_flag cdef int64_t i - + cdef int64_t graph_bytes + if len_ < 0: + raise ValueError("Container element count is negative") + graph_bytes = _OWNER_BYTES + len_ * _REFERENCE_BYTES + read_context.reserve_graph_memory_c(graph_bytes) if len_ == 0: tuple_ = PyTuple_New(0) return tuple_ @@ -684,7 +694,7 @@ cdef class StringArraySerializer(ListSerializer): @cython.final cdef class SetSerializer(CollectionSerializer): cpdef read(self, ReadContext read_context): - cdef set instance = set() + cdef set instance cdef int32_t len_ cdef int8_t collect_flag cdef TypeInfo typeinfo @@ -701,11 +711,20 @@ cdef class SetSerializer(CollectionSerializer): cdef int8_t head_flag cdef int32_t ref_id cdef int64_t i + cdef int64_t graph_bytes - read_context.reference(instance) len_ = buffer.read_var_uint32() + if len_ < 0: + raise ValueError("Container element count is negative") + graph_bytes = _OWNER_BYTES + len_ * _REFERENCE_BYTES + read_context.reserve_graph_memory_c(graph_bytes) if len_ == 0: + instance = set() + read_context.reference(instance) return instance + read_context.check_readable_bytes(len_) + instance = set() + read_context.reference(instance) collect_flag = buffer.read_int8() if (collect_flag & COLL_IS_SAME_TYPE) != 0: @@ -1048,6 +1067,11 @@ cdef class MapSerializer(Serializer): cdef int32_t ref_id cdef dict map_ cdef int8_t chunk_header = 0 + cdef int64_t graph_bytes + if size < 0: + raise ValueError("Map entry count is negative") + graph_bytes = _OWNER_BYTES + size * (2 * _REFERENCE_BYTES) + read_context.reserve_graph_memory_c(graph_bytes) if size == 0: map_ = {} else: diff --git a/python/pyfory/collection.py b/python/pyfory/collection.py index d78673a6dc..b0021d74d8 100644 --- a/python/pyfory/collection.py +++ b/python/pyfory/collection.py @@ -23,6 +23,8 @@ fallback only. """ +import struct + from pyfory.serialization import ENABLE_FORY_CYTHON_SERIALIZATION from pyfory._serializer import Serializer, StringSerializer from pyfory.resolver import NOT_NULL_VALUE_FLAG, NULL_FLAG @@ -33,6 +35,8 @@ COLL_HAS_NULL = 0b10 COLL_IS_DECL_ELEMENT_TYPE = 0b100 COLL_IS_SAME_TYPE = 0b1000 +_REFERENCE_BYTES = struct.calcsize("P") +_OWNER_BYTES = 1 def _needs_element_type_info(type_id): @@ -176,6 +180,9 @@ def _write_different_types(self, write_context, value, collect_flag=0): def read(self, read_context): length = read_context.read_var_uint32() + read_context.reserve_graph_memory(_OWNER_BYTES + length * _REFERENCE_BYTES) + if length != 0: + read_context.check_readable_bytes(length) collection_ = self.new_instance(read_context, self.type_) if length == 0: return collection_ @@ -455,6 +462,9 @@ def write(self, write_context, obj): def read(self, read_context): size = read_context.read_var_uint32() + read_context.reserve_graph_memory(_OWNER_BYTES + size * 2 * _REFERENCE_BYTES) + if size != 0: + read_context.check_readable_bytes(size) map_ = {} ref_reader = read_context.ref_reader read_context.reference(map_) diff --git a/python/pyfory/context.pxi b/python/pyfory/context.pxi index 702f09769c..72e6c50077 100644 --- a/python/pyfory/context.pxi +++ b/python/pyfory/context.pxi @@ -30,6 +30,7 @@ STRING_TYPE_ID = TypeId.STRING SMALL_STRING_THRESHOLD = 16 cdef int32_t MAX_CACHED_META_STRINGS = 8192 cdef int32_t MAX_CACHED_META_STRING_LENGTH = 2048 +cdef int64_t _MAX_GRAPH_MEMORY_BYTES = 9223372036854775807 cdef inline uint64_t _mix64(uint64_t x): @@ -746,6 +747,8 @@ cdef class ReadContext: cdef readonly bint field_nullable cdef readonly object policy cdef readonly int32_t max_depth + cdef int64_t max_graph_memory_bytes + cdef int64_t remaining_graph_memory_bytes cdef readonly RefReader ref_reader cdef readonly MetaStringReader meta_string_reader cdef readonly MetaShareReadContext meta_share_context @@ -766,6 +769,8 @@ cdef class ReadContext: self.field_nullable = config.field_nullable self.policy = config.policy self.max_depth = config.max_depth + self.max_graph_memory_bytes = config.max_graph_memory_bytes + self.remaining_graph_memory_bytes = 0 self.ref_reader = RefReader(self.track_ref) self.meta_string_reader = MetaStringReader(self.type_resolver.shared_registry) self.meta_share_context = MetaShareReadContext() if config.scoped_meta_share_enabled else None @@ -789,6 +794,7 @@ cdef class ReadContext: self.buffers = iter(buffers) if buffers is not None else None self.unsupported_objects = iter(unsupported_objects) if unsupported_objects is not None else None self.peer_out_of_band_enabled = peer_out_of_band_enabled + self.remaining_graph_memory_bytes = self.max_graph_memory_bytes self.depth = 0 cpdef inline reset(self): @@ -803,8 +809,31 @@ cdef class ReadContext: self.buffers = None self.unsupported_objects = None self.peer_out_of_band_enabled = False + self.remaining_graph_memory_bytes = 0 self.depth = 0 + cdef inline void reserve_graph_memory_c(self, int64_t num_bytes): + cdef int64_t used + if num_bytes < 0: + raise ValueError("Estimated graph memory is negative") + if num_bytes > _MAX_GRAPH_MEMORY_BYTES: + raise ValueError("Estimated graph memory overflow") + if num_bytes > self.remaining_graph_memory_bytes: + used = self.max_graph_memory_bytes - self.remaining_graph_memory_bytes + raise ValueError( + f"Estimated graph memory budget exceeded: requested {num_bytes} bytes, " + f"used {used} bytes, limit {self.max_graph_memory_bytes} bytes. " + "Increase Fory(..., max_graph_memory_bytes=...) for trusted larger payloads." + ) + self.remaining_graph_memory_bytes -= num_bytes + + cpdef inline reserve_graph_memory(self, num_bytes): + if num_bytes < 0: + raise ValueError("Estimated graph memory is negative") + if num_bytes > _MAX_GRAPH_MEMORY_BYTES: + raise ValueError("Estimated graph memory overflow") + self.reserve_graph_memory_c(num_bytes) + cpdef inline add_context_object(self, key, obj): self.context_objects[id(key)] = obj diff --git a/python/pyfory/context.py b/python/pyfory/context.py index 3abfb46e3d..b4dbc0889e 100644 --- a/python/pyfory/context.py +++ b/python/pyfory/context.py @@ -37,6 +37,7 @@ FLOAT64_TYPE_ID = TypeId.FLOAT64 BOOL_TYPE_ID = TypeId.BOOL STRING_TYPE_ID = TypeId.STRING +_MAX_GRAPH_MEMORY_BYTES = (1 << 63) - 1 def _mix64(x: int) -> int: @@ -470,6 +471,8 @@ class ReadContext: "field_nullable", "policy", "max_depth", + "_max_graph_memory_bytes", + "_remaining_graph_memory_bytes", "ref_reader", "meta_string_reader", "meta_share_context", @@ -490,6 +493,8 @@ def __init__(self, config: Config, type_resolver): self.field_nullable = config.field_nullable self.policy = config.policy self.max_depth = config.max_depth + self._max_graph_memory_bytes = config.max_graph_memory_bytes + self._remaining_graph_memory_bytes = 0 self.ref_reader = MapRefReader() if self.track_ref else NoRefReader() self.meta_string_reader = MetaStringReader(type_resolver.shared_registry) self.meta_share_context = MetaShareReadContext() if config.scoped_meta_share_enabled else None @@ -511,8 +516,7 @@ def check_readable_bytes(self, length): raise ValueError(f"Readable byte count {length} is negative") if length == 0: return - reader_index = self.buffer.get_reader_index() - self.buffer.check_bound(reader_index, length) + self.buffer.ensure_readable(length) def prepare( self, @@ -525,6 +529,7 @@ def prepare( self.buffers = iter(buffers) if buffers is not None else None self.unsupported_objects = iter(unsupported_objects) if unsupported_objects is not None else None self.peer_out_of_band_enabled = peer_out_of_band_enabled + self._remaining_graph_memory_bytes = self._max_graph_memory_bytes self.depth = 0 def reset(self): @@ -538,8 +543,24 @@ def reset(self): self.buffers = None self.unsupported_objects = None self.peer_out_of_band_enabled = False + self._remaining_graph_memory_bytes = 0 self.depth = 0 + def reserve_graph_memory(self, num_bytes): + if num_bytes < 0: + raise ValueError("Estimated graph memory is negative") + if num_bytes > _MAX_GRAPH_MEMORY_BYTES: + raise ValueError("Estimated graph memory overflow") + remaining = self._remaining_graph_memory_bytes + if num_bytes > remaining: + used = self._max_graph_memory_bytes - remaining + raise ValueError( + f"Estimated graph memory budget exceeded: requested {num_bytes} bytes, " + f"used {used} bytes, limit {self._max_graph_memory_bytes} bytes. " + "Increase Fory(..., max_graph_memory_bytes=...) for trusted larger payloads." + ) + self._remaining_graph_memory_bytes = remaining - num_bytes + def add_context_object(self, key, obj): self.context_objects[id(key)] = obj diff --git a/python/pyfory/serialization.pyx b/python/pyfory/serialization.pyx index 899adcaf3c..6f00480793 100644 --- a/python/pyfory/serialization.pyx +++ b/python/pyfory/serialization.pyx @@ -113,6 +113,8 @@ cdef class Config: max_type_meta_bytes: Maximum accepted body size in one received TypeDef. max_schema_versions_per_type: Maximum accepted remote metadata versions for one logical type. max_average_schema_versions_per_type: Average remote schema versions allowed across accepted remote types. + max_graph_memory_bytes: Maximum estimated graph memory per root + deserialization. Defaults to 128 MiB and must be a positive byte limit. field_nullable: Treats struct/dataclass fields as nullable by default. policy: Deserialization policy used for security-sensitive checks. meta_compressor: Optional typedef/meta compressor implementation. @@ -129,6 +131,7 @@ cdef class Config: cdef public int32_t max_type_meta_bytes cdef public int32_t max_schema_versions_per_type cdef public int32_t max_average_schema_versions_per_type + cdef public int64_t max_graph_memory_bytes cdef public bint field_nullable cdef public object policy cdef public object meta_compressor @@ -147,6 +150,7 @@ cdef class Config: max_type_meta_bytes, max_schema_versions_per_type, max_average_schema_versions_per_type, + max_graph_memory_bytes, field_nullable, policy, meta_compressor, @@ -166,6 +170,8 @@ cdef class Config: max_type_meta_bytes: Maximum accepted body size in one received TypeDef. max_schema_versions_per_type: Maximum accepted remote metadata versions for one logical type. max_average_schema_versions_per_type: Average remote schema versions allowed across accepted remote types. + max_graph_memory_bytes: Maximum estimated graph memory per root + deserialization. Defaults to 128 MiB and must be a positive byte limit. field_nullable: Treat all struct fields as nullable by default. policy: Deserialization policy implementation. meta_compressor: Optional typedef/meta compressor. @@ -185,10 +191,17 @@ cdef class Config: raise ValueError("max_schema_versions_per_type must be a positive integer") if max_average_schema_versions_per_type <= 0: raise ValueError("max_average_schema_versions_per_type must be a positive integer") + if ( + not isinstance(max_graph_memory_bytes, int) + or max_graph_memory_bytes <= 0 + or max_graph_memory_bytes > 9223372036854775807 + ): + raise ValueError("max_graph_memory_bytes must be a positive 63-bit integer") self.max_type_fields = max_type_fields self.max_type_meta_bytes = max_type_meta_bytes self.max_schema_versions_per_type = max_schema_versions_per_type self.max_average_schema_versions_per_type = max_average_schema_versions_per_type + self.max_graph_memory_bytes = max_graph_memory_bytes self.field_nullable = field_nullable self.policy = policy self.meta_compressor = meta_compressor @@ -829,6 +842,7 @@ cdef class Fory: cdef public bint compatible cdef public bint field_nullable cdef public int32_t max_depth + cdef public int64_t max_graph_memory_bytes cdef public object policy cdef public Config config cdef public TypeResolver type_resolver @@ -847,6 +861,7 @@ cdef class Fory: max_type_meta_bytes=4096, max_schema_versions_per_type=10, max_average_schema_versions_per_type=3, + max_graph_memory_bytes=128 * 1024 * 1024, policy=None, field_nullable=False, meta_compressor=None, @@ -865,6 +880,8 @@ cdef class Fory: max_type_meta_bytes: Maximum accepted body size in one received TypeDef. max_schema_versions_per_type: Maximum accepted remote metadata versions for one logical type. max_average_schema_versions_per_type: Average remote schema versions allowed across accepted remote types. + max_graph_memory_bytes: Maximum estimated graph memory per root + deserialization. Defaults to 128 MiB and must be a positive byte limit. policy: Optional deserialization policy implementation. field_nullable: Treat struct fields as nullable by default. meta_compressor: Optional typedef/meta compressor implementation. @@ -882,6 +899,13 @@ cdef class Fory: self.compatible = compatible self.field_nullable = field_nullable self.max_depth = max_depth + if ( + not isinstance(max_graph_memory_bytes, int) + or max_graph_memory_bytes <= 0 + or max_graph_memory_bytes > 9223372036854775807 + ): + raise ValueError("max_graph_memory_bytes must be a positive 63-bit integer") + self.max_graph_memory_bytes = max_graph_memory_bytes self.config = Config( xlang=xlang, track_ref=ref, @@ -894,6 +918,7 @@ cdef class Fory: max_type_meta_bytes=max_type_meta_bytes, max_schema_versions_per_type=max_schema_versions_per_type, max_average_schema_versions_per_type=max_average_schema_versions_per_type, + max_graph_memory_bytes=max_graph_memory_bytes, field_nullable=field_nullable, policy=self.policy, meta_compressor=meta_compressor, @@ -1075,6 +1100,7 @@ cdef class Fory: iter(unsupported_objects) if unsupported_objects is not None else None ) read_context.peer_out_of_band_enabled = peer_out_of_band_enabled + read_context.remaining_graph_memory_bytes = self.max_graph_memory_bytes read_context.depth = 0 return read_context.read_ref() diff --git a/python/pyfory/serializer.py b/python/pyfory/serializer.py index d3e43de30f..5145b1e335 100644 --- a/python/pyfory/serializer.py +++ b/python/pyfory/serializer.py @@ -24,6 +24,7 @@ import marshal import os import pickle +import struct import types from typing import Tuple @@ -42,6 +43,8 @@ ) _WINDOWS = os.name == "nt" +_REFERENCE_BYTES = struct.calcsize("P") +_OWNER_BYTES = 1 from pyfory.serialization import ENABLE_FORY_CYTHON_SERIALIZATION from pyfory.types import TypeId @@ -933,6 +936,7 @@ def read(self, read_context): if dtype.kind == "O": length = read_context.read_varint32() _check_non_negative_size(length, "ndarray object") + read_context.reserve_graph_memory(_OWNER_BYTES + length * _REFERENCE_BYTES) read_context.check_readable_bytes(length) items = [read_context.read_ref() for _ in range(length)] return np.array(items, dtype=object) @@ -1713,10 +1717,11 @@ def write(self, write_context, value): def read(self, read_context): policy = read_context.policy policy.authorize_instantiation(self.type_) - obj = self.type_.__new__(self.type_) - read_context.reference(obj) num_fields = read_context.read_var_uint32() _check_non_negative_size(num_fields, "object field") + read_context.reserve_graph_memory(_OWNER_BYTES + num_fields * _REFERENCE_BYTES) + obj = self.type_.__new__(self.type_) + read_context.reference(obj) state = {} for _ in range(num_fields): field_name = read_context.read_string() @@ -1730,10 +1735,11 @@ def read(self, read_context): class _DefaultPolicyObjectSerializer(ObjectSerializer): def read(self, read_context): - obj = self.type_.__new__(self.type_) - read_context.reference(obj) num_fields = read_context.read_var_uint32() _check_non_negative_size(num_fields, "object field") + read_context.reserve_graph_memory(_OWNER_BYTES + num_fields * _REFERENCE_BYTES) + obj = self.type_.__new__(self.type_) + read_context.reference(obj) for _ in range(num_fields): field_name = read_context.read_string() field_value = read_context.read_ref() diff --git a/python/pyfory/struct.pxi b/python/pyfory/struct.pxi index 96b53d0fa9..7010a5ee8c 100644 --- a/python/pyfory/struct.pxi +++ b/python/pyfory/struct.pxi @@ -22,6 +22,8 @@ from cpython.unicode cimport PyUnicode_InternFromString cdef uint8_t _BASIC_FIELD_NOT_INLINE = 0xFF +cdef int64_t _STRUCT_OWNER_BYTES = 1 +cdef int64_t _STRUCT_REFERENCE_BYTES = sizeof(PyObject*) cdef struct FieldRuntimeInfo: @@ -422,6 +424,9 @@ cdef class DataClassSerializer(Serializer): f"Hash {read_hash} is not consistent with {self._hash} for type {self.type_}" ) + read_context.reserve_graph_memory_c( + _STRUCT_OWNER_BYTES + self._field_runtime_infos.size() * _STRUCT_REFERENCE_BYTES + ) obj = self.type_.__new__(self.type_) read_context.reference(obj) if self._has_slots: diff --git a/python/pyfory/struct.py b/python/pyfory/struct.py index a8daf46687..478cda5b43 100644 --- a/python/pyfory/struct.py +++ b/python/pyfory/struct.py @@ -24,6 +24,7 @@ import inspect import logging import os +import struct import sys import typing from typing import List, Dict @@ -80,6 +81,9 @@ logger = logging.getLogger(__name__) +_REFERENCE_BYTES = struct.calcsize("P") +_OWNER_BYTES = 1 + _MISSING_DEFAULT_INT_TYPES = { int, TypeId.INT8, @@ -651,6 +655,7 @@ def read(self, read_context): raise TypeNotCompatibleError( f"Hash {hash_} is not consistent with {self._hash} for type {self.type_}", ) + read_context.reserve_graph_memory(_OWNER_BYTES + len(self._field_names) * _REFERENCE_BYTES) obj = self.type_.__new__(self.type_) read_context.reference(obj) obj_dict = obj.__dict__ if not self._has_slots else None diff --git a/python/pyfory/tests/test_graph_memory_budget.py b/python/pyfory/tests/test_graph_memory_budget.py new file mode 100644 index 0000000000..6f95acb7c0 --- /dev/null +++ b/python/pyfory/tests/test_graph_memory_budget.py @@ -0,0 +1,253 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import array +import dataclasses +import struct + +import pytest + +import pyfory +from pyfory.serialization import Buffer +from pyfory.serializer import ListSerializer + +try: + import numpy as np +except ImportError: + np = None + + +DEFAULT_GRAPH_MEMORY_BYTES = 128 * 1024 * 1024 +REFERENCE_BYTES = struct.calcsize("P") +OWNER_BYTES = 1 +MAX_GRAPH_MEMORY_BYTES = (1 << 63) - 1 + + +class OneByteStream: + def __init__(self, data: bytes): + self._data = data + self._offset = 0 + + def read(self, size=-1): + if self._offset >= len(self._data): + return b"" + if size < 0: + size = len(self._data) - self._offset + if size == 0: + return b"" + read_size = min(1, size, len(self._data) - self._offset) + start = self._offset + self._offset += read_size + return self._data[start : start + read_size] + + def readinto(self, buffer): + if self._offset >= len(self._data): + return 0 + view = memoryview(buffer).cast("B") + if len(view) == 0: + return 0 + read_size = min(1, len(view), len(self._data) - self._offset) + start = self._offset + self._offset += read_size + view[:read_size] = self._data[start : start + read_size] + return read_size + + def recv_into(self, buffer, size=-1): + if self._offset >= len(self._data): + return 0 + view = memoryview(buffer).cast("B") + if size < 0 or size > len(view): + size = len(view) + if size == 0: + return 0 + read_size = min(1, size, len(self._data) - self._offset) + start = self._offset + self._offset += read_size + view[:read_size] = self._data[start : start + read_size] + return read_size + + +@dataclasses.dataclass +class BudgetItem: + value: int + + +class BudgetObject: + pass + + +def collection_memory(num_elements): + return OWNER_BYTES + num_elements * REFERENCE_BYTES + + +def map_memory(num_entries): + return OWNER_BYTES + num_entries * 2 * REFERENCE_BYTES + + +def object_memory(num_fields): + return OWNER_BYTES + num_fields * REFERENCE_BYTES + + +def new_fory(limit=DEFAULT_GRAPH_MEMORY_BYTES, *, xlang=True): + return pyfory.Fory( + xlang=xlang, + ref=True, + strict=False, + compatible=xlang, + max_graph_memory_bytes=limit, + ) + + +def expect_budget(value, budget, *, xlang=True): + writer = new_fory(xlang=xlang) + data = writer.serialize(value) + with pytest.raises(ValueError, match="Estimated graph memory budget exceeded"): + new_fory(budget - 1, xlang=xlang).deserialize(data) + return new_fory(budget, xlang=xlang).deserialize(data) + + +def varuint_payload(value): + buffer = Buffer.allocate(16) + buffer.write_var_uint32(value) + return buffer.to_bytes(0, buffer.get_writer_index()) + + +def test_fixed_default_budget(): + assert pyfory.Fory(xlang=False, ref=True).max_graph_memory_bytes == DEFAULT_GRAPH_MEMORY_BYTES + fory = new_fory(xlang=False) + value = [[], [], []] + assert fory.deserialize(fory.serialize(value)) == value + + +def test_stream_default_budget(): + fory = new_fory(xlang=False) + value = [[], [], []] + data = fory.serialize(value) + assert fory.deserialize(Buffer.from_stream(OneByteStream(data))) == value + + +def test_explicit_budget(): + value = [1] + budget = collection_memory(1) + assert expect_budget(value, budget) == value + + +def test_nested_empty_containers(): + value = [[]] + budget = collection_memory(1) + collection_memory(0) + assert expect_budget(value, budget) == value + + +def test_sibling_cumulative_budget(): + value = [[], [], []] + budget = collection_memory(3) + 3 * collection_memory(0) + assert expect_budget(value, budget) == value + + +def test_empty_object_owner_is_charged(): + fory = new_fory(xlang=False) + fory.register_type(BudgetItem) + value = BudgetItem(1) + budget = object_memory(1) + data = fory.serialize(value) + with pytest.raises(ValueError, match="Estimated graph memory budget exceeded"): + reader = new_fory(budget - 1, xlang=False) + reader.register_type(BudgetItem) + reader.deserialize(data) + reader = new_fory(budget, xlang=False) + reader.register_type(BudgetItem) + assert reader.deserialize(data) == value + + +def test_dynamic_object_budget(): + value = BudgetObject() + value.left = 1 + value.right = "x" + budget = object_memory(2) + + writer = new_fory(xlang=False) + writer.register_type(BudgetObject) + data = writer.serialize(value) + with pytest.raises(ValueError, match="Estimated graph memory budget exceeded"): + reader = new_fory(budget - 1, xlang=False) + reader.register_type(BudgetObject) + reader.deserialize(data) + reader = new_fory(budget, xlang=False) + reader.register_type(BudgetObject) + restored = reader.deserialize(data) + assert restored.left == value.left + assert restored.right == value.right + + +def test_map_entry_budget_and_overflow(): + value = {"a": 1} + assert expect_budget(value, map_memory(1)) == value + + fory = new_fory(xlang=False) + try: + with pytest.raises(ValueError, match="Estimated graph memory overflow"): + fory.read_context.reserve_graph_memory(MAX_GRAPH_MEMORY_BYTES + 1) + finally: + fory.reset_read() + + +def test_object_reference_array_budget(): + value = (1, 2, 3) + assert expect_budget(value, collection_memory(3), xlang=False) == value + + +def test_object_ndarray_budget(): + if np is None: + pytest.skip("numpy is not installed") + value = np.array([1, 2, 3], dtype=object) + restored = expect_budget(value, collection_memory(3), xlang=False) + np.testing.assert_array_equal(restored, value) + + +def test_dense_leaf_owners_skipped(): + values = [ + "x" * 256, + b"x" * 256, + array.array("i", range(32)), + ] + if np is not None: + values.append(np.array(list(range(32)), dtype=np.int32)) + for value in values: + fory = new_fory(1, xlang=False) + restored = fory.deserialize(fory.serialize(value)) + if np is not None and isinstance(value, np.ndarray): + np.testing.assert_array_equal(restored, value) + else: + assert restored == value + + +def test_large_list_needs_bytes(): + fory = new_fory(10_000_000, xlang=False) + serializer = ListSerializer(fory.type_resolver, list) + try: + fory.read_context.prepare(Buffer(varuint_payload(1000))) + with pytest.raises(Exception) as exc_info: + serializer.read(fory.read_context) + assert "Estimated graph memory" not in str(exc_info.value) + finally: + fory.reset_read() + + +@pytest.mark.parametrize("limit", [0, -2, 1 << 63]) +def test_invalid_config(limit): + with pytest.raises(ValueError, match="max_graph_memory_bytes"): + new_fory(limit) diff --git a/rust/fory-core/src/config.rs b/rust/fory-core/src/config.rs index 167d67e72b..aeeae86625 100644 --- a/rust/fory-core/src/config.rs +++ b/rust/fory-core/src/config.rs @@ -40,6 +40,9 @@ pub struct Config { /// When enabled, shared references and circular references are tracked /// and preserved during serialization/deserialization. pub track_ref: bool, + /// Maximum estimated graph memory accepted during one root deserialization. + /// Defaults to 128 MiB. Value must be a positive byte limit. + pub max_graph_memory_bytes: i64, /// Maximum accepted field count in one received struct TypeMeta. pub max_type_fields: u32, /// Maximum accepted body size in one received TypeMeta. @@ -61,6 +64,7 @@ impl Default for Config { max_dyn_depth: 5, check_struct_version: false, track_ref: false, + max_graph_memory_bytes: 128 * 1024 * 1024, max_type_fields: 512, max_type_meta_bytes: 4096, max_schema_versions_per_type: 10, @@ -123,6 +127,12 @@ impl Config { self.track_ref } + /// Get maximum estimated graph memory per root deserialization. + #[inline(always)] + pub fn max_graph_memory_bytes(&self) -> i64 { + self.max_graph_memory_bytes + } + /// Get maximum accepted field count in one received struct TypeMeta. #[inline(always)] pub fn max_type_fields(&self) -> usize { diff --git a/rust/fory-core/src/context.rs b/rust/fory-core/src/context.rs index 260f94ea4c..eab540e7dd 100644 --- a/rust/fory-core/src/context.rs +++ b/rust/fory-core/src/context.rs @@ -359,6 +359,7 @@ pub struct ReadContext<'a> { max_dyn_depth: u32, check_struct_version: bool, check_string_read: bool, + pub(crate) remaining_graph_memory_bytes: usize, // Context-specific fields pub reader: Reader<'a>, @@ -388,6 +389,7 @@ impl<'a> ReadContext<'a> { max_dyn_depth: config.max_dyn_depth, check_struct_version: config.check_struct_version, check_string_read: config.check_string_read, + remaining_graph_memory_bytes: 0, reader: Reader::default(), meta_resolver: MetaReaderResolver::default(), meta_string_resolver: MetaStringReaderResolver::default(), @@ -443,6 +445,21 @@ impl<'a> ReadContext<'a> { self.reader = reader; } + #[inline(always)] + #[doc(hidden)] + pub fn reserve_graph_memory(&mut self, bytes: usize) -> Result<(), Error> { + let remaining = self.remaining_graph_memory_bytes; + if bytes > remaining { + return Err(graph_memory_exceeded( + bytes, + remaining, + self.config.max_graph_memory_bytes, + )); + } + self.remaining_graph_memory_bytes = remaining - bytes; + Ok(()) + } + #[inline(always)] pub fn detach_reader(&mut self) -> Reader<'_> { mem::take(&mut self.reader) @@ -552,3 +569,12 @@ impl<'a> ReadContext<'a> { self.current_depth = 0; } } + +#[cold] +#[inline(never)] +fn graph_memory_exceeded(bytes: usize, remaining: usize, limit: i64) -> Error { + Error::invalid_data(format!( + "estimated graph memory request {} bytes exceeds max_graph_memory_bytes remaining budget {} bytes out of effective limit {} bytes", + bytes, remaining, limit + )) +} diff --git a/rust/fory-core/src/fory.rs b/rust/fory-core/src/fory.rs index 4b6c98419a..ee09e865c4 100644 --- a/rust/fory-core/src/fory.rs +++ b/rust/fory-core/src/fory.rs @@ -261,6 +261,19 @@ impl ForyBuilder { self } + /// Sets the maximum estimated graph memory accepted during one root deserialization. + /// + /// Defaults to 128 MiB. Values must be positive byte limits. + pub fn max_graph_memory_bytes(mut self, max_bytes: i64) -> Self { + assert!(max_bytes > 0, "max_graph_memory_bytes must be positive"); + assert!( + usize::try_from(max_bytes).is_ok(), + "max_graph_memory_bytes does not fit usize" + ); + self.config.max_graph_memory_bytes = max_bytes; + self + } + /// Sets the maximum depth for nested dynamic object serialization. /// /// # Arguments @@ -988,7 +1001,18 @@ impl Fory { self.with_read_context(|context| { let outlive_buffer = unsafe { mem::transmute::<&[u8], &[u8]>(bf) }; context.attach_reader(Reader::new(outlive_buffer)); - let result = self.deserialize_with_context(context); + let result = match usize::try_from(self.config.max_graph_memory_bytes) + .map_err(|_| Error::invalid_data("max_graph_memory_bytes does not fit usize")) + { + Ok(limit) => { + context.remaining_graph_memory_bytes = limit; + self.deserialize_with_context(context) + } + Err(err) => { + context.reset(); + Err(err) + } + }; context.detach_reader(); result }) @@ -1051,7 +1075,18 @@ impl Fory { let mut new_reader = Reader::new(outlive_buffer); new_reader.set_cursor(reader.cursor); context.attach_reader(new_reader); - let result = self.deserialize_with_context(context); + let result = match usize::try_from(self.config.max_graph_memory_bytes) + .map_err(|_| Error::invalid_data("max_graph_memory_bytes does not fit usize")) + { + Ok(limit) => { + context.remaining_graph_memory_bytes = limit; + self.deserialize_with_context(context) + } + Err(err) => { + context.reset(); + Err(err) + } + }; let end = context.detach_reader().get_cursor(); reader.set_cursor(end); result @@ -1109,7 +1144,6 @@ impl Fory { } else { RefMode::NullOnly }; - // TypeMeta is read inline during deserialization (streaming protocol) let result = ::fory_read(context, ref_mode, true); context.ref_reader.resolve_callbacks(); result diff --git a/rust/fory-core/src/resolver/type_resolver.rs b/rust/fory-core/src/resolver/type_resolver.rs index 41e6c53262..39424ba499 100644 --- a/rust/fory-core/src/resolver/type_resolver.rs +++ b/rust/fory-core/src/resolver/type_resolver.rs @@ -21,7 +21,6 @@ use crate::meta::{ MetaString, TypeMeta, NAMESPACE_ENCODER, NAMESPACE_ENCODINGS, TYPE_NAME_ENCODER, TYPE_NAME_ENCODINGS, }; -use crate::resolver::RefMode; use crate::serializer::{ForyDefault, Serializer, StructSerializer}; use crate::type_id::{get_ext_actual_type_id, is_enum_type_id}; use crate::types::{Date, Duration, Timestamp}; @@ -50,16 +49,6 @@ fn supports_type_def(type_id: u32) -> bool { ) } -type WriteFn = fn( - &dyn Any, - &mut WriteContext, - ref_mode: RefMode, - write_type_info: bool, - has_enerics: bool, -) -> Result<(), Error>; -type ReadFn = - fn(&mut ReadContext, ref_mode: RefMode, read_type_info: bool) -> Result, Error>; - type WriteDataFn = fn(&dyn Any, &mut WriteContext, has_generics: bool) -> Result<(), Error>; type ReadDataFn = fn(&mut ReadContext) -> Result, Error>; type ReadDataAsSendSyncAnyFn = fn(&mut ReadContext) -> Result, Error>; @@ -94,8 +83,6 @@ fn split_named_registration<'a>(name: &'a str, api: &str) -> Result<(&'a str, &' #[derive(Clone, Debug)] pub struct Harness { - write_fn: WriteFn, - read_fn: ReadFn, write_data_fn: WriteDataFn, read_data_fn: ReadDataFn, read_data_as_send_sync_any_fn: ReadDataAsSendSyncAnyFn, @@ -108,8 +95,6 @@ pub struct Harness { impl Harness { pub fn stub() -> Harness { Harness { - write_fn: stub_write_fn, - read_fn: stub_read_fn, write_data_fn: stub_write_data_fn, read_data_fn: stub_read_data_fn, read_data_as_send_sync_any_fn: stub_read_data_as_send_sync_any_fn, @@ -120,16 +105,6 @@ impl Harness { } } - #[inline(always)] - pub fn get_write_fn(&self) -> WriteFn { - self.write_fn - } - - #[inline(always)] - pub fn get_read_fn(&self) -> ReadFn { - self.read_fn - } - #[inline(always)] pub fn get_write_data_fn(&self) -> WriteDataFn { self.write_data_fn @@ -338,8 +313,6 @@ impl TypeInfo { } else { // Create a stub harness that returns errors when called Harness { - write_fn: stub_write_fn, - read_fn: stub_read_fn, write_data_fn: stub_write_data_fn, read_data_fn: stub_read_data_fn, read_data_as_send_sync_any_fn: stub_read_data_as_send_sync_any_fn, @@ -365,24 +338,6 @@ impl TypeInfo { } // Stub functions for when a type doesn't exist locally -fn stub_write_fn( - _: &dyn Any, - _: &mut WriteContext, - _: RefMode, - _: bool, - _: bool, -) -> Result<(), Error> { - Err(Error::type_error( - "Cannot serialize unknown remote type - type not registered locally", - )) -} - -fn stub_read_fn(_: &mut ReadContext, _: RefMode, _: bool) -> Result, Error> { - Err(Error::type_error( - "Cannot deserialize unknown remote type - type not registered locally", - )) -} - fn stub_write_data_fn(_: &dyn Any, _: &mut WriteContext, _: bool) -> Result<(), Error> { Err(Error::type_error( "Cannot serialize unknown remote type - type not registered locally", @@ -926,32 +881,6 @@ impl TypeResolver { || x == TypeId::NAMED_COMPATIBLE_STRUCT as u32 ); - fn write( - this: &dyn Any, - context: &mut WriteContext, - ref_mode: RefMode, - write_type_info: bool, - has_generics: bool, - ) -> Result<(), Error> { - let this = this.downcast_ref::(); - match this { - Some(v) => T2::fory_write(v, context, ref_mode, write_type_info, has_generics), - None => Err(Error::type_error(format!( - "Cast type to {:?} error when writing: {:?}", - std::any::type_name::(), - T2::fory_static_type_id() - ))), - } - } - - fn read( - context: &mut ReadContext, - ref_mode: RefMode, - read_type_info: bool, - ) -> Result, Error> { - Ok(Box::new(T2::fory_read(context, ref_mode, read_type_info)?)) - } - fn write_data( this: &dyn Any, context: &mut WriteContext, @@ -971,6 +900,10 @@ impl TypeResolver { fn read_data( context: &mut ReadContext, ) -> Result, Error> { + let graph_self_size = T2::fory_graph_self_size(); + if graph_self_size != 0 { + context.reserve_graph_memory(graph_self_size)?; + } match T2::fory_read_data(context) { Ok(v) => Ok(Box::new(v)), Err(e) => Err(e), @@ -1002,6 +935,10 @@ impl TypeResolver { context: &mut ReadContext, type_info: Rc, ) -> Result, Error> { + let graph_self_size = T2::fory_graph_self_size(); + if graph_self_size != 0 { + context.reserve_graph_memory(graph_self_size)?; + } Ok(Box::new(T2::fory_read_compatible(context, type_info)?)) } @@ -1013,8 +950,6 @@ impl TypeResolver { } let harness = Harness { - write_fn: write::, - read_fn: read::, write_data_fn: write_data::, read_data_fn: read_data::, read_data_as_send_sync_any_fn: read_data_as_send_sync_any::, @@ -1169,32 +1104,6 @@ impl TypeResolver { ))); } - fn write( - this: &dyn Any, - context: &mut WriteContext, - ref_mode: RefMode, - write_type_info: bool, - has_generics: bool, - ) -> Result<(), Error> { - let this = this.downcast_ref::(); - match this { - Some(v) => v.fory_write(context, ref_mode, write_type_info, has_generics), - None => Err(Error::type_error(format!( - "Cast type to {:?} error when writing: {:?}", - std::any::type_name::(), - T2::fory_static_type_id() - ))), - } - } - - fn read( - context: &mut ReadContext, - ref_mode: RefMode, - read_type_info: bool, - ) -> Result, Error> { - Ok(Box::new(T2::fory_read(context, ref_mode, read_type_info)?)) - } - fn write_data( this: &dyn Any, context: &mut WriteContext, @@ -1214,6 +1123,10 @@ impl TypeResolver { fn read_data( context: &mut ReadContext, ) -> Result, Error> { + let graph_self_size = T2::fory_graph_self_size(); + if graph_self_size != 0 { + context.reserve_graph_memory(graph_self_size)?; + } match T2::fory_read_data(context) { Ok(v) => Ok(Box::new(v)), Err(e) => Err(e), @@ -1252,8 +1165,6 @@ impl TypeResolver { // EXT types don't support fory_read_compatible let harness = Harness { - write_fn: write::, - read_fn: read::, write_data_fn: write_data::, read_data_fn: read_data::, read_data_as_send_sync_any_fn: read_data_as_send_sync_any::, diff --git a/rust/fory-core/src/serializer/arc.rs b/rust/fory-core/src/serializer/arc.rs index 1e69ae15c3..25677c9c34 100644 --- a/rust/fory-core/src/serializer/arc.rs +++ b/rust/fory-core/src/serializer/arc.rs @@ -217,6 +217,10 @@ fn read_arc_inner( ) -> Result { // Read type info if needed, then read data directly // No recursive ref handling needed since Arc only wraps allowed types + let graph_self_size = T::fory_graph_self_size(); + if graph_self_size != 0 { + context.reserve_graph_memory(graph_self_size)?; + } if let Some(typeinfo) = typeinfo { return T::fory_read_with_type_info(context, RefMode::None, typeinfo); } diff --git a/rust/fory-core/src/serializer/array.rs b/rust/fory-core/src/serializer/array.rs index 35e6a31069..c0257beb13 100644 --- a/rust/fory-core/src/serializer/array.rs +++ b/rust/fory-core/src/serializer/array.rs @@ -269,6 +269,18 @@ impl Serializer for [T; N] { } } + #[inline(always)] + fn fory_graph_self_size() -> usize + where + Self: Sized, + { + if is_primitive_type::() { + 0 + } else { + mem::size_of::() + } + } + #[inline(always)] fn fory_static_type_id() -> TypeId where diff --git a/rust/fory-core/src/serializer/box_.rs b/rust/fory-core/src/serializer/box_.rs index da7a378eb2..1ccd22eb65 100644 --- a/rust/fory-core/src/serializer/box_.rs +++ b/rust/fory-core/src/serializer/box_.rs @@ -29,6 +29,10 @@ impl Serializer for Box { where Self: Sized + ForyDefault, { + let graph_self_size = T::fory_graph_self_size(); + if graph_self_size != 0 { + context.reserve_graph_memory(graph_self_size)?; + } Ok(Box::new(T::fory_read_data(context)?)) } diff --git a/rust/fory-core/src/serializer/codec.rs b/rust/fory-core/src/serializer/codec.rs index 34059103f5..e103c27b29 100644 --- a/rust/fory-core/src/serializer/codec.rs +++ b/rust/fory-core/src/serializer/codec.rs @@ -52,6 +52,20 @@ const DECL_VALUE_TYPE: u8 = 0b100000; const MAX_CHUNK_SIZE: u8 = 255; +#[inline(always)] +fn reserve_graph_storage( + context: &mut ReadContext, + len: u32, + elem_bytes: usize, +) -> Result { + let len = len as usize; + let bytes = len + .checked_mul(elem_bytes) + .ok_or_else(|| Error::invalid_data("graph memory estimate overflows"))?; + context.reserve_graph_memory(bytes)?; + Ok(len) +} + #[inline(always)] pub fn field_ref_mode(field_type: &FieldType) -> RefMode { if field_type.track_ref { @@ -474,6 +488,11 @@ pub trait Codec: 'static { std::mem::size_of::() } + #[inline(always)] + fn graph_storage_size() -> usize { + std::mem::size_of::() + } + fn write_field(value: &T, context: &mut WriteContext) -> Result<(), Error>; fn read_field(context: &mut ReadContext) -> Result; @@ -615,6 +634,11 @@ where T::fory_reserved_space() + SIZE_OF_REF_AND_TYPE } + #[inline(always)] + fn graph_storage_size() -> usize { + T::fory_graph_storage_size() + } + #[inline(always)] fn write_field(value: &T, context: &mut WriteContext) -> Result<(), Error> { T::fory_write( @@ -1700,6 +1724,7 @@ where fn read_data(context: &mut ReadContext) -> Result, Error> { let len = context.reader.read_var_u32()?; + reserve_graph_storage(context, len, C::graph_storage_size())?; if len == 0 { return Ok(Vec::new()); } @@ -1728,6 +1753,7 @@ where remote_field_type: &FieldType, ) -> Result, Error> { let len = context.reader.read_var_u32()?; + reserve_graph_storage(context, len, C::graph_storage_size())?; if len == 0 { return Ok(Vec::new()); } @@ -2270,6 +2296,10 @@ where fn read_data(context: &mut ReadContext) -> Result, Error> { let len = context.reader.read_var_u32()?; + let elem_bytes = KC::graph_storage_size() + .checked_add(VC::graph_storage_size()) + .ok_or_else(|| Error::invalid_data("graph memory estimate overflows"))?; + reserve_graph_storage(context, len, elem_bytes)?; if len == 0 { return Ok(HashMap::new()); } @@ -2289,6 +2319,10 @@ where remote_field_type: &FieldType, ) -> Result, Error> { let len = context.reader.read_var_u32()?; + let elem_bytes = KC::graph_storage_size() + .checked_add(VC::graph_storage_size()) + .ok_or_else(|| Error::invalid_data("graph memory estimate overflows"))?; + let capacity = reserve_graph_storage(context, len, elem_bytes)?; if len == 0 { return Ok(HashMap::new()); } @@ -2299,7 +2333,8 @@ where { return read_map_dynamic::(context, len, remote_field_type); } - let mut map = HashMap::with_capacity(check_map_len(context, len)?); + context.reader.check_bound(capacity)?; + let mut map = HashMap::with_capacity(capacity); let mut len_counter = 0; while len_counter < len { let header = context.reader.read_u8()?; diff --git a/rust/fory-core/src/serializer/collection.rs b/rust/fory-core/src/serializer/collection.rs index ee16166bb4..ca26c2b050 100644 --- a/rust/fory-core/src/serializer/collection.rs +++ b/rust/fory-core/src/serializer/collection.rs @@ -42,6 +42,20 @@ fn check_collection_len(context: &ReadContext, len: u32) -> Result Ok(len) } +#[inline(always)] +fn reserve_collection_storage( + context: &mut ReadContext, + len: u32, + elem_bytes: usize, +) -> Result { + let len = len as usize; + let bytes = len + .checked_mul(elem_bytes) + .ok_or_else(|| Error::invalid_data("graph memory estimate overflows"))?; + context.reserve_graph_memory(bytes)?; + Ok(len) +} + pub fn write_collection_type_info( context: &mut WriteContext, collection_type_id: u32, @@ -239,6 +253,7 @@ where C: FromIterator, { let len = context.reader.read_var_u32()?; + let len_usize = reserve_collection_storage(context, len, T::fory_graph_storage_size())?; if len == 0 { return Ok(C::from_iter(std::iter::empty())); } @@ -257,7 +272,9 @@ where (header & IS_SAME_TYPE) != 0, Error::type_error("Type inconsistent, target type is not polymorphic") ); - let _ = check_collection_len(context, len)?; + if std::mem::size_of::() != 0 { + context.reader.check_bound(len_usize)?; + } if !has_null { (0..len) .map(|_| T::fory_read_data(context)) @@ -281,6 +298,7 @@ where T: Serializer + ForyDefault, { let len = context.reader.read_var_u32()?; + let len_usize = reserve_collection_storage(context, len, T::fory_graph_storage_size())?; if len == 0 { return Ok(Vec::new()); } @@ -297,7 +315,10 @@ where (header & IS_SAME_TYPE) != 0, Error::type_error("Type inconsistent, target type is not polymorphic") ); - let mut vec = Vec::with_capacity(check_collection_len(context, len)?); + if std::mem::size_of::() != 0 { + context.reader.check_bound(len_usize)?; + } + let mut vec = Vec::with_capacity(len_usize); if !has_null { for _ in 0..len { vec.push(T::fory_read_data(context)?); @@ -343,7 +364,8 @@ where } else { T::fory_get_type_info(context.get_type_resolver())? }; - let mut vec = Vec::with_capacity(check_collection_len(context, len)?); + let len_usize = check_collection_len(context, len)?; + let mut vec = Vec::with_capacity(len_usize); if elem_ref_mode == RefMode::None { for _ in 0..len { vec.push(T::fory_read_with_type_info( @@ -363,7 +385,8 @@ where } Ok(vec) } else { - let mut vec = Vec::with_capacity(check_collection_len(context, len)?); + let len_usize = check_collection_len(context, len)?; + let mut vec = Vec::with_capacity(len_usize); for _ in 0..len { vec.push(T::fory_read(context, elem_ref_mode, true)?); } @@ -724,6 +747,7 @@ where { let element_type = generic_field_type(remote_field_type, 0, "list")?; let len = context.reader.read_var_u32()?; + let len_usize = reserve_collection_storage(context, len, T::fory_graph_storage_size())?; if len == 0 { return Ok(Vec::new()); } @@ -748,8 +772,8 @@ where "array-compatible list must declare element type", )); } - context.reader.check_bound(len as usize)?; - let mut vec = Vec::with_capacity(len as usize); + context.reader.check_bound(len_usize)?; + let mut vec = Vec::with_capacity(len_usize); for _ in 0..len { vec.push(T::read_list_array_element(context, element_type.type_id)?); } diff --git a/rust/fory-core/src/serializer/core.rs b/rust/fory-core/src/serializer/core.rs index 2f155792d1..09a3d2adf8 100644 --- a/rust/fory-core/src/serializer/core.rs +++ b/rust/fory-core/src/serializer/core.rs @@ -585,6 +585,10 @@ pub trait Serializer: 'static { if read_type_info { Self::fory_read_type_info(context)?; } + let graph_self_size = Self::fory_graph_self_size(); + if graph_self_size != 0 { + context.reserve_graph_memory(graph_self_size)?; + } Self::fory_read_data(context) } @@ -1048,6 +1052,31 @@ pub trait Serializer: 'static { Self::fory_is_shared_ref() } + /// Shallow value-owner self storage for root and box/shared-reference allocation sites. + /// + /// Value serializers must not reserve this unconditionally because parent structs, + /// arrays, maps, and collections may already own the inline storage. + #[inline(always)] + fn fory_graph_self_size() -> usize + where + Self: Sized, + { + 0 + } + + #[inline(always)] + fn fory_graph_storage_size() -> usize + where + Self: Sized, + { + let inline_size = std::mem::size_of::(); + if inline_size == 0 { + Self::fory_graph_self_size() + } else { + inline_size + } + } + /// Get the static Fory type ID for this type. /// /// Type IDs are Fory's internal type identification system, separate from diff --git a/rust/fory-core/src/serializer/heap.rs b/rust/fory-core/src/serializer/heap.rs index 7203c9545e..4c7b9483ba 100644 --- a/rust/fory-core/src/serializer/heap.rs +++ b/rust/fory-core/src/serializer/heap.rs @@ -58,6 +58,10 @@ impl Serializer for BinaryHeap { mem::size_of::() } + fn fory_graph_self_size() -> usize { + mem::size_of::() + } + fn fory_get_type_id(_: &TypeResolver) -> Result { Ok(TypeId::SET) } diff --git a/rust/fory-core/src/serializer/list.rs b/rust/fory-core/src/serializer/list.rs index 435ec97e9e..175ac41e21 100644 --- a/rust/fory-core/src/serializer/list.rs +++ b/rust/fory-core/src/serializer/list.rs @@ -154,6 +154,15 @@ impl Serializer for Vec { } } + #[inline(always)] + fn fory_graph_self_size() -> usize { + if is_primitive_type::() { + 0 + } else { + mem::size_of::() + } + } + #[inline(always)] fn fory_get_type_id(_: &TypeResolver) -> Result { let id = get_primitive_type_id::(); @@ -235,6 +244,11 @@ impl Serializer for VecDeque { mem::size_of::() } + #[inline(always)] + fn fory_graph_self_size() -> usize { + mem::size_of::() + } + #[inline(always)] fn fory_get_type_id(_: &TypeResolver) -> Result { Ok(TypeId::LIST) @@ -298,6 +312,11 @@ impl Serializer for LinkedList { mem::size_of::() } + #[inline(always)] + fn fory_graph_self_size() -> usize { + mem::size_of::() + } + #[inline(always)] fn fory_get_type_id(_: &TypeResolver) -> Result { Ok(TypeId::LIST) diff --git a/rust/fory-core/src/serializer/map.rs b/rust/fory-core/src/serializer/map.rs index 3d0dc094e7..79bb8c53b3 100644 --- a/rust/fory-core/src/serializer/map.rs +++ b/rust/fory-core/src/serializer/map.rs @@ -35,16 +35,24 @@ const TRACKING_VALUE_REF: u8 = 0b1000; pub const VALUE_NULL: u8 = 0b10000; pub const DECL_VALUE_TYPE: u8 = 0b100000; -fn check_map_len(context: &ReadContext, len: u32) -> Result { - let len = len as usize; - context.reader.check_bound(len)?; - Ok(len) -} - fn write_chunk_size(context: &mut WriteContext, header_offset: usize, size: u8) { context.writer.set_bytes(header_offset + 1, &[size]); } +#[inline(always)] +fn reserve_map_storage( + context: &mut ReadContext, + len: u32, + elem_bytes: usize, +) -> Result { + let len = len as usize; + let bytes = len + .checked_mul(elem_bytes) + .ok_or_else(|| Error::invalid_data("graph memory estimate overflows"))?; + context.reserve_graph_memory(bytes)?; + Ok(len) +} + pub fn write_map_data<'a, K, V, I>( iter: I, length: usize, @@ -559,10 +567,14 @@ impl Result { let len = context.reader.read_var_u32()?; + let elem_bytes = K::fory_graph_storage_size() + .checked_add(V::fory_graph_storage_size()) + .ok_or_else(|| Error::invalid_data("graph memory estimate overflows"))?; + let capacity = reserve_map_storage(context, len, elem_bytes)?; if len == 0 { return Ok(HashMap::new()); } - let capacity = check_map_len(context, len)?; + context.reader.check_bound(capacity)?; if K::fory_is_polymorphic() || K::fory_is_shared_ref() || V::fory_is_polymorphic() @@ -659,6 +671,10 @@ impl() } + fn fory_graph_self_size() -> usize { + size_of::() + } + fn fory_get_type_id(_: &TypeResolver) -> Result { Ok(TypeId::MAP) } @@ -711,10 +727,14 @@ impl Result { let len = context.reader.read_var_u32()?; + let elem_bytes = K::fory_graph_storage_size() + .checked_add(V::fory_graph_storage_size()) + .ok_or_else(|| Error::invalid_data("graph memory estimate overflows"))?; + let len_usize = reserve_map_storage(context, len, elem_bytes)?; if len == 0 { return Ok(BTreeMap::new()); } - let _ = check_map_len(context, len)?; + context.reader.check_bound(len_usize)?; let mut map = BTreeMap::::new(); if K::fory_is_polymorphic() || K::fory_is_shared_ref() @@ -810,6 +830,10 @@ impl() } + fn fory_graph_self_size() -> usize { + size_of::() + } + fn fory_get_type_id(_: &TypeResolver) -> Result { Ok(TypeId::MAP) } diff --git a/rust/fory-core/src/serializer/rc.rs b/rust/fory-core/src/serializer/rc.rs index 7c0f46861e..eb8522d503 100644 --- a/rust/fory-core/src/serializer/rc.rs +++ b/rust/fory-core/src/serializer/rc.rs @@ -205,6 +205,10 @@ fn read_rc_inner( ) -> Result { // Read type info if needed, then read data directly // No recursive ref handling needed since Rc only wraps allowed types + let graph_self_size = T::fory_graph_self_size(); + if graph_self_size != 0 { + context.reserve_graph_memory(graph_self_size)?; + } if let Some(typeinfo) = typeinfo { return T::fory_read_with_type_info(context, RefMode::None, typeinfo); } diff --git a/rust/fory-core/src/serializer/set.rs b/rust/fory-core/src/serializer/set.rs index d49bc5559f..856fd66c9a 100644 --- a/rust/fory-core/src/serializer/set.rs +++ b/rust/fory-core/src/serializer/set.rs @@ -58,6 +58,10 @@ impl Serializer for HashSet< mem::size_of::() } + fn fory_graph_self_size() -> usize { + mem::size_of::() + } + fn fory_get_type_id(_: &TypeResolver) -> Result { Ok(TypeId::SET) } @@ -113,6 +117,10 @@ impl Serializer for BTreeSet { mem::size_of::() } + fn fory_graph_self_size() -> usize { + mem::size_of::() + } + fn fory_get_type_id(_: &TypeResolver) -> Result { Ok(TypeId::SET) } diff --git a/rust/fory-core/src/serializer/tuple.rs b/rust/fory-core/src/serializer/tuple.rs index bbda7383b2..6dc730a0b9 100644 --- a/rust/fory-core/src/serializer/tuple.rs +++ b/rust/fory-core/src/serializer/tuple.rs @@ -181,6 +181,11 @@ impl Serializer for (T0,) { mem::size_of::() } + #[inline(always)] + fn fory_graph_self_size() -> usize { + mem::size_of::() + } + #[inline(always)] fn fory_get_type_id(_: &TypeResolver) -> Result { Ok(TypeId::LIST) @@ -449,6 +454,11 @@ macro_rules! impl_tuple_serializer { mem::size_of::() // Size for length } + #[inline(always)] + fn fory_graph_self_size() -> usize { + mem::size_of::() + } + #[inline(always)] fn fory_get_type_id(_: &TypeResolver) -> Result { Ok(TypeId::LIST) diff --git a/rust/fory-derive/src/object/read.rs b/rust/fory-derive/src/object/read.rs index c4addc8593..9d3a20d819 100644 --- a/rust/fory-derive/src/object/read.rs +++ b/rust/fory-derive/src/object/read.rs @@ -195,6 +195,10 @@ pub fn gen_read(_struct_ident: &Ident) -> TokenStream { if ref_flag == (fory_core::RefFlag::RefValue as i8) && ref_mode == fory_core::RefMode::Tracking { context.ref_reader.reserve_ref_id(); } + let graph_self_size = ::fory_graph_self_size(); + if graph_self_size != 0 { + context.reserve_graph_memory(graph_self_size)?; + } if context.is_compatible() { let type_info = if read_type_info { context.read_any_type_info()? diff --git a/rust/fory-derive/src/object/serializer.rs b/rust/fory-derive/src/object/serializer.rs index 9624ebce10..922b32cf34 100644 --- a/rust/fory-derive/src/object/serializer.rs +++ b/rust/fory-derive/src/object/serializer.rs @@ -119,6 +119,7 @@ pub fn derive_serializer( read_type_info_ts, reserved_space_ts, static_type_id_ts, + graph_self_size_ts, ) = match &ast.data { syn::Data::Struct(s) => { let source_fields = source_fields(&s.fields); @@ -132,6 +133,10 @@ pub fn derive_serializer( read::gen_read_type_info(), write::gen_reserved_space(&source_fields), quote! { fory_core::TypeId::STRUCT }, + quote! { + let bytes = ::std::mem::size_of::(); + if bytes == 0 { 1 } else { bytes } + }, ) } syn::Data::Enum(e) => ( @@ -144,6 +149,7 @@ pub fn derive_serializer( derive_enum::gen_read_type_info(e), derive_enum::gen_reserved_space(), derive_enum::gen_static_type_id(e), + quote! { 0 }, ), syn::Data::Union(_) => { panic!("Union is not supported") @@ -226,6 +232,14 @@ pub fn derive_serializer( #reserved_space_ts } + #[inline(always)] + fn fory_graph_self_size() -> usize + where + Self: Sized, + { + #graph_self_size_ts + } + #[inline(always)] fn fory_write(&self, context: &mut fory_core::WriteContext, ref_mode: fory_core::RefMode, write_type_info: bool, _: bool) -> ::std::result::Result<(), fory_core::error::Error> { #write_ts @@ -289,6 +303,10 @@ fn generate_send_sync_tokens(ast: &syn::DeriveInput) -> SendSyncTokens { context: &mut fory_core::ReadContext, type_info: ::std::rc::Rc, ) -> ::std::result::Result<::std::boxed::Box, fory_core::error::Error> { + let graph_self_size = ::fory_graph_self_size(); + if graph_self_size != 0 { + context.reserve_graph_memory(graph_self_size)?; + } let value = ::fory_read_compatible(context, type_info)?; ::std::result::Result::Ok(fory_core::serializer::box_send_sync(value)) } @@ -305,6 +323,10 @@ fn generate_send_sync_tokens(ast: &syn::DeriveInput) -> SendSyncTokens { where Self: Sized + fory_core::ForyDefault, { + let graph_self_size = ::fory_graph_self_size(); + if graph_self_size != 0 { + context.reserve_graph_memory(graph_self_size)?; + } let value = ::fory_read_data(context)?; ::std::result::Result::Ok(fory_core::serializer::box_send_sync(value)) } diff --git a/rust/tests/tests/mod.rs b/rust/tests/tests/mod.rs index c66d727c8a..1a3c78828f 100644 --- a/rust/tests/tests/mod.rs +++ b/rust/tests/tests/mod.rs @@ -19,5 +19,6 @@ mod compatible; mod test_any; mod test_collection; mod test_field_meta; +mod test_graph_memory_budget; mod test_max_dyn_depth; mod test_tuple; diff --git a/rust/tests/tests/test_graph_memory_budget.rs b/rust/tests/tests/test_graph_memory_budget.rs new file mode 100644 index 0000000000..8abd761558 --- /dev/null +++ b/rust/tests/tests/test_graph_memory_budget.rs @@ -0,0 +1,307 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use fory_core::{Error, Fory, Reader}; +use fory_derive::ForyStruct; +use std::collections::HashMap; +use std::mem; + +const DEFAULT_GRAPH_MEMORY_BYTES: i64 = 128 * 1024 * 1024; + +#[derive(ForyStruct, Debug, PartialEq)] +struct BudgetSiblings { + first: Vec, + second: Vec, +} + +#[derive(ForyStruct, Debug, PartialEq)] +struct BudgetItem { + left: u64, + right: u64, +} + +#[derive(ForyStruct, Debug, PartialEq)] +struct BudgetEmpty; + +#[derive(ForyStruct, Debug)] +struct ListWireInts { + values: Vec>, +} + +#[derive(ForyStruct, Debug, PartialEq)] +struct DenseWireInts { + values: Vec, +} + +fn fory_with_budget(max_graph_memory_bytes: i64) -> Fory { + let mut fory = Fory::builder() + .xlang(false) + .compatible(false) + .max_graph_memory_bytes(max_graph_memory_bytes) + .build(); + fory.register_by_name::("BudgetSiblings") + .unwrap(); + fory.register_by_name::("BudgetItem").unwrap(); + fory.register_by_name::("BudgetEmpty").unwrap(); + fory +} + +fn compatible_fory(max_graph_memory_bytes: i64) -> Fory +where + T: fory_core::Serializer + fory_core::StructSerializer + fory_core::ForyDefault, +{ + let mut fory = Fory::builder() + .xlang(false) + .compatible(true) + .max_graph_memory_bytes(max_graph_memory_bytes) + .build(); + fory.register::(88_001).unwrap(); + fory +} + +fn compact_empty_lists(count: usize) -> Vec> { + (0..count).map(|_| Vec::new()).collect() +} + +#[test] +fn config_validation() { + assert_eq!( + Fory::builder().build().config().max_graph_memory_bytes, + DEFAULT_GRAPH_MEMORY_BYTES + ); + assert!( + std::panic::catch_unwind(|| Fory::builder().max_graph_memory_bytes(0).build()).is_err() + ); + assert!( + std::panic::catch_unwind(|| Fory::builder().max_graph_memory_bytes(-2).build()).is_err() + ); + let _ = Fory::builder().max_graph_memory_bytes(1).build(); +} + +#[test] +fn byte_root_uses_fixed_default_budget() { + let value = compact_empty_lists(12000); + let writer = fory_with_budget(DEFAULT_GRAPH_MEMORY_BYTES); + let bytes = writer.serialize(&value).unwrap(); + let decoded = writer.deserialize::>>(&bytes).unwrap(); + assert_eq!(decoded, value); +} + +#[test] +fn reader_root_default_budget() { + let value = compact_empty_lists(12000); + let writer = fory_with_budget(DEFAULT_GRAPH_MEMORY_BYTES); + let bytes = writer.serialize(&value).unwrap(); + + let mut reader = Reader::new(&bytes); + let decoded = writer + .deserialize_from::>>(&mut reader) + .unwrap(); + assert_eq!(decoded, value); +} + +#[test] +fn explicit_override() { + let value = compact_empty_lists(12000); + let writer = fory_with_budget(DEFAULT_GRAPH_MEMORY_BYTES); + let bytes = writer.serialize(&value).unwrap(); + + let vec_bytes = mem::size_of::>(); + let estimate = mem::size_of::>>() + value.len() * vec_bytes; + let limited = fory_with_budget((estimate - 1) as i64); + assert!(limited.deserialize::>>(&bytes).is_err()); + let explicit = fory_with_budget(estimate as i64); + let decoded: Vec> = explicit.deserialize(&bytes).unwrap(); + assert_eq!(decoded, value); +} + +#[test] +fn empty_collection_owner_self() { + let value: Vec = Vec::new(); + let writer = fory_with_budget(DEFAULT_GRAPH_MEMORY_BYTES); + let bytes = writer.serialize(&value).unwrap(); + + let limited = fory_with_budget((mem::size_of::>() - 1) as i64); + assert!(limited.deserialize::>(&bytes).is_err()); + + let limited = fory_with_budget(mem::size_of::>() as i64); + let decoded: Vec = limited.deserialize(&bytes).unwrap(); + assert!(decoded.is_empty()); +} + +#[test] +fn empty_struct_owner_self() { + let value = BudgetEmpty; + let writer = fory_with_budget(DEFAULT_GRAPH_MEMORY_BYTES); + let bytes = writer.serialize(&value).unwrap(); + + assert_eq!( + fory_with_budget(1) + .deserialize::(&bytes) + .unwrap(), + value + ); + + let values = vec![BudgetEmpty, BudgetEmpty, BudgetEmpty]; + let bytes = writer.serialize(&values).unwrap(); + let required = mem::size_of::>() + values.len(); + assert!(fory_with_budget((required - 1) as i64) + .deserialize::>(&bytes) + .is_err()); + assert_eq!( + fory_with_budget(required as i64) + .deserialize::>(&bytes) + .unwrap(), + values + ); +} + +#[test] +fn sibling_cumulative_budget() { + let value = BudgetSiblings { + first: vec!["a".to_string()], + second: vec!["b".to_string()], + }; + let writer = fory_with_budget(DEFAULT_GRAPH_MEMORY_BYTES); + let bytes = writer.serialize(&value).unwrap(); + let root = mem::size_of::() as i64; + let one_vec = mem::size_of::() as i64; + + let limited = fory_with_budget(root + one_vec); + assert!(limited.deserialize::(&bytes).is_err()); + let enough = fory_with_budget(root + one_vec * 2); + assert_eq!(enough.deserialize::(&bytes).unwrap(), value); +} + +#[test] +fn map_budget() { + let value: HashMap = HashMap::from([("a".to_string(), 1)]); + let writer = fory_with_budget(DEFAULT_GRAPH_MEMORY_BYTES); + let bytes = writer.serialize(&value).unwrap(); + let required = (mem::size_of::>() + + mem::size_of::() + + mem::size_of::()) as i64; + + let limited = fory_with_budget(required - 1); + assert!(limited.deserialize::>(&bytes).is_err()); + assert_eq!( + fory_with_budget(required) + .deserialize::>(&bytes) + .unwrap(), + value + ); +} + +#[test] +fn inline_value_vec_budget() { + let value = (0..16) + .map(|i| BudgetItem { + left: i, + right: i + 1, + }) + .collect::>(); + let writer = fory_with_budget(DEFAULT_GRAPH_MEMORY_BYTES); + let bytes = writer.serialize(&value).unwrap(); + let under_inline = mem::size_of::>() + value.len() * mem::size_of::(); + + let limited = fory_with_budget(under_inline as i64); + assert!(limited.deserialize::>(&bytes).is_err()); +} + +#[test] +fn box_vector_owner_self() { + let value = Box::new( + (0..4) + .map(|i| BudgetItem { + left: i, + right: i + 1, + }) + .collect::>(), + ); + let writer = fory_with_budget(DEFAULT_GRAPH_MEMORY_BYTES); + let bytes = writer.serialize(&value).unwrap(); + let required = mem::size_of::>() + value.len() * mem::size_of::(); + + assert!(fory_with_budget((required - 1) as i64) + .deserialize::>>(&bytes) + .is_err()); + assert_eq!( + fory_with_budget(required as i64) + .deserialize::>>(&bytes) + .unwrap(), + value + ); +} + +#[test] +fn compatible_list_array_budget() { + let value = ListWireInts { + values: (0..64).map(Some).collect(), + }; + let writer = compatible_fory::(DEFAULT_GRAPH_MEMORY_BYTES); + let bytes = writer.serialize(&value).unwrap(); + + let required = mem::size_of::() + 64 * mem::size_of::(); + let limited = compatible_fory::((required - 1) as i64); + assert!(limited.deserialize::(&bytes).is_err()); + + let enough = compatible_fory::(required as i64); + let decoded = enough.deserialize::(&bytes).unwrap(); + assert_eq!( + decoded, + DenseWireInts { + values: (0..64).collect() + } + ); +} + +#[test] +fn dense_paths_skipped() { + let fory = fory_with_budget(1); + + let string_bytes = fory_with_budget(DEFAULT_GRAPH_MEMORY_BYTES) + .serialize(&"hello".to_string()) + .unwrap(); + let decoded: String = fory.deserialize(&string_bytes).unwrap(); + assert_eq!(decoded, "hello"); + + let binary = vec![1_u8, 2, 3, 4]; + let binary_bytes = fory_with_budget(DEFAULT_GRAPH_MEMORY_BYTES) + .serialize(&binary) + .unwrap(); + let decoded: Vec = fory.deserialize(&binary_bytes).unwrap(); + assert_eq!(decoded, binary); + + let ints = vec![1_i32, 2, 3, 4]; + let int_bytes = fory_with_budget(DEFAULT_GRAPH_MEMORY_BYTES) + .serialize(&ints) + .unwrap(); + let decoded: Vec = fory.deserialize(&int_bytes).unwrap(); + assert_eq!(decoded, ints); +} + +#[test] +fn byte_check_preserved() { + let writer = fory_with_budget(DEFAULT_GRAPH_MEMORY_BYTES); + let mut bytes = writer.serialize(&Vec::::new()).unwrap(); + let last = bytes.len() - 1; + bytes[last] = 64; + + let reader = fory_with_budget(i64::MAX); + let err = reader.deserialize::>(&bytes).unwrap_err(); + assert!(matches!(err, Error::BufferOutOfBound(..)), "{err}"); +} diff --git a/scala/README.md b/scala/README.md index a1149ccf60..daffbc05a9 100644 --- a/scala/README.md +++ b/scala/README.md @@ -188,9 +188,12 @@ sbt test ## Code Format ```bash -sbt scalafmt +cd .. +ci/format.sh ``` +The Scala module does not currently wire a `scalafmt` sbt command. + ## Additional Notes - **Fory Reuse**: Always reuse Fory instances; creation is expensive diff --git a/scala/src/main/scala-3/org/apache/fory/scala/internal/ForySerializerMacros.scala b/scala/src/main/scala-3/org/apache/fory/scala/internal/ForySerializerMacros.scala index 025ff14c82..cbefa8ccdb 100644 --- a/scala/src/main/scala-3/org/apache/fory/scala/internal/ForySerializerMacros.scala +++ b/scala/src/main/scala-3/org/apache/fory/scala/internal/ForySerializerMacros.scala @@ -109,6 +109,19 @@ object ForySerializerMacros { annotations.foldRight(boxed)((annotation, current) => AnnotatedType(current, annotation)) } + def graphFieldBytes(tpe: TypeRepr): Long = { + val base = peelAnnotations(tpe.widen)._1.dealias + if base =:= TypeRepr.of[Boolean] then 1L + else if base =:= TypeRepr.of[Byte] then 1L + else if base =:= TypeRepr.of[Char] then 2L + else if base =:= TypeRepr.of[Short] then 2L + else if base =:= TypeRepr.of[Int] then 4L + else if base =:= TypeRepr.of[Float] then 4L + else if base =:= TypeRepr.of[Long] then 8L + else if base =:= TypeRepr.of[Double] then 8L + else 4L + } + def classFor(tpe: TypeRepr): Expr[Class[?]] = { val normalized = peelAnnotations(tpe.widen)._1.dealias val fullName = normalized.typeSymbol.fullName @@ -207,6 +220,7 @@ object ForySerializerMacros { !privateField, constructorOwned || (field.flags.is(Flags.Mutable) && !privateField)) } + val objectGraphMemoryBytes: Long = 1L + fields.map(field => graphFieldBytes(field.sourceType)).sum val hasNestedCompatibleStructFields = fields.exists(field => hasNestedCompatibleStruct(field.sourceType)) @@ -1129,7 +1143,8 @@ object ForySerializerMacros { Block( localDefs ++ maskDefs, Block( - readLoop.asTerm :: defaultAssignments.toList, + '{ $readContextExpr.reserveGraphMemory(${ Expr(objectGraphMemoryBytes) }) }.asTerm :: + readLoop.asTerm :: defaultAssignments.toList, constructFromLocals(localFields, instantiatorExpr, fieldAccessorsExpr).asTerm)) .asExprOf[T] } @@ -1151,6 +1166,7 @@ object ForySerializerMacros { if $resolverExpr.checkClassVersion() then { $serializerExpr.checkClassVersion(buffer.readInt32(), $classVersionHashExpr) } + $readContextExpr.reserveGraphMemory(${ Expr(objectGraphMemoryBytes) }) val obj = $instantiatorExpr.newInstance() $readContextExpr.reference(obj) var i = 0 @@ -1176,6 +1192,7 @@ object ForySerializerMacros { if $resolverExpr.checkClassVersion() then { $serializerExpr.checkClassVersion(buffer.readInt32(), $classVersionHashExpr) } + $readContextExpr.reserveGraphMemory(${ Expr(objectGraphMemoryBytes) }) val values = new Array[Any]($descriptorsExpr.size()) var i = 0 while i < $allFieldsExpr.length do { @@ -1204,6 +1221,7 @@ object ForySerializerMacros { if $sameSchemaCompatibleExpr then { $serializerExpr.read($readContextExpr) } else { + $readContextExpr.reserveGraphMemory(${ Expr(objectGraphMemoryBytes) }) val obj = $instantiatorExpr.newInstance() $readContextExpr.reference(obj) val remoteFields = $serializerExpr.getRemoteFields() diff --git a/scala/src/main/scala/org/apache/fory/serializer/scala/CollectionSerializer.scala b/scala/src/main/scala/org/apache/fory/serializer/scala/CollectionSerializer.scala index 066e24c629..6bfcefaa9b 100644 --- a/scala/src/main/scala/org/apache/fory/serializer/scala/CollectionSerializer.scala +++ b/scala/src/main/scala/org/apache/fory/serializer/scala/CollectionSerializer.scala @@ -53,15 +53,10 @@ abstract class AbstractScalaCollectionSerializer[A, T <: Iterable[A]]( value: T): util.Collection[_] override def newCollection(readContext: ReadContext): util.Collection[_] = { - val buffer = readContext.getBuffer - val numElements = buffer.readVarUInt32() - checkCollectionSize(numElements) + val numElements = readCollectionSize(readContext, readContext.getBuffer) setNumElements(numElements) val factory = readContext.readRef().asInstanceOf[Factory[A, T]] val builder = factory.newBuilder - if (numElements != 0) { - buffer.checkReadableBytes(numElements) - } builder.sizeHint(numElements) new JavaCollectionBuilder[A, T](builder) } diff --git a/scala/src/main/scala/org/apache/fory/serializer/scala/MapSerializer.scala b/scala/src/main/scala/org/apache/fory/serializer/scala/MapSerializer.scala index 9c21954b7d..145c4d47e8 100644 --- a/scala/src/main/scala/org/apache/fory/serializer/scala/MapSerializer.scala +++ b/scala/src/main/scala/org/apache/fory/serializer/scala/MapSerializer.scala @@ -50,15 +50,10 @@ abstract class AbstractScalaMapSerializer[K, V, T](typeResolver: TypeResolver, c def onMapWrite(writeContext: WriteContext, value: T): util.Map[_, _] override def newMap(readContext: ReadContext): util.Map[_, _] = { - val buffer = readContext.getBuffer - val numElements = buffer.readVarUInt32() - checkMapSize(numElements) + val numElements = readMapSize(readContext, readContext.getBuffer) setNumElements(numElements) val factory = readContext.readRef().asInstanceOf[Factory[(K, V), T]] val builder = factory.newBuilder - if (numElements != 0) { - buffer.checkReadableBytes(numElements) - } builder.sizeHint(numElements) new MapBuilder[K, V, T](builder) } diff --git a/scala/src/main/scala/org/apache/fory/serializer/scala/XlangCollectionSerializer.scala b/scala/src/main/scala/org/apache/fory/serializer/scala/XlangCollectionSerializer.scala index 9eeab286d2..4e34d77a20 100644 --- a/scala/src/main/scala/org/apache/fory/serializer/scala/XlangCollectionSerializer.scala +++ b/scala/src/main/scala/org/apache/fory/serializer/scala/XlangCollectionSerializer.scala @@ -43,12 +43,8 @@ abstract class AbstractScalaXlangCollectionSerializer[A, T <: scala.collection.I } override def newCollection(readContext: ReadContext): util.Collection[_] = { - val buffer = readContext.getBuffer - val numElements = readCollectionSize(buffer) + val numElements = readCollectionSize(readContext, readContext.getBuffer) setNumElements(numElements) - if (numElements != 0) { - buffer.checkReadableBytes(numElements) - } val builder = newBuilder(numElements) if (ScalaXlangCollectionShape.hasOptionElement(readContext)) { new XlangOptionCollectionBuilder[A, T](builder) @@ -368,12 +364,8 @@ abstract class AbstractScalaXlangMapSerializer[K, V, T <: scala.collection.Map[K } override def newMap(readContext: ReadContext): util.Map[_, _] = { - val buffer = readContext.getBuffer - val numElements = readMapSize(buffer) + val numElements = readMapSize(readContext, readContext.getBuffer) setNumElements(numElements) - if (numElements != 0) { - buffer.checkReadableBytes(numElements) - } val builder = ScalaXlangCollectionShape.mapBuilder[K, V, T](cls, numElements) val optionKey = ScalaXlangCollectionShape.hasOptionKey(readContext) diff --git a/scala/src/test/scala/org/apache/fory/serializer/scala/CollectionSerializerTest.scala b/scala/src/test/scala/org/apache/fory/serializer/scala/CollectionSerializerTest.scala index d1b7e67952..6489118eef 100644 --- a/scala/src/test/scala/org/apache/fory/serializer/scala/CollectionSerializerTest.scala +++ b/scala/src/test/scala/org/apache/fory/serializer/scala/CollectionSerializerTest.scala @@ -20,6 +20,7 @@ package org.apache.fory.serializer.scala import org.apache.fory.Fory +import org.apache.fory.exception.InsecureException import org.apache.fory.scala.ForyScala import org.scalatest.matchers.should.Matchers import org.scalatest.wordspec.AnyWordSpec @@ -89,6 +90,35 @@ class CollectionSerializerTest extends AnyWordSpec with Matchers { } } } + + "fory scala graph memory budget" should { + def runtime(maxGraphMemoryBytes: Option[Long] = None): Fory = { + val builder = ForyScala.builder() + .withXlang(false) + .withRefTracking(true) + .requireClassRegistration(false) + .suppressClassRegistrationWarnings(false) + .withSerializerFactory(new ScalaSerializerFactory()) + maxGraphMemoryBytes.foreach(builder.withMaxGraphMemoryBytes) + builder.build() + } + + "reserve scala collection storage" in { + val writer = runtime() + val reader = runtime(maxGraphMemoryBytes = Some(23)) + intercept[InsecureException] { + reader.deserialize(writer.serialize(List.fill(6)("v"))) + } + } + + "reserve scala map storage" in { + val writer = runtime() + val reader = runtime(maxGraphMemoryBytes = Some(23)) + intercept[InsecureException] { + reader.deserialize(writer.serialize(Map("a" -> 1, "b" -> 2, "c" -> 3))) + } + } + } } case class CollectionStruct1(list: List[String]) diff --git a/scala/src/test/scala/org/apache/fory/serializer/scala/ScalaXlangSerializerTest.scala b/scala/src/test/scala/org/apache/fory/serializer/scala/ScalaXlangSerializerTest.scala index 6cc4f880a1..6269061555 100644 --- a/scala/src/test/scala/org/apache/fory/serializer/scala/ScalaXlangSerializerTest.scala +++ b/scala/src/test/scala/org/apache/fory/serializer/scala/ScalaXlangSerializerTest.scala @@ -20,6 +20,7 @@ package org.apache.fory.serializer.scala import org.apache.fory.Fory +import org.apache.fory.exception.InsecureException import org.apache.fory.scala.ForyScala import org.scalatest.matchers.should.Matchers import org.scalatest.wordspec.AnyWordSpec @@ -120,5 +121,24 @@ class ScalaXlangSerializerTest extends AnyWordSpec with Matchers { copiedCyclic should not be theSameInstanceAs(cyclic) copiedCyclic(0) shouldBe theSameInstanceAs(copiedCyclic) } + + "enforce graph memory budget" in { + val writer = fory + val reader = ForyScala.builder() + .withXlang(true) + .withRefTracking(true) + .withRefCopy(true) + .requireClassRegistration(false) + .suppressClassRegistrationWarnings(false) + .withMaxGraphMemoryBytes(23) + .build() + + intercept[InsecureException] { + reader.deserialize(writer.serialize(List.fill(6)("v"))) + } + intercept[InsecureException] { + reader.deserialize(writer.serialize(Map("a" -> 1, "b" -> 2, "c" -> 3))) + } + } } } diff --git a/swift/Sources/Fory/AnySerializer.swift b/swift/Sources/Fory/AnySerializer.swift index fdd99b38a4..0b582a0d86 100644 --- a/swift/Sources/Fory/AnySerializer.swift +++ b/swift/Sources/Fory/AnySerializer.swift @@ -71,7 +71,11 @@ extension AnyHashable: Serializer { ) } - public static func foryReadCompatibleData(_ context: ReadContext, remoteTypeInfo: TypeInfo) throws -> AnyHashable { + public static func foryReadCompatibleData( + _ context: ReadContext, remoteTypeInfo: TypeInfo + ) throws + -> AnyHashable + { let typeInfo = remoteTypeInfo if typeInfo.typeID == .none { throw ForyError.invalidData("dynamic AnyHashable key cannot be null") @@ -179,7 +183,11 @@ struct SerializableAny: Serializer { ) } - static func foryReadCompatibleData(_ context: ReadContext, remoteTypeInfo: TypeInfo) throws -> SerializableAny { + static func foryReadCompatibleData( + _ context: ReadContext, remoteTypeInfo: TypeInfo + ) throws + -> SerializableAny + { let typeInfo = remoteTypeInfo if typeInfo.typeID == .none { return .foryDefault() @@ -541,7 +549,7 @@ public func readAny( refMode: RefMode, readTypeInfo: Bool = true ) throws -> Any? { - try SerializableAny.foryRead(context, refMode: refMode, readTypeInfo: readTypeInfo).anyValue() + try context.readAny(refMode: refMode, readTypeInfo: readTypeInfo) } public func writeListOfAny( @@ -565,12 +573,7 @@ public func readListOfAny( refMode: RefMode, readTypeInfo: Bool = false ) throws -> [Any]? { - let wrapped: [SerializableAny]? = try [SerializableAny]?.foryRead( - context, - refMode: refMode, - readTypeInfo: readTypeInfo - ) - return wrapped?.map { $0.anyValueForCollection() } + try context.readListOfAny(refMode: refMode, readTypeInfo: readTypeInfo) } public func writeMapStringToAny( @@ -596,20 +599,7 @@ public func readMapStringToAny( refMode: RefMode, readTypeInfo: Bool = false ) throws -> [String: Any]? { - let wrapped: [String: SerializableAny]? = try [String: SerializableAny]?.foryRead( - context, - refMode: refMode, - readTypeInfo: readTypeInfo - ) - guard let wrapped else { - return nil - } - var map: [String: Any] = [:] - map.reserveCapacity(wrapped.count) - for pair in wrapped { - map[pair.key] = pair.value.anyValueForCollection() - } - return map + try context.readMapStringToAny(refMode: refMode, readTypeInfo: readTypeInfo) } public func writeMapInt32ToAny( @@ -635,20 +625,7 @@ public func readMapInt32ToAny( refMode: RefMode, readTypeInfo: Bool = false ) throws -> [Int32: Any]? { - let wrapped: [Int32: SerializableAny]? = try [Int32: SerializableAny]?.foryRead( - context, - refMode: refMode, - readTypeInfo: readTypeInfo - ) - guard let wrapped else { - return nil - } - var map: [Int32: Any] = [:] - map.reserveCapacity(wrapped.count) - for pair in wrapped { - map[pair.key] = pair.value.anyValueForCollection() - } - return map + try context.readMapInt32ToAny(refMode: refMode, readTypeInfo: readTypeInfo) } public func writeMapAnyHashableToAny( @@ -674,27 +651,40 @@ public func readMapAnyHashableToAny( refMode: RefMode, readTypeInfo: Bool = false ) throws -> [AnyHashable: Any]? { - let wrapped: [AnyHashable: SerializableAny]? = try [AnyHashable: SerializableAny]?.foryRead( - context, - refMode: refMode, - readTypeInfo: readTypeInfo - ) - guard let wrapped else { - return nil + try context.readMapAnyHashableToAny(refMode: refMode, readTypeInfo: readTypeInfo) +} + +private let anyMapReferenceBytes = 4 + +@inline(never) +private func throwAnyMapGraphMemoryOverflow() throws -> Never { + throw ForyError.invalidData("graph memory estimate overflows") +} + +@inline(__always) +private func reserveAnyMapMemory( + _ context: ReadContext, _ type: Map.Type, count: Int +) throws { + let (slotBytes, overflow) = count.multipliedReportingOverflow( + by: 2 * anyMapReferenceBytes) + if overflow { + try throwAnyMapGraphMemoryOverflow() } - var map: [AnyHashable: Any] = [:] - map.reserveCapacity(wrapped.count) - for pair in wrapped { - map[pair.key] = pair.value.anyValueForCollection() + let ownerBytes = max(1, MemoryLayout.stride) + let (bytes, addOverflow) = ownerBytes.addingReportingOverflow(slotBytes) + if addOverflow { + try throwAnyMapGraphMemoryOverflow() } - return map + try context.reserveGraphMemory(bytes) } func readDynamicAnyMapValue(context: ReadContext) throws -> Any { - let map = try readMapAnyHashableToAny(context: context, refMode: .none) ?? [:] + let map = try context.readMapAnyHashableToAny(refMode: .none) ?? [:] if map.isEmpty { + try reserveAnyMapMemory(context, [String: Any].self, count: 0) return [String: Any]() } + try reserveAnyMapMemory(context, [String: Any].self, count: map.count) var stringMap: [String: Any] = [:] stringMap.reserveCapacity(map.count) for pair in map { @@ -708,6 +698,7 @@ func readDynamicAnyMapValue(context: ReadContext) throws -> Any { return stringMap } + try reserveAnyMapMemory(context, [Int32: Any].self, count: map.count) var int32Map: [Int32: Any] = [:] int32Map.reserveCapacity(map.count) for pair in map { diff --git a/swift/Sources/Fory/CollectionSerializers.swift b/swift/Sources/Fory/CollectionSerializers.swift index 1be59fb6b4..73afdfec31 100644 --- a/swift/Sources/Fory/CollectionSerializers.swift +++ b/swift/Sources/Fory/CollectionSerializers.swift @@ -34,21 +34,65 @@ enum MapHeader { static let declaredValueType: UInt8 = 0b0010_0000 } -private func primitiveArrayTypeID(for _: Element.Type) -> TypeId? { - if Element.self == UInt8.self { return .uint8Array } - if Element.self == Bool.self { return .boolArray } - if Element.self == Int8.self { return .int8Array } - if Element.self == Int16.self { return .int16Array } - if Element.self == Int32.self { return .int32Array } - if Element.self == Int64.self { return .int64Array } - if Element.self == UInt16.self { return .uint16Array } - if Element.self == UInt32.self { return .uint32Array } - if Element.self == UInt64.self { return .uint64Array } - if Element.self == Float16.self { return .float16Array } - if Element.self == BFloat16.self { return .bfloat16Array } - if Element.self == Float.self { return .float32Array } - if Element.self == Double.self { return .float64Array } - return nil +private let storedReferenceBytes = 4 + +@inline(__always) +private func storedElementBytes(_ type: Element.Type) -> Int { + type.isRefType ? storedReferenceBytes : max(1, MemoryLayout.stride) +} + +@inline(__always) +private func storedOwnerBytes(_ type: T.Type) -> Int { + max(1, MemoryLayout.stride) +} + +@inline(__always) +private func reserveGraphElements( + _ context: ReadContext, + ownerBytes: Int, + count: Int, + elementBytes: Int +) throws { + if ownerBytes < 0 || count < 0 || elementBytes < 0 { + throw ForyError.invalidData("graph memory estimate overflows") + } + let (storageBytes, overflow) = count.multipliedReportingOverflow(by: elementBytes) + if overflow { + throw ForyError.invalidData("graph memory estimate overflows") + } + let (bytes, addOverflow) = ownerBytes.addingReportingOverflow(storageBytes) + if addOverflow { + throw ForyError.invalidData("graph memory estimate overflows") + } + try context.reserveGraphMemory(bytes) +} + +@inline(__always) +private func reserveGraphArrayMemory( + _ context: ReadContext, + _ type: Element.Type, + ownerBytes: Int, + count: Int +) throws { + try reserveGraphElements( + context, ownerBytes: ownerBytes, count: count, elementBytes: storedElementBytes(type)) +} + +@inline(__always) +private func reserveGraphMapMemory( + _ context: ReadContext, + key: Key.Type, + value: Value.Type, + ownerBytes: Int, + count: Int +) throws { + let keyBytes = storedElementBytes(key) + let valueBytes = storedElementBytes(value) + let (elementBytes, overflow) = keyBytes.addingReportingOverflow(valueBytes) + if overflow { + throw ForyError.invalidData("graph memory estimate overflows") + } + try reserveGraphElements(context, ownerBytes: ownerBytes, count: count, elementBytes: elementBytes) } private let hostIsLittleEndian = Int(littleEndian: 1) == 1 @@ -234,18 +278,40 @@ func writePrimitiveArray(_ value: [Element], context: Write } } -func readPrimitiveArray(_ context: ReadContext) throws -> [Element] { +@inline(__always) +private func preparePrimitiveArray( + _ context: ReadContext, + reserveGraphStorage: Bool, + type: Element.Type, + count: Int, + label: String +) throws { + try context.ensureCollectionLength(count, label: label) + if reserveGraphStorage { + try reserveGraphArrayMemory( + context, type, ownerBytes: storedOwnerBytes([Element].self), count: count) + } +} + +func readPrimitiveArray( + _ context: ReadContext, + reserveGraphStorage: Bool = false +) throws -> [Element] { let byteSize = Int(try context.buffer.readVarUInt32()) try context.ensureRemainingBytes(byteSize, label: "primitive_array_bytes") if Element.self == UInt8.self { - try context.ensureCollectionLength(byteSize, label: "uint8_array") + try preparePrimitiveArray( + context, reserveGraphStorage: reserveGraphStorage, type: Element.self, count: byteSize, + label: "uint8_array") let bytes = try context.buffer.readBytes(count: byteSize) return uncheckedArrayCast(bytes, to: Element.self) } if Element.self == Bool.self { - try context.ensureCollectionLength(byteSize, label: "bool_array") + try preparePrimitiveArray( + context, reserveGraphStorage: reserveGraphStorage, type: Element.self, count: byteSize, + label: "bool_array") let out = try readArrayUninitialized(count: byteSize) { destination in for index in 0..(_ context: ReadContext) throws -> [ } if Element.self == Int8.self { - try context.ensureCollectionLength(byteSize, label: "int8_array") + try preparePrimitiveArray( + context, reserveGraphStorage: reserveGraphStorage, type: Element.self, count: byteSize, + label: "int8_array") var out = Array(repeating: Int8(0), count: byteSize) try out.withUnsafeMutableBytes { rawBytes in try context.buffer.readBytes(into: rawBytes) @@ -266,7 +334,9 @@ func readPrimitiveArray(_ context: ReadContext) throws -> [ if Element.self == Int16.self { if byteSize % 2 != 0 { throw ForyError.invalidData("int16 array byte size mismatch") } let count = byteSize / 2 - try context.ensureCollectionLength(count, label: "int16_array") + try preparePrimitiveArray( + context, reserveGraphStorage: reserveGraphStorage, type: Element.self, count: count, + label: "int16_array") if hostIsLittleEndian { var out = Array(repeating: Int16(0), count: count) try out.withUnsafeMutableBytes { rawBytes in @@ -285,7 +355,9 @@ func readPrimitiveArray(_ context: ReadContext) throws -> [ if Element.self == Int32.self { if byteSize % 4 != 0 { throw ForyError.invalidData("int32 array byte size mismatch") } let count = byteSize / 4 - try context.ensureCollectionLength(count, label: "int32_array") + try preparePrimitiveArray( + context, reserveGraphStorage: reserveGraphStorage, type: Element.self, count: count, + label: "int32_array") if hostIsLittleEndian { var out = Array(repeating: Int32(0), count: count) try out.withUnsafeMutableBytes { rawBytes in @@ -304,7 +376,9 @@ func readPrimitiveArray(_ context: ReadContext) throws -> [ if Element.self == UInt32.self { if byteSize % 4 != 0 { throw ForyError.invalidData("uint32 array byte size mismatch") } let count = byteSize / 4 - try context.ensureCollectionLength(count, label: "uint32_array") + try preparePrimitiveArray( + context, reserveGraphStorage: reserveGraphStorage, type: Element.self, count: count, + label: "uint32_array") if hostIsLittleEndian { var out = Array(repeating: UInt32(0), count: count) try out.withUnsafeMutableBytes { rawBytes in @@ -323,7 +397,9 @@ func readPrimitiveArray(_ context: ReadContext) throws -> [ if Element.self == Int64.self { if byteSize % 8 != 0 { throw ForyError.invalidData("int64 array byte size mismatch") } let count = byteSize / 8 - try context.ensureCollectionLength(count, label: "int64_array") + try preparePrimitiveArray( + context, reserveGraphStorage: reserveGraphStorage, type: Element.self, count: count, + label: "int64_array") if hostIsLittleEndian { var out = Array(repeating: Int64(0), count: count) try out.withUnsafeMutableBytes { rawBytes in @@ -342,7 +418,9 @@ func readPrimitiveArray(_ context: ReadContext) throws -> [ if Element.self == UInt64.self { if byteSize % 8 != 0 { throw ForyError.invalidData("uint64 array byte size mismatch") } let count = byteSize / 8 - try context.ensureCollectionLength(count, label: "uint64_array") + try preparePrimitiveArray( + context, reserveGraphStorage: reserveGraphStorage, type: Element.self, count: count, + label: "uint64_array") if hostIsLittleEndian { var out = Array(repeating: UInt64(0), count: count) try out.withUnsafeMutableBytes { rawBytes in @@ -361,7 +439,9 @@ func readPrimitiveArray(_ context: ReadContext) throws -> [ if Element.self == UInt16.self { if byteSize % 2 != 0 { throw ForyError.invalidData("uint16 array byte size mismatch") } let count = byteSize / 2 - try context.ensureCollectionLength(count, label: "uint16_array") + try preparePrimitiveArray( + context, reserveGraphStorage: reserveGraphStorage, type: Element.self, count: count, + label: "uint16_array") if hostIsLittleEndian { var out = Array(repeating: UInt16(0), count: count) try out.withUnsafeMutableBytes { rawBytes in @@ -380,10 +460,13 @@ func readPrimitiveArray(_ context: ReadContext) throws -> [ if Element.self == Float16.self { if byteSize % 2 != 0 { throw ForyError.invalidData("float16 array byte size mismatch") } let count = byteSize / 2 - try context.ensureCollectionLength(count, label: "float16_array") + try preparePrimitiveArray( + context, reserveGraphStorage: reserveGraphStorage, type: Element.self, count: count, + label: "float16_array") let values = try readArrayUninitialized(count: count) { destination in for index in 0..(_ context: ReadContext) throws -> [ if Element.self == BFloat16.self { if byteSize % 2 != 0 { throw ForyError.invalidData("bfloat16 array byte size mismatch") } let count = byteSize / 2 - try context.ensureCollectionLength(count, label: "bfloat16_array") + try preparePrimitiveArray( + context, reserveGraphStorage: reserveGraphStorage, type: Element.self, count: count, + label: "bfloat16_array") let values = try readArrayUninitialized(count: count) { destination in for index in 0..(_ context: ReadContext) throws -> [ if Element.self == Float.self { if byteSize % 4 != 0 { throw ForyError.invalidData("float32 array byte size mismatch") } let count = byteSize / 4 - try context.ensureCollectionLength(count, label: "float32_array") + try preparePrimitiveArray( + context, reserveGraphStorage: reserveGraphStorage, type: Element.self, count: count, + label: "float32_array") if hostIsLittleEndian { var out = Array(repeating: Float(0), count: count) try out.withUnsafeMutableBytes { rawBytes in @@ -422,7 +510,9 @@ func readPrimitiveArray(_ context: ReadContext) throws -> [ if byteSize % 8 != 0 { throw ForyError.invalidData("float64 array byte size mismatch") } let count = byteSize / 8 - try context.ensureCollectionLength(count, label: "float64_array") + try preparePrimitiveArray( + context, reserveGraphStorage: reserveGraphStorage, type: Element.self, count: count, + label: "float64_array") if hostIsLittleEndian { var out = Array(repeating: Double(0), count: count) try out.withUnsafeMutableBytes { rawBytes in @@ -502,14 +592,16 @@ extension Array: Serializer where Element: Serializer { refMode = .none } for element in self { - try element.foryWrite(context, refMode: refMode, writeTypeInfo: true, hasGenerics: hasGenerics) + try element.foryWrite( + context, refMode: refMode, writeTypeInfo: true, hasGenerics: hasGenerics) } return } if trackRef { for element in self { - try element.foryWrite(context, refMode: .tracking, writeTypeInfo: false, hasGenerics: hasGenerics) + try element.foryWrite( + context, refMode: .tracking, writeTypeInfo: false, hasGenerics: hasGenerics) } } else if hasNull { for element in self { @@ -528,10 +620,22 @@ extension Array: Serializer where Element: Serializer { } public static func foryReadData(_ context: ReadContext) throws -> [Element] { + try readData(context, reserveGraphStorage: true) + } + + fileprivate static func readData( + _ context: ReadContext, + reserveGraphStorage: Bool + ) throws -> [Element] { let buffer = context.buffer let length = Int(try buffer.readVarUInt32()) try context.ensureCollectionLength(length, label: "array") + let ownerBytes = reserveGraphStorage ? storedOwnerBytes([Element].self) : 0 if length == 0 { + if reserveGraphStorage { + try reserveGraphArrayMemory( + context, Element.self, ownerBytes: ownerBytes, count: length) + } return [] } @@ -541,6 +645,10 @@ extension Array: Serializer where Element: Serializer { let declared = (header & CollectionHeader.declaredElementType) != 0 let sameType = (header & CollectionHeader.sameType) != 0 if !sameType { + if reserveGraphStorage { + try reserveGraphArrayMemory( + context, Element.self, ownerBytes: ownerBytes, count: length) + } try context.ensureRemainingBytes(length, label: "array") if trackRef { return try readArrayUninitialized(count: length) { destination in @@ -579,6 +687,9 @@ extension Array: Serializer where Element: Serializer { } let elementTypeInfo = declared ? nil : try Element.foryReadTypeInfo(context) + if reserveGraphStorage { + try reserveGraphArrayMemory(context, Element.self, ownerBytes: ownerBytes, count: length) + } try context.ensureRemainingBytes(length, label: "array") return try context.withTypeInfo(elementTypeInfo, for: Element.self) { if trackRef { @@ -637,7 +748,10 @@ extension Set: Serializer where Element: Serializer & Hashable { } public static func foryReadData(_ context: ReadContext) throws -> Set { - Set(try [Element].foryReadData(context)) + let values = try [Element].readData(context, reserveGraphStorage: false) + try reserveGraphArrayMemory( + context, Element.self, ownerBytes: storedOwnerBytes(Set.self), count: values.count) + return Set(values) } } @@ -802,7 +916,8 @@ extension Dictionary: Serializer where Key: Serializer & Hashable, Value: Serial try Key.foryWriteStaticTypeInfo(context) } if trackKeyRef { - try pair.key.foryWrite(context, refMode: .tracking, writeTypeInfo: false, hasGenerics: hasGenerics) + try pair.key.foryWrite( + context, refMode: .tracking, writeTypeInfo: false, hasGenerics: hasGenerics) } else { try pair.key.foryWriteData(context, hasGenerics: hasGenerics) } @@ -812,7 +927,8 @@ extension Dictionary: Serializer where Key: Serializer & Hashable, Value: Serial try Value.foryWriteStaticTypeInfo(context) } if trackValueRef { - try pair.value.foryWrite(context, refMode: .tracking, writeTypeInfo: false, hasGenerics: hasGenerics) + try pair.value.foryWrite( + context, refMode: .tracking, writeTypeInfo: false, hasGenerics: hasGenerics) } else { try pair.value.foryWriteData(context, hasGenerics: hasGenerics) } @@ -844,12 +960,14 @@ extension Dictionary: Serializer where Key: Serializer & Hashable, Value: Serial break } if trackKeyRef { - try current.key.foryWrite(context, refMode: .tracking, writeTypeInfo: false, hasGenerics: hasGenerics) + try current.key.foryWrite( + context, refMode: .tracking, writeTypeInfo: false, hasGenerics: hasGenerics) } else { try current.key.foryWriteData(context, hasGenerics: hasGenerics) } if trackValueRef { - try current.value.foryWrite(context, refMode: .tracking, writeTypeInfo: false, hasGenerics: hasGenerics) + try current.value.foryWrite( + context, refMode: .tracking, writeTypeInfo: false, hasGenerics: hasGenerics) } else { try current.value.foryWriteData(context, hasGenerics: hasGenerics) } @@ -863,12 +981,17 @@ extension Dictionary: Serializer where Key: Serializer & Hashable, Value: Serial public static func foryReadData(_ context: ReadContext) throws -> [Key: Value] { let totalLength = Int(try context.buffer.readVarUInt32()) try context.ensureCollectionLength(totalLength, label: "map") + let ownerBytes = storedOwnerBytes(Dictionary.self) if totalLength == 0 { + try reserveGraphMapMemory( + context, key: Key.self, value: Value.self, ownerBytes: ownerBytes, count: totalLength) return [:] } - var map: [Key: Value] = [:] + try reserveGraphMapMemory( + context, key: Key.self, value: Value.self, ownerBytes: ownerBytes, count: totalLength) try context.ensureRemainingBytes(totalLength, label: "map") + var map: [Key: Value] = [:] map.reserveCapacity(totalLength) let keyDynamicType = Key.staticTypeId == .unknown let valueDynamicType = Value.staticTypeId == .unknown diff --git a/swift/Sources/Fory/FieldCodecs.swift b/swift/Sources/Fory/FieldCodecs.swift index 3db01e7c3c..5285cfe1dc 100644 --- a/swift/Sources/Fory/FieldCodecs.swift +++ b/swift/Sources/Fory/FieldCodecs.swift @@ -17,6 +17,83 @@ import Foundation +private let fieldReferenceBytes = 4 + +@inline(__always) +private func fieldElementBytes(_ codec: ElementCodec.Type) -> Int { + codec.isRefType ? fieldReferenceBytes : max(1, MemoryLayout.stride) +} + +@inline(__always) +private func serializerElementBytes(_ type: Element.Type) -> Int { + type.isRefType ? fieldReferenceBytes : max(1, MemoryLayout.stride) +} + +@inline(__always) +private func fieldOwnerBytes(_ type: T.Type) -> Int { + max(1, MemoryLayout.stride) +} + +@inline(__always) +private func reserveFieldStorage( + _ context: ReadContext, + ownerBytes: Int, + count: Int, + elementBytes: Int +) throws { + if ownerBytes < 0 || count < 0 || elementBytes < 0 { + throw ForyError.invalidData("graph memory estimate overflows") + } + let (storageBytes, overflow) = count.multipliedReportingOverflow(by: elementBytes) + if overflow { + throw ForyError.invalidData("graph memory estimate overflows") + } + let (bytes, addOverflow) = ownerBytes.addingReportingOverflow(storageBytes) + if addOverflow { + throw ForyError.invalidData("graph memory estimate overflows") + } + try context.reserveGraphMemory(bytes) +} + +@inline(__always) +private func reserveFieldArrayStorage( + _ context: ReadContext, + _ codec: ElementCodec.Type, + ownerBytes: Int, + count: Int +) throws { + try reserveFieldStorage( + context, ownerBytes: ownerBytes, count: count, elementBytes: fieldElementBytes(codec)) +} + +@inline(__always) +private func reserveSerializerArrayMemory( + _ context: ReadContext, + _ type: Element.Type, + ownerBytes: Int, + count: Int +) throws { + try reserveFieldStorage( + context, ownerBytes: ownerBytes, count: count, elementBytes: serializerElementBytes(type)) +} + +@inline(__always) +private func reserveFieldMapStorage( + _ context: ReadContext, + key: KeyCodec.Type, + value: ValueCodec.Type, + ownerBytes: Int, + count: Int +) throws { + let keyBytes = fieldElementBytes(key) + let valueBytes = fieldElementBytes(value) + let (elementBytes, overflow) = keyBytes.addingReportingOverflow(valueBytes) + if overflow { + throw ForyError.invalidData("graph memory estimate overflows") + } + try reserveFieldStorage(context, ownerBytes: ownerBytes, count: count, elementBytes: elementBytes) +} + public protocol FieldCodec { associatedtype Value @@ -31,7 +108,10 @@ public protocol FieldCodec { static func readPayload(_ context: ReadContext) throws -> Value static func writeStaticTypeInfo(_ context: WriteContext) throws static func readTypeInfo(_ context: ReadContext) throws -> TypeInfo? - static func withTypeInfo(_ typeInfo: TypeInfo?, _ context: ReadContext, _ body: () throws -> R) rethrows -> R + static func withTypeInfo( + _ typeInfo: TypeInfo?, _ context: ReadContext, _ body: () throws -> R + ) + rethrows -> R static func readCompatibleField( _ context: ReadContext, remoteFieldType: TypeMeta.FieldType, @@ -39,25 +119,25 @@ public protocol FieldCodec { ) throws -> Value } -public extension FieldCodec { - static var isNullableType: Bool { false } - static var isRefType: Bool { false } +extension FieldCodec { + public static var isNullableType: Bool { false } + public static var isRefType: Bool { false } - static func isNone(_: Value) -> Bool { false } + public static func isNone(_: Value) -> Bool { false } - static func fieldType(nullable: Bool, trackRef: Bool) -> TypeMeta.FieldType { + public static func fieldType(nullable: Bool, trackRef: Bool) -> TypeMeta.FieldType { TypeMeta.FieldType(typeID: typeId.rawValue, nullable: nullable, trackRef: trackRef) } - static func writeStaticTypeInfo(_ context: WriteContext) throws { + public static func writeStaticTypeInfo(_ context: WriteContext) throws { context.writeStaticTypeInfo(typeId) } - static func readTypeInfo(_ context: ReadContext) throws -> TypeInfo? { + public static func readTypeInfo(_ context: ReadContext) throws -> TypeInfo? { try context.readStaticTypeInfo(typeId) } - static func withTypeInfo( + public static func withTypeInfo( _ typeInfo: TypeInfo?, _ context: ReadContext, _ body: () throws -> R @@ -67,7 +147,7 @@ public extension FieldCodec { return try body() } - static func readCompatibleField( + public static func readCompatibleField( _ context: ReadContext, remoteFieldType: TypeMeta.FieldType, refMode: RefMode @@ -75,11 +155,12 @@ public extension FieldCodec { try read( context, refMode: refMode, - readTypeInfo: TypeId.needsTypeInfoForField(TypeId(rawValue: remoteFieldType.typeID) ?? .unknown) + readTypeInfo: TypeId.needsTypeInfoForField( + TypeId(rawValue: remoteFieldType.typeID) ?? .unknown) ) } - static func write( + public static func write( _ value: Value, _ context: WriteContext, refMode: RefMode, @@ -101,7 +182,7 @@ public extension FieldCodec { try writePayload(value, context) } - static func read( + public static func read( _ context: ReadContext, refMode: RefMode, readTypeInfo: Bool @@ -190,7 +271,8 @@ private enum FieldCodecDefault { try Codec.read( context, refMode: refMode, - readTypeInfo: TypeId.needsTypeInfoForField(TypeId(rawValue: remoteFieldType.typeID) ?? .unknown) + readTypeInfo: TypeId.needsTypeInfoForField( + TypeId(rawValue: remoteFieldType.typeID) ?? .unknown) ) } } @@ -208,7 +290,8 @@ public enum SerializerCodec: FieldCodec { } public static func fieldType(nullable: Bool, trackRef: Bool) -> TypeMeta.FieldType { - let fieldTypeID = T.staticTypeId == .structType ? TypeId.compatibleStruct.rawValue : T.staticTypeId.rawValue + let fieldTypeID = + T.staticTypeId == .structType ? TypeId.compatibleStruct.rawValue : T.staticTypeId.rawValue return TypeMeta.FieldType(typeID: fieldTypeID, nullable: nullable, trackRef: trackRef) } @@ -656,7 +739,11 @@ public enum ListFieldCodec: FieldCodec { } public static func readPayload(_ context: ReadContext) throws -> Value { - return try readCollectionPayload(context, elementCodec: ElementCodec.self) + return try readCollectionPayload( + context, + elementCodec: ElementCodec.self, + ownerBytes: fieldOwnerBytes([ElementCodec.Value].self) + ) } public static func readCompatibleField( @@ -665,7 +752,8 @@ public enum ListFieldCodec: FieldCodec { refMode: RefMode ) throws -> Value { if isCompatiblePackedArrayTypeID(remoteFieldType.typeID, elementCodec: ElementCodec.self) { - return try readCompatiblePackedArrayField(context, refMode: refMode, elementCodec: ElementCodec.self) + return try readCompatiblePackedArrayField( + context, refMode: refMode, elementCodec: ElementCodec.self) } return try FieldCodecDefault.readCompatibleField( codec: Self.self, @@ -841,7 +929,18 @@ public enum SetFieldCodec: FieldCodec where ElementCod } public static func readPayload(_ context: ReadContext) throws -> Value { - Set(try readCollectionPayload(context, elementCodec: ElementCodec.self)) + let values = try readCollectionPayload( + context, + elementCodec: ElementCodec.self, + ownerBytes: 0 + ) + try reserveFieldArrayStorage( + context, + ElementCodec.self, + ownerBytes: fieldOwnerBytes(Set.self), + count: values.count + ) + return Set(values) } } @@ -959,12 +1058,19 @@ where KeyCodec.Value: Hashable { public static func readPayload(_ context: ReadContext) throws -> Value { let totalLength = Int(try context.buffer.readVarUInt32()) try context.ensureCollectionLength(totalLength, label: "map") + let ownerBytes = fieldOwnerBytes(Dictionary.self) if totalLength == 0 { + try reserveFieldMapStorage( + context, key: KeyCodec.self, value: ValueCodec.self, ownerBytes: ownerBytes, + count: totalLength) return [:] } - var map: Value = [:] + try reserveFieldMapStorage( + context, key: KeyCodec.self, value: ValueCodec.self, ownerBytes: ownerBytes, + count: totalLength) try context.ensureRemainingBytes(totalLength, label: "map") + var map: Value = [:] map.reserveCapacity(totalLength) var readCount = 0 while readCount < totalLength { @@ -1263,49 +1369,62 @@ private func readPackedArrayPayload( elementCodec _: ElementCodec.Type ) throws -> [ElementCodec.Value]? { if ElementCodec.self == BoolCodec.self { - return uncheckedPackedArrayCast(try readPrimitiveArray(context) as [Bool], to: ElementCodec.Value.self) + return uncheckedPackedArrayCast( + try readPrimitiveArray(context) as [Bool], to: ElementCodec.Value.self) } if ElementCodec.self == Int8Codec.self { - return uncheckedPackedArrayCast(try readPrimitiveArray(context) as [Int8], to: ElementCodec.Value.self) + return uncheckedPackedArrayCast( + try readPrimitiveArray(context) as [Int8], to: ElementCodec.Value.self) } if ElementCodec.self == Int16Codec.self { - return uncheckedPackedArrayCast(try readPrimitiveArray(context) as [Int16], to: ElementCodec.Value.self) + return uncheckedPackedArrayCast( + try readPrimitiveArray(context) as [Int16], to: ElementCodec.Value.self) } if ElementCodec.self == Int32FixedCodec.self { - return uncheckedPackedArrayCast(try readPrimitiveArray(context) as [Int32], to: ElementCodec.Value.self) + return uncheckedPackedArrayCast( + try readPrimitiveArray(context) as [Int32], to: ElementCodec.Value.self) } if ElementCodec.self == Int64FixedCodec.self { - return uncheckedPackedArrayCast(try readPrimitiveArray(context) as [Int64], to: ElementCodec.Value.self) + return uncheckedPackedArrayCast( + try readPrimitiveArray(context) as [Int64], to: ElementCodec.Value.self) } if ElementCodec.self == IntFixedCodec.self { return uncheckedPackedArrayCast(try readIntArrayPayload(context), to: ElementCodec.Value.self) } if ElementCodec.self == UInt8Codec.self { - return uncheckedPackedArrayCast(try readPrimitiveArray(context) as [UInt8], to: ElementCodec.Value.self) + return uncheckedPackedArrayCast( + try readPrimitiveArray(context) as [UInt8], to: ElementCodec.Value.self) } if ElementCodec.self == UInt16Codec.self { - return uncheckedPackedArrayCast(try readPrimitiveArray(context) as [UInt16], to: ElementCodec.Value.self) + return uncheckedPackedArrayCast( + try readPrimitiveArray(context) as [UInt16], to: ElementCodec.Value.self) } if ElementCodec.self == UInt32FixedCodec.self { - return uncheckedPackedArrayCast(try readPrimitiveArray(context) as [UInt32], to: ElementCodec.Value.self) + return uncheckedPackedArrayCast( + try readPrimitiveArray(context) as [UInt32], to: ElementCodec.Value.self) } if ElementCodec.self == UInt64FixedCodec.self { - return uncheckedPackedArrayCast(try readPrimitiveArray(context) as [UInt64], to: ElementCodec.Value.self) + return uncheckedPackedArrayCast( + try readPrimitiveArray(context) as [UInt64], to: ElementCodec.Value.self) } if ElementCodec.self == UIntFixedCodec.self { return uncheckedPackedArrayCast(try readUIntArrayPayload(context), to: ElementCodec.Value.self) } if ElementCodec.self == Float16Codec.self { - return uncheckedPackedArrayCast(try readPrimitiveArray(context) as [Float16], to: ElementCodec.Value.self) + return uncheckedPackedArrayCast( + try readPrimitiveArray(context) as [Float16], to: ElementCodec.Value.self) } if ElementCodec.self == BFloat16Codec.self { - return uncheckedPackedArrayCast(try readPrimitiveArray(context) as [BFloat16], to: ElementCodec.Value.self) + return uncheckedPackedArrayCast( + try readPrimitiveArray(context) as [BFloat16], to: ElementCodec.Value.self) } if ElementCodec.self == FloatCodec.self { - return uncheckedPackedArrayCast(try readPrimitiveArray(context) as [Float], to: ElementCodec.Value.self) + return uncheckedPackedArrayCast( + try readPrimitiveArray(context) as [Float], to: ElementCodec.Value.self) } if ElementCodec.self == DoubleCodec.self { - return uncheckedPackedArrayCast(try readPrimitiveArray(context) as [Double], to: ElementCodec.Value.self) + return uncheckedPackedArrayCast( + try readPrimitiveArray(context) as [Double], to: ElementCodec.Value.self) } return nil } @@ -1324,8 +1443,16 @@ private func writeUIntArrayPayload(_ value: [UInt], _ context: WriteContext) { } } -private func readIntArrayPayload(_ context: ReadContext) throws -> [Int] { +private func readIntArrayPayload( + _ context: ReadContext, reserveGraphStorage: Bool = false +) throws + -> [Int] +{ let count = try readPackedArrayElementCount(context, width: 8, label: "int64_array") + if reserveGraphStorage { + try reserveSerializerArrayMemory( + context, Int.self, ownerBytes: fieldOwnerBytes([Int].self), count: count) + } var values: [Int] = [] values.reserveCapacity(count) for _ in 0.. [Int] { return values } -private func readUIntArrayPayload(_ context: ReadContext) throws -> [UInt] { +private func readUIntArrayPayload( + _ context: ReadContext, reserveGraphStorage: Bool = false +) throws + -> [UInt] +{ let count = try readPackedArrayElementCount(context, width: 8, label: "uint64_array") + if reserveGraphStorage { + try reserveSerializerArrayMemory( + context, UInt.self, ownerBytes: fieldOwnerBytes([UInt].self), count: count) + } var values: [UInt] = [] values.reserveCapacity(count) for _ in 0..( elementCodec _: ElementCodec.Type ) throws -> [ElementCodec.Value] { if ElementCodec.self == BoolCodec.self { - return uncheckedPackedArrayCast(try readPrimitiveArray(context) as [Bool], to: ElementCodec.Value.self) + return uncheckedPackedArrayCast( + try readPrimitiveArray(context, reserveGraphStorage: true) as [Bool], + to: ElementCodec.Value.self) } if ElementCodec.self == Int8Codec.self { - return uncheckedPackedArrayCast(try readPrimitiveArray(context) as [Int8], to: ElementCodec.Value.self) + return uncheckedPackedArrayCast( + try readPrimitiveArray(context, reserveGraphStorage: true) as [Int8], + to: ElementCodec.Value.self) } if ElementCodec.self == Int16Codec.self { - return uncheckedPackedArrayCast(try readPrimitiveArray(context) as [Int16], to: ElementCodec.Value.self) + return uncheckedPackedArrayCast( + try readPrimitiveArray(context, reserveGraphStorage: true) as [Int16], + to: ElementCodec.Value.self) } if ElementCodec.self == Int32FixedCodec.self || ElementCodec.self == Int32VarintCodec.self { - return uncheckedPackedArrayCast(try readPrimitiveArray(context) as [Int32], to: ElementCodec.Value.self) - } - if ElementCodec.self == Int64FixedCodec.self || ElementCodec.self == Int64VarintCodec.self || ElementCodec.self == Int64TaggedCodec.self { - return uncheckedPackedArrayCast(try readPrimitiveArray(context) as [Int64], to: ElementCodec.Value.self) - } - if ElementCodec.self == IntFixedCodec.self || ElementCodec.self == IntVarintCodec.self || ElementCodec.self == IntTaggedCodec.self { - return uncheckedPackedArrayCast(try readIntArrayPayload(context), to: ElementCodec.Value.self) + return uncheckedPackedArrayCast( + try readPrimitiveArray(context, reserveGraphStorage: true) as [Int32], + to: ElementCodec.Value.self) + } + if ElementCodec.self == Int64FixedCodec.self || ElementCodec.self == Int64VarintCodec.self + || ElementCodec.self == Int64TaggedCodec.self + { + return uncheckedPackedArrayCast( + try readPrimitiveArray(context, reserveGraphStorage: true) as [Int64], + to: ElementCodec.Value.self) + } + if ElementCodec.self == IntFixedCodec.self || ElementCodec.self == IntVarintCodec.self + || ElementCodec.self == IntTaggedCodec.self + { + return uncheckedPackedArrayCast( + try readIntArrayPayload(context, reserveGraphStorage: true), to: ElementCodec.Value.self) } if ElementCodec.self == UInt8Codec.self { - return uncheckedPackedArrayCast(try readPrimitiveArray(context) as [UInt8], to: ElementCodec.Value.self) + return uncheckedPackedArrayCast( + try readPrimitiveArray(context, reserveGraphStorage: true) as [UInt8], + to: ElementCodec.Value.self) } if ElementCodec.self == UInt16Codec.self { - return uncheckedPackedArrayCast(try readPrimitiveArray(context) as [UInt16], to: ElementCodec.Value.self) + return uncheckedPackedArrayCast( + try readPrimitiveArray(context, reserveGraphStorage: true) as [UInt16], + to: ElementCodec.Value.self) } if ElementCodec.self == UInt32FixedCodec.self || ElementCodec.self == UInt32VarintCodec.self { - return uncheckedPackedArrayCast(try readPrimitiveArray(context) as [UInt32], to: ElementCodec.Value.self) - } - if ElementCodec.self == UInt64FixedCodec.self || ElementCodec.self == UInt64VarintCodec.self || ElementCodec.self == UInt64TaggedCodec.self { - return uncheckedPackedArrayCast(try readPrimitiveArray(context) as [UInt64], to: ElementCodec.Value.self) - } - if ElementCodec.self == UIntFixedCodec.self || ElementCodec.self == UIntVarintCodec.self || ElementCodec.self == UIntTaggedCodec.self { - return uncheckedPackedArrayCast(try readUIntArrayPayload(context), to: ElementCodec.Value.self) + return uncheckedPackedArrayCast( + try readPrimitiveArray(context, reserveGraphStorage: true) as [UInt32], + to: ElementCodec.Value.self) + } + if ElementCodec.self == UInt64FixedCodec.self || ElementCodec.self == UInt64VarintCodec.self + || ElementCodec.self == UInt64TaggedCodec.self + { + return uncheckedPackedArrayCast( + try readPrimitiveArray(context, reserveGraphStorage: true) as [UInt64], + to: ElementCodec.Value.self) + } + if ElementCodec.self == UIntFixedCodec.self || ElementCodec.self == UIntVarintCodec.self + || ElementCodec.self == UIntTaggedCodec.self + { + return uncheckedPackedArrayCast( + try readUIntArrayPayload(context, reserveGraphStorage: true), to: ElementCodec.Value.self) } if ElementCodec.self == Float16Codec.self { - return uncheckedPackedArrayCast(try readPrimitiveArray(context) as [Float16], to: ElementCodec.Value.self) + return uncheckedPackedArrayCast( + try readPrimitiveArray(context, reserveGraphStorage: true) as [Float16], + to: ElementCodec.Value.self) } if ElementCodec.self == BFloat16Codec.self { - return uncheckedPackedArrayCast(try readPrimitiveArray(context) as [BFloat16], to: ElementCodec.Value.self) + return uncheckedPackedArrayCast( + try readPrimitiveArray(context, reserveGraphStorage: true) as [BFloat16], + to: ElementCodec.Value.self) } if ElementCodec.self == FloatCodec.self { - return uncheckedPackedArrayCast(try readPrimitiveArray(context) as [Float], to: ElementCodec.Value.self) + return uncheckedPackedArrayCast( + try readPrimitiveArray(context, reserveGraphStorage: true) as [Float], + to: ElementCodec.Value.self) } if ElementCodec.self == DoubleCodec.self { - return uncheckedPackedArrayCast(try readPrimitiveArray(context) as [Double], to: ElementCodec.Value.self) + return uncheckedPackedArrayCast( + try readPrimitiveArray(context, reserveGraphStorage: true) as [Double], + to: ElementCodec.Value.self) } - throw ForyError.invalidData("unsupported compatible array-to-list field element codec \(ElementCodec.self)") + throw ForyError.invalidData( + "unsupported compatible array-to-list field element codec \(ElementCodec.self)") } private func readCompatibleElementPayload( @@ -1441,33 +1613,44 @@ private func readCompatibleElementPayload( if ElementCodec.self == Int32FixedCodec.self || ElementCodec.self == Int32VarintCodec.self { switch remoteTypeID { case .int32: - return uncheckedScalarCast(try context.buffer.readInt32() as Int32, to: ElementCodec.Value.self) + return uncheckedScalarCast( + try context.buffer.readInt32() as Int32, to: ElementCodec.Value.self) case .varint32: - return uncheckedScalarCast(try context.buffer.readVarInt32() as Int32, to: ElementCodec.Value.self) + return uncheckedScalarCast( + try context.buffer.readVarInt32() as Int32, to: ElementCodec.Value.self) default: break } } - if ElementCodec.self == Int64FixedCodec.self || ElementCodec.self == Int64VarintCodec.self || ElementCodec.self == Int64TaggedCodec.self { + if ElementCodec.self == Int64FixedCodec.self || ElementCodec.self == Int64VarintCodec.self + || ElementCodec.self == Int64TaggedCodec.self + { switch remoteTypeID { case .int64: - return uncheckedScalarCast(try context.buffer.readInt64() as Int64, to: ElementCodec.Value.self) + return uncheckedScalarCast( + try context.buffer.readInt64() as Int64, to: ElementCodec.Value.self) case .varint64: - return uncheckedScalarCast(try context.buffer.readVarInt64() as Int64, to: ElementCodec.Value.self) + return uncheckedScalarCast( + try context.buffer.readVarInt64() as Int64, to: ElementCodec.Value.self) case .taggedInt64: - return uncheckedScalarCast(try context.buffer.readTaggedInt64() as Int64, to: ElementCodec.Value.self) + return uncheckedScalarCast( + try context.buffer.readTaggedInt64() as Int64, to: ElementCodec.Value.self) default: break } } - if ElementCodec.self == IntFixedCodec.self || ElementCodec.self == IntVarintCodec.self || ElementCodec.self == IntTaggedCodec.self { + if ElementCodec.self == IntFixedCodec.self || ElementCodec.self == IntVarintCodec.self + || ElementCodec.self == IntTaggedCodec.self + { switch remoteTypeID { case .int64: return uncheckedScalarCast(Int(try context.buffer.readInt64()), to: ElementCodec.Value.self) case .varint64: - return uncheckedScalarCast(Int(try context.buffer.readVarInt64()), to: ElementCodec.Value.self) + return uncheckedScalarCast( + Int(try context.buffer.readVarInt64()), to: ElementCodec.Value.self) case .taggedInt64: - return uncheckedScalarCast(Int(try context.buffer.readTaggedInt64()), to: ElementCodec.Value.self) + return uncheckedScalarCast( + Int(try context.buffer.readTaggedInt64()), to: ElementCodec.Value.self) default: break } @@ -1475,33 +1658,44 @@ private func readCompatibleElementPayload( if ElementCodec.self == UInt32FixedCodec.self || ElementCodec.self == UInt32VarintCodec.self { switch remoteTypeID { case .uint32: - return uncheckedScalarCast(try context.buffer.readUInt32() as UInt32, to: ElementCodec.Value.self) + return uncheckedScalarCast( + try context.buffer.readUInt32() as UInt32, to: ElementCodec.Value.self) case .varUInt32: - return uncheckedScalarCast(try context.buffer.readVarUInt32() as UInt32, to: ElementCodec.Value.self) + return uncheckedScalarCast( + try context.buffer.readVarUInt32() as UInt32, to: ElementCodec.Value.self) default: break } } - if ElementCodec.self == UInt64FixedCodec.self || ElementCodec.self == UInt64VarintCodec.self || ElementCodec.self == UInt64TaggedCodec.self { + if ElementCodec.self == UInt64FixedCodec.self || ElementCodec.self == UInt64VarintCodec.self + || ElementCodec.self == UInt64TaggedCodec.self + { switch remoteTypeID { case .uint64: - return uncheckedScalarCast(try context.buffer.readUInt64() as UInt64, to: ElementCodec.Value.self) + return uncheckedScalarCast( + try context.buffer.readUInt64() as UInt64, to: ElementCodec.Value.self) case .varUInt64: - return uncheckedScalarCast(try context.buffer.readVarUInt64() as UInt64, to: ElementCodec.Value.self) + return uncheckedScalarCast( + try context.buffer.readVarUInt64() as UInt64, to: ElementCodec.Value.self) case .taggedUInt64: - return uncheckedScalarCast(try context.buffer.readTaggedUInt64() as UInt64, to: ElementCodec.Value.self) + return uncheckedScalarCast( + try context.buffer.readTaggedUInt64() as UInt64, to: ElementCodec.Value.self) default: break } } - if ElementCodec.self == UIntFixedCodec.self || ElementCodec.self == UIntVarintCodec.self || ElementCodec.self == UIntTaggedCodec.self { + if ElementCodec.self == UIntFixedCodec.self || ElementCodec.self == UIntVarintCodec.self + || ElementCodec.self == UIntTaggedCodec.self + { switch remoteTypeID { case .uint64: return uncheckedScalarCast(UInt(try context.buffer.readUInt64()), to: ElementCodec.Value.self) case .varUInt64: - return uncheckedScalarCast(UInt(try context.buffer.readVarUInt64()), to: ElementCodec.Value.self) + return uncheckedScalarCast( + UInt(try context.buffer.readVarUInt64()), to: ElementCodec.Value.self) case .taggedUInt64: - return uncheckedScalarCast(UInt(try context.buffer.readTaggedUInt64()), to: ElementCodec.Value.self) + return uncheckedScalarCast( + UInt(try context.buffer.readTaggedUInt64()), to: ElementCodec.Value.self) default: break } @@ -1586,12 +1780,15 @@ private func writeCollectionPayload( private func readCollectionPayload( _ context: ReadContext, - elementCodec _: ElementCodec.Type + elementCodec _: ElementCodec.Type, + ownerBytes: Int ) throws -> [ElementCodec.Value] { let buffer = context.buffer let length = Int(try buffer.readVarUInt32()) try context.ensureCollectionLength(length, label: "array") if length == 0 { + try reserveFieldArrayStorage( + context, ElementCodec.self, ownerBytes: ownerBytes, count: length) return [] } @@ -1606,6 +1803,7 @@ private func readCollectionPayload( let sameType = (header & CollectionHeader.sameType) != 0 var result: [ElementCodec.Value] = [] + try reserveFieldArrayStorage(context, ElementCodec.self, ownerBytes: ownerBytes, count: length) try context.ensureRemainingBytes(length, label: "array") result.reserveCapacity(length) diff --git a/swift/Sources/Fory/FieldSkipper.swift b/swift/Sources/Fory/FieldSkipper.swift index 7fce18fe08..9183d2f58f 100644 --- a/swift/Sources/Fory/FieldSkipper.swift +++ b/swift/Sources/Fory/FieldSkipper.swift @@ -17,8 +17,8 @@ import Foundation -public extension ReadContext { - func skipFieldValue(_ fieldType: TypeMeta.FieldType) throws { +extension ReadContext { + public func skipFieldValue(_ fieldType: TypeMeta.FieldType) throws { _ = try readSkippedFieldValue( fieldType: fieldType, readTypeInfo: needsTypeInfoForSkippedField(fieldType.typeID) diff --git a/swift/Sources/Fory/Fory.swift b/swift/Sources/Fory/Fory.swift index 7b537e60ab..567aedc74f 100644 --- a/swift/Sources/Fory/Fory.swift +++ b/swift/Sources/Fory/Fory.swift @@ -22,6 +22,7 @@ public struct Config { public let compatible: Bool public let checkClassVersion: Bool public let maxDepth: Int + public let maxGraphMemoryBytes: Int64 public let maxTypeFields: Int public let maxTypeMetaBytes: Int public let maxSchemaVersionsPerType: Int @@ -32,6 +33,7 @@ public struct Config { compatible: Bool? = nil, checkClassVersion: Bool? = nil, maxDepth: Int = 5, + maxGraphMemoryBytes: Int64 = 128 * 1024 * 1024, maxTypeFields: Int = 512, maxTypeMetaBytes: Int = 4096, maxSchemaVersionsPerType: Int = 10, @@ -43,12 +45,16 @@ public struct Config { precondition( maxAverageSchemaVersionsPerType > 0, "maxAverageSchemaVersionsPerType must be positive") + precondition( + maxGraphMemoryBytes > 0 && maxGraphMemoryBytes <= Int64(Int.max), + "maxGraphMemoryBytes must be a positive byte limit") let effectiveCompatible = compatible ?? true let effectiveCheckClassVersion = checkClassVersion ?? !effectiveCompatible self.trackRef = trackRef self.compatible = effectiveCompatible self.checkClassVersion = effectiveCheckClassVersion self.maxDepth = maxDepth + self.maxGraphMemoryBytes = maxGraphMemoryBytes self.maxTypeFields = maxTypeFields self.maxTypeMetaBytes = maxTypeMetaBytes self.maxSchemaVersionsPerType = maxSchemaVersionsPerType @@ -72,6 +78,7 @@ public final class Fory { compatible: Bool? = nil, checkClassVersion: Bool? = nil, maxDepth: Int = 5, + maxGraphMemoryBytes: Int64 = 128 * 1024 * 1024, maxTypeFields: Int = 512, maxTypeMetaBytes: Int = 4096, maxSchemaVersionsPerType: Int = 10, @@ -83,6 +90,7 @@ public final class Fory { compatible: compatible, checkClassVersion: checkClassVersion, maxDepth: maxDepth, + maxGraphMemoryBytes: maxGraphMemoryBytes, maxTypeFields: maxTypeFields, maxTypeMetaBytes: maxTypeMetaBytes, maxSchemaVersionsPerType: maxSchemaVersionsPerType, @@ -486,8 +494,9 @@ public final class Fory { func withReusableReadContext( data: Data, _ body: (ReadContext) throws -> R - ) rethrows -> R { + ) throws -> R { readContext.buffer.replace(with: data) + readContext.remainingGraphMemoryBytes = Int(self.config.maxGraphMemoryBytes) defer { readContext.reset() } @@ -549,6 +558,7 @@ public final class Fory { ) throws -> R { try typeResolver.finishRegistration() readContext.buffer.swapState(with: buffer) + readContext.remainingGraphMemoryBytes = Int(self.config.maxGraphMemoryBytes) defer { readContext.buffer.swapState(with: buffer) readContext.reset() diff --git a/swift/Sources/Fory/ReadContext.swift b/swift/Sources/Fory/ReadContext.swift index 84505e738b..50f4aca171 100644 --- a/swift/Sources/Fory/ReadContext.swift +++ b/swift/Sources/Fory/ReadContext.swift @@ -35,6 +35,7 @@ public final class ReadContext { private var typeInfoScopeStack: [(typeKey: UInt64, previousTypeInfo: TypeInfo?)] = [] private var lastTypeInfo = TypeInfo.uncached private let config: Config + var remainingGraphMemoryBytes = 0 init( buffer: ByteBuffer, @@ -51,6 +52,30 @@ public final class ReadContext { self.refReader = RefReader() } + @inline(__always) + public func reserveGraphMemory(_ bytes: Int) throws { + if _slowPath(bytes < 0) { + try throwGraphMemoryOverflow() + } + if _slowPath(bytes > remainingGraphMemoryBytes) { + try throwGraphMemoryExceeded(bytes: bytes) + } + remainingGraphMemoryBytes -= bytes + } + + @inline(never) + private func throwGraphMemoryOverflow() throws -> Never { + throw ForyError.invalidData("graph memory estimate overflows") + } + + @inline(never) + private func throwGraphMemoryExceeded(bytes: Int) throws -> Never { + let message = + "estimated graph memory request \(bytes) bytes exceeds maxGraphMemoryBytes " + + "remaining budget \(remainingGraphMemoryBytes) bytes" + throw ForyError.invalidData(message) + } + @inline(__always) func enterDynamicAnyDepth() throws { if maxDepth < 0 { @@ -357,7 +382,8 @@ public final class ReadContext { for: localTypeInfo, wireTypeID: wireTypeID) compatibleTypeDefTypeInfos.push(cachedTypeInfo) - return try validateCompatibleTypeInfo(cachedTypeInfo, for: localTypeInfo, wireTypeID: wireTypeID) + return try validateCompatibleTypeInfo( + cachedTypeInfo, for: localTypeInfo, wireTypeID: wireTypeID) } @inline(__always) @@ -714,7 +740,46 @@ public final class ReadContext { } } +private let anyReferenceBytes = 4 +private let anyArrayOwnerBytes = max(1, MemoryLayout<[Any]>.stride) + +@inline(never) +private func throwAnyGraphMemoryOverflow() throws -> Never { + throw ForyError.invalidData("graph memory estimate overflows") +} + +@inline(__always) +private func reserveAnyReferenceArrayMemory(_ context: ReadContext, count: Int) throws { + let (slotBytes, overflow) = count.multipliedReportingOverflow(by: anyReferenceBytes) + if overflow { + try throwAnyGraphMemoryOverflow() + } + let (bytes, addOverflow) = anyArrayOwnerBytes.addingReportingOverflow(slotBytes) + if addOverflow { + try throwAnyGraphMemoryOverflow() + } + try context.reserveGraphMemory(bytes) +} + +@inline(__always) +private func reserveAnyReferenceMapMemory( + _ context: ReadContext, _ type: Map.Type, count: Int +) throws { + let (slotBytes, overflow) = count.multipliedReportingOverflow(by: 2 * anyReferenceBytes) + if overflow { + try throwAnyGraphMemoryOverflow() + } + let ownerBytes = max(1, MemoryLayout.stride) + let (bytes, addOverflow) = ownerBytes.addingReportingOverflow(slotBytes) + if addOverflow { + try throwAnyGraphMemoryOverflow() + } + try context.reserveGraphMemory(bytes) +} + extension ReadContext { + // Swift `Any` cannot conform to `Serializer`, so generated and hand-written dynamic-Any + // readers must enter through these context methods instead of trying `Any.foryRead(...)`. public func readAny( refMode: RefMode, readTypeInfo: Bool = true @@ -731,7 +796,11 @@ extension ReadContext { refMode: refMode, readTypeInfo: readTypeInfo ) - return wrapped?.map { $0.anyValueForCollection() } + guard let wrapped else { + return nil + } + try reserveAnyReferenceArrayMemory(self, count: wrapped.count) + return wrapped.map { $0.anyValueForCollection() } } public func readMapStringToAny( @@ -746,6 +815,7 @@ extension ReadContext { guard let wrapped else { return nil } + try reserveAnyReferenceMapMemory(self, [String: Any].self, count: wrapped.count) var map: [String: Any] = [:] map.reserveCapacity(wrapped.count) for pair in wrapped { @@ -766,6 +836,7 @@ extension ReadContext { guard let wrapped else { return nil } + try reserveAnyReferenceMapMemory(self, [Int32: Any].self, count: wrapped.count) var map: [Int32: Any] = [:] map.reserveCapacity(wrapped.count) for pair in wrapped { @@ -786,6 +857,7 @@ extension ReadContext { guard let wrapped else { return nil } + try reserveAnyReferenceMapMemory(self, [AnyHashable: Any].self, count: wrapped.count) var map: [AnyHashable: Any] = [:] map.reserveCapacity(wrapped.count) for pair in wrapped { diff --git a/swift/Sources/ForyMacro/ForyObjectMacroReadGeneration.swift b/swift/Sources/ForyMacro/ForyObjectMacroReadGeneration.swift index 0c37178078..495743f133 100644 --- a/swift/Sources/ForyMacro/ForyObjectMacroReadGeneration.swift +++ b/swift/Sources/ForyMacro/ForyObjectMacroReadGeneration.swift @@ -47,6 +47,24 @@ func buildReadCompatibleDataDecl( fields: fields, sortedFields: sortedFields, accessPrefix: accessPrefix) } +private func graphFieldBytesExpr(_ field: ParsedField) -> String { + if field.primitiveSize > 0 { + return "\(field.primitiveSize)" + } + return "(\(field.typeText).isRefType ? 4 : max(1, MemoryLayout<\(field.typeText)>.stride))" +} + +private func classGraphOwnerBytesExpr(_ fields: [ParsedField]) -> String { + if fields.isEmpty { + return "1" + } + return "max(1, 1 + " + fields.map(graphFieldBytesExpr).joined(separator: " + ") + ")" +} + +private func reserveClassGraphOwnerLine(fields: [ParsedField], indent: String) -> String { + "\(indent)try context.reserveGraphMemory(\(classGraphOwnerBytesExpr(fields)))" +} + func buildClassReadWrapperDecl(accessPrefix: String) -> String { """ @inline(__always) @@ -109,6 +127,7 @@ private func buildClassReadDataDecl( private static func __foryReadDataImpl(_ context: ReadContext, reservedRefID: UInt32?) throws -> Self { let __buffer = context.buffer \(schemaHashCheckExpr()) + \(reserveClassGraphOwnerLine(fields: sortedFields, indent: " ")) let value = Self.init() if let reservedRefID { context.refReader.storeRef(value, at: reservedRefID) @@ -130,6 +149,7 @@ private func buildEmptyStructReadDataDecl(accessPrefix: String) -> String { \(accessPrefix)static func foryReadData(_ context: ReadContext) throws -> Self { let __buffer = context.buffer \(schemaHashCheckExpr()) + try context.reserveGraphMemory(1) return Self() } """ @@ -153,6 +173,7 @@ private func buildStructReadDataDecl( \(accessPrefix)static func foryReadData(_ context: ReadContext) throws -> Self { let __buffer = context.buffer \(schemaHashCheckExpr()) + try context.reserveGraphMemory(MemoryLayout.stride) \(schemaReadBody) return Self( \(ctorArgs) @@ -183,7 +204,8 @@ private func buildClassReadCompatibleDataDecl( || compatibleCases.contains("__buffer")) ? "let __buffer = context.buffer\n " : "" let localFieldsBinding = compatibleCases.contains("__foryLocalFields") - ? "let __foryLocalFields = remoteTypeInfo.typeMeta?.fields ?? Self.foryFieldsInfo(trackRef: context.trackRef)\n " : "" + ? "let __foryLocalFields = remoteTypeInfo.typeMeta?.fields ?? Self.foryFieldsInfo(trackRef: context.trackRef)\n " + : "" return """ @inline(never) @@ -195,6 +217,7 @@ private func buildClassReadCompatibleDataDecl( \(bufferBinding)guard let typeMeta = remoteTypeInfo.compatibleTypeMeta else { throw ForyError.invalidData("compatible type metadata is required") } + \(reserveClassGraphOwnerLine(fields: sortedFields, indent: " ")) let value = Self.init() if let reservedRefID { context.refReader.storeRef(value, at: reservedRefID) @@ -236,6 +259,7 @@ private func buildEmptyStructReadCompatibleDataDecl(accessPrefix: String) -> Str guard let typeMeta = remoteTypeInfo.compatibleTypeMeta else { throw ForyError.invalidData("compatible type metadata is required") } + try context.reserveGraphMemory(1) if let localTypeMeta = remoteTypeInfo.typeMeta, let localHeaderHash = remoteTypeInfo.typeDefHeaderHash, typeMeta.headerHash == localHeaderHash, @@ -290,6 +314,7 @@ private func buildStructReadCompatibleDataDecl( \(bufferBinding)guard let typeMeta = remoteTypeInfo.compatibleTypeMeta else { throw ForyError.invalidData("compatible type metadata is required") } + try context.reserveGraphMemory(MemoryLayout.stride) if let localTypeMeta = remoteTypeInfo.typeMeta, let localHeaderHash = remoteTypeInfo.typeDefHeaderHash, typeMeta.headerHash == localHeaderHash, @@ -351,7 +376,8 @@ private func buildClassAssignBody( primitiveFastFields: [ParsedField], compatibleAligned: Bool ) -> String { - let remainingAssignLines = sortedFields.dropFirst(primitiveFastFields.count).map { field -> String in + let remainingAssignLines = sortedFields.dropFirst(primitiveFastFields.count).map { + field -> String in let valueExpr: String if compatibleAligned { valueExpr = compatibleSchemaReadFieldExpr(field) @@ -383,7 +409,8 @@ private func buildStructReadBody( primitiveFastFields: [ParsedField], compatibleAligned: Bool ) -> String { - let remainingReadLines = sortedFields.dropFirst(primitiveFastFields.count).map { field -> String in + let remainingReadLines = sortedFields.dropFirst(primitiveFastFields.count).map { + field -> String in let valueExpr = compatibleAligned ? compatibleSchemaReadFieldExpr(field) : schemaReadFieldExpr(field) return "let __\(field.name) = \(valueExpr)" @@ -655,7 +682,7 @@ private func dynamicAnyReadExpr( ? ", readTypeInfo: true" : "" return - "try castAnyDynamicValue(context.\(method)(refMode: \(refModeExpr)\(readTypeInfoExpr)), to: \(metatypeExpr))" + "try castAnyDynamicValue(\(method)(context: context, refMode: \(refModeExpr)\(readTypeInfoExpr)), to: \(metatypeExpr))" } private func compatibleDefaultDecl(_ field: ParsedField) -> String { diff --git a/swift/Tests/ForyTests/AnyTests.swift b/swift/Tests/ForyTests/AnyTests.swift index e7eb1c3e3f..0dec3f3a92 100644 --- a/swift/Tests/ForyTests/AnyTests.swift +++ b/swift/Tests/ForyTests/AnyTests.swift @@ -92,6 +92,90 @@ private func nestedDynamicAnyList(depth: Int) -> Any { return value } +private func readAnyRoot( + _ fory: Fory, + data: Data, + _ body: (ReadContext) throws -> R +) throws -> R { + try fory.withReusableReadContext(data: data) { context in + try fory.readHead(buffer: context.buffer) + let value = try body(context) + #expect(context.buffer.remaining == 0) + return value + } +} + +@Test +func contextAnyReadersRoundTrip() throws { + let fory = Fory(config: .init(trackRef: false, compatible: false)) + fory.register(AnyHashableDynamicKey.self, id: 510) + fory.register(AnyHashableDynamicValue.self, id: 511) + + let anyValue: Any = AnyHashableDynamicValue(label: "context-any", score: 1) + let anyDecoded = try readAnyRoot(fory, data: try fory.serialize(anyValue)) { context in + try context.readAny(refMode: .nullOnly, readTypeInfo: true) + } + #expect( + anyDecoded as? AnyHashableDynamicValue + == AnyHashableDynamicValue(label: "context-any", score: 1) + ) + + let listValue: [Any] = [ + Int32(2), + "context-list", + AnyHashableDynamicValue(label: "context-list-obj", score: 3) + ] + let listDecoded = try readAnyRoot(fory, data: try fory.serialize(listValue)) { context in + try context.readListOfAny(refMode: .nullOnly, readTypeInfo: true) + } + #expect(listDecoded?[0] as? Int32 == 2) + #expect(listDecoded?[1] as? String == "context-list") + #expect( + listDecoded?[2] as? AnyHashableDynamicValue + == AnyHashableDynamicValue(label: "context-list-obj", score: 3) + ) + + let stringMapValue: [String: Any] = [ + "a": Int32(4), + "b": AnyHashableDynamicValue(label: "context-string-map", score: 5) + ] + let stringMapDecoded = try readAnyRoot(fory, data: try fory.serialize(stringMapValue)) { context in + try context.readMapStringToAny(refMode: .nullOnly, readTypeInfo: true) + } + #expect(stringMapDecoded?["a"] as? Int32 == 4) + #expect( + stringMapDecoded?["b"] as? AnyHashableDynamicValue + == AnyHashableDynamicValue(label: "context-string-map", score: 5) + ) + + let int32MapValue: [Int32: Any] = [ + 6: "context-int-map", + 7: AnyHashableDynamicValue(label: "context-int-map-obj", score: 8) + ] + let int32MapDecoded = try readAnyRoot(fory, data: try fory.serialize(int32MapValue)) { context in + try context.readMapInt32ToAny(refMode: .nullOnly, readTypeInfo: true) + } + #expect(int32MapDecoded?[6] as? String == "context-int-map") + #expect( + int32MapDecoded?[7] as? AnyHashableDynamicValue + == AnyHashableDynamicValue(label: "context-int-map-obj", score: 8) + ) + + let anyHashableMapValue: [AnyHashable: Any] = [ + AnyHashable("x"): Int32(9), + AnyHashable(AnyHashableDynamicKey(id: 10)): + AnyHashableDynamicValue(label: "context-any-map", score: 11) + ] + let anyHashableMapDecoded = try readAnyRoot(fory, data: try fory.serialize(anyHashableMapValue)) { context in + try context.readMapAnyHashableToAny(refMode: .nullOnly, readTypeInfo: true) + } + #expect(anyHashableMapDecoded?[AnyHashable("x")] as? Int32 == 9) + #expect( + anyHashableMapDecoded?[AnyHashable(AnyHashableDynamicKey(id: 10))] as? AnyHashableDynamicValue + == AnyHashableDynamicValue(label: "context-any-map", score: 11) + ) +} + @Test func topLevelAnyHashableRoundTrip() throws { let fory = Fory() diff --git a/swift/Tests/ForyTests/ForySwiftTests.swift b/swift/Tests/ForyTests/ForySwiftTests.swift index 897c881968..ea26571f17 100644 --- a/swift/Tests/ForyTests/ForySwiftTests.swift +++ b/swift/Tests/ForyTests/ForySwiftTests.swift @@ -382,6 +382,7 @@ func namedInitializerBuildsConfig() { #expect(defaultConfig.config.compatible == true) #expect(defaultConfig.config.checkClassVersion == false) #expect(defaultConfig.config.maxDepth == 5) + #expect(defaultConfig.config.maxGraphMemoryBytes == 128 * 1024 * 1024) #expect(defaultConfig.config.maxTypeFields == 512) #expect(defaultConfig.config.maxTypeMetaBytes == 4096) #expect(defaultConfig.config.maxSchemaVersionsPerType == 10) @@ -391,6 +392,7 @@ func namedInitializerBuildsConfig() { ref: true, compatible: true, maxDepth: 7, + maxGraphMemoryBytes: 65_536, maxTypeFields: 31, maxTypeMetaBytes: 1234, maxSchemaVersionsPerType: 12, @@ -400,6 +402,7 @@ func namedInitializerBuildsConfig() { #expect(explicitConfig.config.compatible == true) #expect(explicitConfig.config.checkClassVersion == false) #expect(explicitConfig.config.maxDepth == 7) + #expect(explicitConfig.config.maxGraphMemoryBytes == 65_536) #expect(explicitConfig.config.maxTypeFields == 31) #expect(explicitConfig.config.maxTypeMetaBytes == 1234) #expect(explicitConfig.config.maxSchemaVersionsPerType == 12) @@ -410,6 +413,7 @@ func namedInitializerBuildsConfig() { trackRef: false, compatible: true, maxDepth: 9, + maxGraphMemoryBytes: 131_072, maxTypeFields: 41, maxTypeMetaBytes: 2048, maxSchemaVersionsPerType: 14, @@ -419,6 +423,7 @@ func namedInitializerBuildsConfig() { #expect(configInit.config.compatible == true) #expect(configInit.config.checkClassVersion == false) #expect(configInit.config.maxDepth == 9) + #expect(configInit.config.maxGraphMemoryBytes == 131_072) #expect(configInit.config.maxTypeFields == 41) #expect(configInit.config.maxTypeMetaBytes == 2048) #expect(configInit.config.maxSchemaVersionsPerType == 14) @@ -1285,7 +1290,8 @@ func macroFieldOrderFollowsForyRules() throws { let second = try buffer.readVarInt64() let third = try buffer.readVarInt32() - let tailContext = ReadContext(buffer: buffer, typeResolver: fory.typeResolver, config: fory.config) + let tailContext = ReadContext( + buffer: buffer, typeResolver: fory.typeResolver, config: fory.config) let fourth = try String.foryReadData(tailContext) #expect(first == value.shortValue) @@ -1313,7 +1319,8 @@ func macroTaggedFieldsKeepGroupedPayloadOrder() throws { _ = try buffer.readInt32() #expect(try buffer.readVarInt32() == value.intValue) - let tailContext = ReadContext(buffer: buffer, typeResolver: fory.typeResolver, config: fory.config) + let tailContext = ReadContext( + buffer: buffer, typeResolver: fory.typeResolver, config: fory.config) #expect(try String.foryReadData(tailContext) == value.textTail) } diff --git a/swift/Tests/ForyTests/GraphMemoryBudgetTests.swift b/swift/Tests/ForyTests/GraphMemoryBudgetTests.swift new file mode 100644 index 0000000000..1d77b77dee --- /dev/null +++ b/swift/Tests/ForyTests/GraphMemoryBudgetTests.swift @@ -0,0 +1,456 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +import Foundation +import Testing + +@testable import Fory + +@ForyStruct +private final class BudgetNode { + var id: Int32 = 0 + + required init() {} + + init(id: Int32) { + self.id = id + } +} + +@ForyStruct +private struct BudgetSiblings { + var left: [BudgetNode] = [] + var right: [BudgetNode] = [] +} + +@ForyStruct +private struct BudgetDenseHolder: Equatable { + var text: String = "" + var data: Data = Data() + @ArrayField(element: .int32()) + var dense: [Int32] = [] +} + +@ForyStruct +private struct BudgetListDenseWriter { + var dense: [Int32] = [] +} + +@ForyStruct +private struct BudgetListDenseReader: Equatable { + @ArrayField(element: .int32()) + var dense: [Int32] = [] +} + +private let defaultGraphMemoryBytes: Int64 = 128 * 1024 * 1024 + +private func makeBudgetFory(maxGraphMemoryBytes: Int64 = defaultGraphMemoryBytes) -> Fory { + let fory = Fory( + config: .init( + trackRef: false, + compatible: false, + maxGraphMemoryBytes: maxGraphMemoryBytes + )) + fory.register(BudgetNode.self, id: 9801) + fory.register(BudgetSiblings.self, id: 9802) + fory.register(BudgetDenseHolder.self, id: 9803) + return fory +} + +private func makeCompatibleBudgetFory(maxGraphMemoryBytes: Int64 = defaultGraphMemoryBytes) -> Fory { + Fory( + config: .init( + trackRef: false, + compatible: true, + maxGraphMemoryBytes: maxGraphMemoryBytes + )) +} + +private let testReferenceBytes = 4 +private let budgetNodeGraphBytes = 1 + 4 + +private func elementBytes(_ type: Element.Type) -> Int { + type.isRefType ? testReferenceBytes : max(1, MemoryLayout.stride) +} + +private func ownerBytes(_ type: T.Type) -> Int { + max(1, MemoryLayout.stride) +} + +private func arrayBudget(_ type: Element.Type, count: Int) -> Int { + count * elementBytes(type) +} + +private func listBudget( + _ type: Element.Type, + count: Int, + elementOwnerBytes: Int = 0 +) -> Int { + ownerBytes([Element].self) + arrayBudget(type, count: count) + count * elementOwnerBytes +} + +private func rootArrayBudget( + _ type: Element.Type, + count: Int, + elementOwnerBytes: Int = 0 +) -> Int { + listBudget(type, count: count, elementOwnerBytes: elementOwnerBytes) +} + +private func mapBudget( + key: Key.Type, + value: Value.Type, + count: Int +) -> Int { + count * (elementBytes(key) + elementBytes(value)) +} + +private func dictionaryBudget( + key: Key.Type, + value: Value.Type, + count: Int +) -> Int { + ownerBytes(Dictionary.self) + mapBudget(key: key, value: value, count: count) +} + +private func rootMapBudget( + key: Key.Type, + value: Value.Type, + count: Int +) -> Int { + dictionaryBudget(key: key, value: value, count: count) +} + +private func expectInvalidData(_ body: () throws -> Void) { + do { + try body() + Issue.record("expected invalid data") + } catch ForyError.invalidData { + } catch { + Issue.record("expected invalid data, got \(error)") + } +} + +@Test +func fixedDefaultBudget() throws { + let fory = makeBudgetFory() + #expect(fory.config.maxGraphMemoryBytes == defaultGraphMemoryBytes) + let value = Array(repeating: [String](), count: 3) + #expect(try fory.deserialize(try fory.serialize(value)) == value) +} + +@Test +func byteBufferRootDefaultBudget() throws { + let count = 6 + let value = Array(repeating: [String](), count: count) + let bytes = try makeBudgetFory().serialize(value) + let buffer = ByteBuffer(data: bytes) + + let decoded: [[String]] = try makeBudgetFory().deserialize(from: buffer) + #expect(decoded.count == count) +} + +@Test +func explicitConfigOverridesDefault() throws { + let values = (0..<16).map { "value-\($0)" } + let bytes = try makeBudgetFory().serialize(values) + let required = rootArrayBudget(String.self, count: values.count) + + expectInvalidData { + let _: [String] = try makeBudgetFory(maxGraphMemoryBytes: Int64(required - 1)).deserialize( + bytes) + } + let decoded: [String] = try makeBudgetFory(maxGraphMemoryBytes: Int64(required)).deserialize( + bytes) + #expect(decoded == values) +} + +@Test +func siblingContainersShareOneBudget() throws { + let value = BudgetSiblings( + left: (0..<16).map { BudgetNode(id: Int32($0)) }, + right: (16..<32).map { BudgetNode(id: Int32($0)) } + ) + let bytes = try makeBudgetFory().serialize(value) + let oneList = listBudget(BudgetNode.self, count: 16, elementOwnerBytes: budgetNodeGraphBytes) + let required = ownerBytes(BudgetSiblings.self) + oneList * 2 + + expectInvalidData { + let _: BudgetSiblings = try makeBudgetFory(maxGraphMemoryBytes: Int64(required - 1)) + .deserialize(bytes) + } + let decoded: BudgetSiblings = try makeBudgetFory(maxGraphMemoryBytes: Int64(required)) + .deserialize(bytes) + #expect(decoded.left.count == 16) + #expect(decoded.right.count == 16) +} + +@Test +func nestedEmptyArraysChargeOwner() throws { + let count = 3 + let value = Array(repeating: [String](), count: count) + let bytes = try makeBudgetFory().serialize(value) + let required = listBudget([String].self, count: count) + count * ownerBytes([String].self) + + expectInvalidData { + let _: [[String]] = try makeBudgetFory(maxGraphMemoryBytes: Int64(required - 1)) + .deserialize(bytes) + } + let decoded: [[String]] = try makeBudgetFory(maxGraphMemoryBytes: Int64(required)) + .deserialize(bytes) + #expect(decoded == value) +} + +@Test +func mapBudgetIsCharged() throws { + let value: [String: Int32] = ["a": 1, "b": 2, "c": 3] + let bytes = try makeBudgetFory().serialize(value) + let required = rootMapBudget(key: String.self, value: Int32.self, count: value.count) + + expectInvalidData { + let _: [String: Int32] = try makeBudgetFory(maxGraphMemoryBytes: Int64(required - 1)) + .deserialize(bytes) + } + let decoded: [String: Int32] = try makeBudgetFory(maxGraphMemoryBytes: Int64(required)) + .deserialize(bytes) + #expect(decoded == value) +} + +@Test +func emptyTypedMapOwnerIsCharged() throws { + let value: [String: Int32] = [:] + let bytes = try makeBudgetFory().serialize(value) + let required = rootMapBudget(key: String.self, value: Int32.self, count: value.count) + + expectInvalidData { + let _: [String: Int32] = try makeBudgetFory(maxGraphMemoryBytes: Int64(required - 1)) + .deserialize(bytes) + } + let decoded: [String: Int32] = try makeBudgetFory(maxGraphMemoryBytes: Int64(required)) + .deserialize(bytes) + #expect(decoded == value) +} + +@Test +func arrayInlineValueBudget() throws { + let nodes = (0..<4).map { BudgetNode(id: Int32($0)) } + let nodeBytes = try makeBudgetFory().serialize(nodes) + let nodeBudget = rootArrayBudget( + BudgetNode.self, + count: nodes.count, + elementOwnerBytes: budgetNodeGraphBytes + ) + expectInvalidData { + let _: [BudgetNode] = try makeBudgetFory(maxGraphMemoryBytes: Int64(nodeBudget - 1)) + .deserialize(nodeBytes) + } + let decodedNodes: [BudgetNode] = try makeBudgetFory(maxGraphMemoryBytes: Int64(nodeBudget)) + .deserialize(nodeBytes) + #expect(decodedNodes.count == nodes.count) + + let ints: [Int32] = [1, 2, 3, 4] + let intBytes = try makeBudgetFory().serialize(ints) + let intBudget = rootArrayBudget(Int32.self, count: ints.count) + expectInvalidData { + let _: [Int32] = try makeBudgetFory(maxGraphMemoryBytes: Int64(intBudget - 1)) + .deserialize(intBytes) + } + #expect(try makeBudgetFory(maxGraphMemoryBytes: Int64(intBudget)).deserialize(intBytes) == ints) +} + +@Test +func setConversionOwnerChargedOnce() throws { + let values: Set = [1, 2, 3] + let bytes = try makeBudgetFory().serialize(values) + let required = ownerBytes(Set.self) + arrayBudget(Int32.self, count: values.count) + + expectInvalidData { + let _: Set = try makeBudgetFory(maxGraphMemoryBytes: Int64(required - 1)) + .deserialize(bytes) + } + let decoded: Set = try makeBudgetFory(maxGraphMemoryBytes: Int64(required)) + .deserialize(bytes) + #expect(decoded == values) +} + +@Test +func denseLeafOwnersSkipped() throws { + let value = BudgetDenseHolder( + text: "budget", + data: Data([1, 2, 3]), + dense: [1, 2, 3] + ) + let bytes = try makeBudgetFory().serialize(value) + let required = ownerBytes(BudgetDenseHolder.self) + + expectInvalidData { + let _: BudgetDenseHolder = try makeBudgetFory(maxGraphMemoryBytes: Int64(required - 1)) + .deserialize(bytes) + } + let decoded: BudgetDenseHolder = try makeBudgetFory(maxGraphMemoryBytes: Int64(required)) + .deserialize(bytes) + #expect(decoded == value) +} + +@Test +func dynamicAnyEmptyMapOwnerSelf() throws { + let value = [:] as [AnyHashable: Any] + let bytes = try makeBudgetFory().serialize(value as Any) + let required = + dictionaryBudget(key: AnyHashable.self, value: SerializableAny.self, count: value.count) + + ownerBytes(Dictionary.self) + + ownerBytes(Dictionary.self) + + expectInvalidData { + let _: Any = try makeBudgetFory(maxGraphMemoryBytes: Int64(required - 1)) + .deserialize(bytes) + } + let decoded: Any = try makeBudgetFory(maxGraphMemoryBytes: Int64(required)) + .deserialize(bytes) + #expect((decoded as? [String: Any])?.isEmpty == true) +} + +@Test +func publicAnyArrayBudget() throws { + let value: [Any] = [Int32(1), Int32(2), Int32(3)] + let bytes = try makeBudgetFory().serialize(value) + let wrappedBudget = listBudget(SerializableAny.self, count: value.count) + let finalBudget = ownerBytes([Any].self) + value.count * testReferenceBytes + + expectInvalidData { + let _: [Any] = try makeBudgetFory(maxGraphMemoryBytes: Int64(wrappedBudget)) + .deserialize(bytes, as: [Any].self) + } + let decoded = try makeBudgetFory(maxGraphMemoryBytes: Int64(wrappedBudget + finalBudget)) + .deserialize(bytes, as: [Any].self) + #expect(decoded.count == value.count) +} + +@Test +func publicAnyMapBudget() throws { + let stringMap: [String: Any] = ["a": Int32(1), "b": Int32(2), "c": Int32(3)] + let stringBytes = try makeBudgetFory().serialize(stringMap) + let stringWrapped = dictionaryBudget( + key: String.self, + value: SerializableAny.self, + count: stringMap.count + ) + let stringFinal = + ownerBytes(Dictionary.self) + stringMap.count * 2 * testReferenceBytes + expectInvalidData { + let _: [String: Any] = try makeBudgetFory(maxGraphMemoryBytes: Int64(stringWrapped)) + .deserialize(stringBytes, as: [String: Any].self) + } + let decodedString = try makeBudgetFory(maxGraphMemoryBytes: Int64(stringWrapped + stringFinal)) + .deserialize(stringBytes, as: [String: Any].self) + #expect(decodedString.count == stringMap.count) + + let intMap: [Int32: Any] = [1: Int32(10), 2: Int32(20), 3: Int32(30)] + let intBytes = try makeBudgetFory().serialize(intMap) + let intWrapped = dictionaryBudget( + key: Int32.self, + value: SerializableAny.self, + count: intMap.count + ) + let intFinal = ownerBytes(Dictionary.self) + intMap.count * 2 * testReferenceBytes + expectInvalidData { + let _: [Int32: Any] = try makeBudgetFory(maxGraphMemoryBytes: Int64(intWrapped)) + .deserialize(intBytes, as: [Int32: Any].self) + } + let decodedInt = try makeBudgetFory(maxGraphMemoryBytes: Int64(intWrapped + intFinal)) + .deserialize(intBytes, as: [Int32: Any].self) + #expect(decodedInt.count == intMap.count) + + let anyHashableMap: [AnyHashable: Any] = [ + AnyHashable("a"): Int32(1), + AnyHashable(Int32(2)): Int32(2), + AnyHashable(true): Int32(3) + ] + let anyHashableBytes = try makeBudgetFory().serialize(anyHashableMap) + let anyHashableWrapped = dictionaryBudget( + key: AnyHashable.self, + value: SerializableAny.self, + count: anyHashableMap.count + ) + let anyHashableFinal = + ownerBytes(Dictionary.self) + anyHashableMap.count * 2 * testReferenceBytes + expectInvalidData { + let _: [AnyHashable: Any] = try makeBudgetFory( + maxGraphMemoryBytes: Int64(anyHashableWrapped) + ).deserialize(anyHashableBytes, as: [AnyHashable: Any].self) + } + let decodedAnyHashable = try makeBudgetFory( + maxGraphMemoryBytes: Int64(anyHashableWrapped + anyHashableFinal) + ).deserialize(anyHashableBytes, as: [AnyHashable: Any].self) + #expect(decodedAnyHashable.count == anyHashableMap.count) +} + +@Test +func dynamicAnyArrayBudget() throws { + let list: [Any] = [Int32(1), "two", Int32(3)] + let value: Any = list + let bytes = try makeBudgetFory().serialize(value) + let count = list.count + let wrappedBudget = listBudget(SerializableAny.self, count: count) + let finalBudget = ownerBytes([Any].self) + count * testReferenceBytes + + expectInvalidData { + let _: Any = try makeBudgetFory(maxGraphMemoryBytes: Int64(wrappedBudget)) + .deserialize(bytes, as: Any.self) + } + let decoded = try makeBudgetFory(maxGraphMemoryBytes: Int64(wrappedBudget + finalBudget)) + .deserialize(bytes, as: Any.self) + #expect((decoded as? [Any])?.count == count) +} + +@Test +func compatibleDenseArraySkip() throws { + let writer = makeCompatibleBudgetFory() + writer.register(BudgetListDenseWriter.self, id: 9804) + let reader = makeCompatibleBudgetFory( + maxGraphMemoryBytes: Int64(ownerBytes(BudgetListDenseReader.self)) + ) + reader.register(BudgetListDenseReader.self, id: 9804) + let bytes = try writer.serialize(BudgetListDenseWriter(dense: [1, 2, 3])) + + expectInvalidData { + let failingReader = makeCompatibleBudgetFory( + maxGraphMemoryBytes: Int64(ownerBytes(BudgetListDenseReader.self) - 1) + ) + failingReader.register(BudgetListDenseReader.self, id: 9804) + let _: BudgetListDenseReader = try failingReader.deserialize(bytes) + } + let decoded: BudgetListDenseReader = try reader.deserialize(bytes) + #expect(decoded.dense == [1, 2, 3]) +} + +@Test +func byteCheckRejectsLargeLength() throws { + let buffer = ByteBuffer() + buffer.writeVarUInt32(64) + buffer.writeUInt8(CollectionHeader.sameType | CollectionHeader.declaredElementType) + let config = Config(trackRef: false, compatible: false) + let context = ReadContext( + buffer: buffer, + typeResolver: TypeResolver(config: config), + config: config + ) + + expectInvalidData { + let _: [String] = try [String].foryReadData(context) + } +}