From a01e79babfab51643fac2c467f047ef039547807 Mon Sep 17 00:00:00 2001 From: chaokunyang Date: Sat, 27 Jun 2026 01:45:02 +0800 Subject: [PATCH 01/54] feat: add container memory budget --- .agents/languages/cpp.md | 7 + .agents/languages/csharp.md | 2 + .agents/languages/dart.md | 1 + .agents/languages/go.md | 8 + .agents/languages/java.md | 9 + .agents/languages/javascript.md | 7 + .agents/languages/python.md | 6 + .agents/languages/rust.md | 10 + .agents/languages/swift.md | 2 + AGENTS.md | 16 +- cpp/fory/serialization/BUILD | 10 + cpp/fory/serialization/CMakeLists.txt | 5 + .../serialization/collection_serializer.h | 221 +- cpp/fory/serialization/config.h | 4 + .../container_memory_budget_test.cc | 265 ++ cpp/fory/serialization/context.cc | 40 + cpp/fory/serialization/context.h | 108 + cpp/fory/serialization/fory.h | 70 +- cpp/fory/serialization/map_serializer.h | 31 + cpp/fory/serialization/struct_serializer.h | 20 +- cpp/fory/serialization/union_serializer.h | 12 +- .../src/Fory.Generator/ForyModelGenerator.cs | 2 + csharp/src/Fory/CollectionSerializers.cs | 21 +- csharp/src/Fory/Config.cs | 32 + csharp/src/Fory/DictionarySerializers.cs | 2 + csharp/src/Fory/Fory.cs | 3 + csharp/src/Fory/NullableKeyDictionary.cs | 2 + .../Fory/PrimitiveDictionarySerializers.cs | 2 + csharp/src/Fory/ReadContext.cs | 140 + .../Fory.Tests/ContainerMemoryBudgetTests.cs | 188 ++ dart/packages/fory/lib/src/config.dart | 11 + .../fory/lib/src/context/read_context.dart | 95 + dart/packages/fory/lib/src/fory.dart | 10 + .../serializer/collection_serializers.dart | 134 +- .../lib/src/serializer/map_serializers.dart | 118 +- .../test/container_memory_budget_test.dart | 280 ++ docs/guide/cpp/configuration.md | 26 + docs/guide/csharp/configuration.md | 17 + docs/guide/dart/configuration.md | 25 + docs/guide/go/configuration.md | 23 + docs/guide/java/configuration.md | 5 + docs/guide/javascript/configuration.md | 24 + docs/guide/python/configuration.md | 8 + docs/guide/rust/configuration.md | 32 + docs/guide/swift/configuration.md | 14 +- docs/security/deserialization.md | 48 +- .../xlang_implementation_guide.md | 25 +- go/fory/README.md | 4 + go/fory/array.go | 5 +- go/fory/codegen/decoder.go | 37 + go/fory/codegen/generator.go | 31 + go/fory/container_memory_budget_test.go | 207 ++ go/fory/field_serializer.go | 5 +- go/fory/fory.go | 22 + go/fory/map.go | 3 + go/fory/map_primitive.go | 160 +- go/fory/reader.go | 268 +- go/fory/set.go | 6 + go/fory/slice.go | 8 + go/fory/slice_dyn.go | 54 +- go/fory/slice_primitive.go | 3 + go/fory/slice_primitive_list.go | 55 +- go/fory/stream.go | 11 + go/fory/tests/structs_fory_gen.go | 72 +- go/fory/type_resolver.go | 4 +- .../src/main/java/org/apache/fory/Fory.java | 27 +- .../java/org/apache/fory/config/Config.java | 9 + .../org/apache/fory/config/ForyBuilder.java | 17 + .../org/apache/fory/context/ReadContext.java | 85 +- .../org/apache/fory/memory/MemoryBuffer.java | 7 + .../fory/serializer/ArraySerializers.java | 40 +- .../CompatibleCollectionArrayReader.java | 28 +- .../collection/ChildContainerSerializers.java | 6 +- .../collection/CollectionLikeSerializer.java | 6 +- .../collection/CollectionSerializers.java | 37 +- .../GuavaCollectionSerializers.java | 14 +- .../ImmutableCollectionSerializers.java | 6 +- .../collection/MapLikeSerializer.java | 6 +- .../serializer/collection/MapSerializers.java | 16 +- .../collection/SubListSerializers.java | 2 +- .../org/apache/fory/memory/MemoryBuffer.java | 5 + .../java/org/apache/fory/ForyTestBase.java | 2 +- .../fory/io/MemoryBufferObjectInputTest.java | 2 +- .../fory/io/MemoryBufferObjectOutputTest.java | 2 +- .../fory/resolver/ClassResolverTest.java | 6 +- .../fory/serializer/ArraySerializersTest.java | 16 +- .../serializer/CompatibleSerializerTest.java | 20 +- .../serializer/ContainerMemoryBudgetTest.java | 318 +++ .../serializer/ExceptionSerializersTest.java | 4 +- .../serializer/PrimitiveSerializersTest.java | 4 +- .../ChildContainerSerializersTest.java | 2 +- .../collection/CollectionSerializersTest.java | 2 +- javascript/packages/core/lib/context.ts | 457 ++- javascript/packages/core/lib/fory.ts | 12 + .../packages/core/lib/gen/collection.ts | 30 + javascript/packages/core/lib/gen/map.ts | 2 + javascript/packages/core/lib/type.ts | 1 + javascript/test/containerMemoryBudget.test.ts | 225 ++ .../serializer/kotlin/CollectionSerializer.kt | 10 +- .../kotlin/CollectionSerializerTest.kt | 18 + python/pyfory/_fory.py | 14 + python/pyfory/collection.pxi | 64 +- python/pyfory/collection.py | 6 + python/pyfory/context.pxi | 81 + python/pyfory/context.py | 63 + python/pyfory/serialization.pyx | 36 + python/pyfory/serializer.py | 1 + .../tests/test_container_memory_budget.py | 220 ++ rust/fory-core/src/config.rs | 10 + rust/fory-core/src/context.rs | 143 + rust/fory-core/src/fory.rs | 29 +- rust/fory-core/src/serializer/codec.rs | 7 +- rust/fory-core/src/serializer/collection.rs | 18 +- rust/fory-core/src/serializer/map.rs | 12 +- rust/tests/tests/mod.rs | 1 + .../tests/test_container_memory_budget.rs | 244 ++ .../scala/CollectionSerializer.scala | 7 +- .../fory/serializer/scala/MapSerializer.scala | 7 +- .../scala/XlangCollectionSerializer.scala | 12 +- .../scala/CollectionSerializerTest.scala | 32 + .../scala/ScalaXlangSerializerTest.scala | 20 + swift/Sources/Fory/AnySerializer.swift | 12 +- .../Sources/Fory/CollectionSerializers.swift | 56 +- swift/Sources/Fory/FieldCodecs.swift | 52 +- swift/Sources/Fory/Fory.swift | 1043 ++++--- swift/Sources/Fory/ReadContext.swift | 1592 ++++++----- .../ContainerMemoryBudgetTests.swift | 232 ++ swift/Tests/ForyTests/ForySwiftTests.swift | 2513 +++++++++-------- 128 files changed, 7918 insertions(+), 3149 deletions(-) create mode 100644 cpp/fory/serialization/container_memory_budget_test.cc create mode 100644 csharp/tests/Fory.Tests/ContainerMemoryBudgetTests.cs create mode 100644 dart/packages/fory/test/container_memory_budget_test.dart create mode 100644 go/fory/container_memory_budget_test.go create mode 100644 java/fory-core/src/test/java/org/apache/fory/serializer/ContainerMemoryBudgetTest.java create mode 100644 javascript/test/containerMemoryBudget.test.ts create mode 100644 python/pyfory/tests/test_container_memory_budget.py create mode 100644 rust/tests/tests/test_container_memory_budget.rs create mode 100644 swift/Tests/ForyTests/ContainerMemoryBudgetTests.swift diff --git a/.agents/languages/cpp.md b/.agents/languages/cpp.md index 8a28fe0d03..f640aa8552 100644 --- a/.agents/languages/cpp.md +++ b/.agents/languages/cpp.md @@ -17,6 +17,13 @@ Load this file when changing `cpp/`, Cython build plumbing, or C++ xlang behavio - Do not redesign alias-based or low-level public type shapes to add convenience methods unless the user explicitly asks for that API change. - For cross-language feature ports, match protocol behavior but use idiomatic C++ ownership and layering instead of mirroring Java structure literally. - Compatible scalar, list-array, and binary/uint8-array adaptations are immediate-field-only. Recursive matched-field comparison for collection elements, array elements, map keys, and map values must require exact nullability, ref tracking, generic arity, and type shape except documented user-type family normalization. +- Root deserialization container budgets are owned by `ReadContext` and initialized by the root + `Fory::deserialize` overload. Keep `max_container_memory_bytes` as `-1 / auto` or a positive + explicit limit; known byte roots use `inputBytes * 8 + 64 KiB`, while stream roots use fixed + `128 MiB`. Reserve estimated container-owned memory before allocation but preserve existing + byte-availability checks and their non-empty metadata ordering. Skip only dedicated string, + binary, primitive vector, and primitive dense-array owners; general `std::vector` for + non-primitive `T` is inline container storage and must be charged. ## Key Paths diff --git a/.agents/languages/csharp.md b/.agents/languages/csharp.md index 098f9a50fe..8785b440e1 100644 --- a/.agents/languages/csharp.md +++ b/.agents/languages/csharp.md @@ -12,6 +12,8 @@ Load this file when changing `csharp/` or C# xlang behavior. - Generated C# gRPC service companions are compiler-owned files that depend on application-provided gRPC packages, not `csharp/src/Fory`. Keep gRPC package references out of the Fory runtime package. - C# generated schema modules are source-file owners. Service companions must use that module's `ThreadSafeFory` and must not introduce namespace-owned aliases or duplicate serializer registration paths. - Compatible scalar, list-array, and binary/uint8-array adaptations are immediate-field-only. Recursive matched-field comparison for collection elements, array elements, map keys, and map values must require exact nullability, ref tracking, generic arity, and type shape except documented user-type family normalization. +- Root deserialization container memory budget state belongs to `ReadContext`. C# public roots are memory-backed today, so auto uses known input length; generated serializers may call `ReadContext`'s generated-code reservation helpers, but should not expose or depend on serializer helper classes such as `CollectionCodec`. +- For C# container budget formulas, distinguish inline value storage from reference storage: use cheap value-type size for `List`/`T[]` value paths and the 4-byte reference fallback for reference paths. Dedicated string, binary, and primitive dense-array serializers stay skipped and rely on byte availability checks. - When extending C# tests from Java references, prioritize xlang spec behavior and the public C# contract before adding complex Java-specific parity cases. ## Commands diff --git a/.agents/languages/dart.md b/.agents/languages/dart.md index d4f6bebbbc..9db2504693 100644 --- a/.agents/languages/dart.md +++ b/.agents/languages/dart.md @@ -14,6 +14,7 @@ Load this file when changing `dart/`. - Keep root numeric wrapper defaults separate from generated field metadata. Root wrapper resolution belongs in the builtin resolver, while annotations and generated metadata choose fixed, tagged, or declared-field encodings. - Dart 64-bit carriers are optimized for each platform. Do not replace native extension-type wrappers with allocation-heavy classes or route web/native hot paths through `BigInt` unless the user approves a representation change. - In `Buffer`, cursor, serializer, and generated-code hot paths, prefer direct byte/local integer operations and conditional import/export files over callbacks, records, holder objects, wrapper round-trips, or runtime platform branches. +- Root deserialization container memory budgets are owned by `ReadContext`; `maxContainerMemoryBytes` defaults to `-1 / auto`, positive explicit values win, and Dart auto uses `buffer.readableBytes * 8 + 64 KiB` because roots are memory-backed. Charge Dart lists, sets, maps, object/reference arrays, compatible list-to-array inline storage, and compatible array-to-list materialization before allocation. Skip only dedicated string, binary, `BoolList`, and typed-array dense owner paths with byte checks. Do not add stream bytes-read accounting, per-element accounting, or extra hot-path allocations for this budget. - Do not add parallel header-low/header-high slot caches or multi-slot recent caches in TypeMeta hot paths to chase benchmark gaps. Header-cache hits must use the concrete checked cache owner directly; if a hit hint is needed, cache one TypeInfo/TypeMeta object and compare the validated header identity on that object, not separate low/high header fields or benchmark-pattern state. - If Dart TypeMeta cache ownership changes, keep the invariant in a source comment near the hit path: a checked metadata-cache hit skips the body and must not grow low-bit sentinels, accepted-header fields, parallel header slots, or benchmark-pattern state. - Dart expected-type TypeDef reads should compare the expected `TypeInfo` object's cached local TypeDef header before consulting the parsed-metadata map. A match is a direct local-schema hit: skip the remote body, add the expected type to the per-read shared type table, and do not publish to `ParsedTypeMetaCache`, record a remote schema version, or parse/hash the body. diff --git a/.agents/languages/go.md b/.agents/languages/go.md index 94d47fe94c..949dd5030e 100644 --- a/.agents/languages/go.md +++ b/.agents/languages/go.md @@ -7,6 +7,14 @@ Load this file when changing `go/fory/` or Go xlang behavior. - Run Go commands from within `go/fory/`. - Changes under `go/` must pass formatting and tests. - The Go implementation focuses on reflection-based and codegen-based serialization. +- Root deserialization container memory budgets are owned by `ReadContext`. + `WithMaxContainerMemoryBytes` defaults to `-1 / auto`; byte-slice roots use + `inputBytes * 8 + 64 KiB`, and `DeserializeFromReader`/`DeserializeFromStream` + use fixed `128 MiB`. Charge Go slices, maps, map-backed sets, LIST-encoded + inline/value slices, and generated container reads before allocation. Fixed + arrays are caller-owned and normally not charged; `arrayDynSerializer` charges + its temporary slice. Skip only dedicated string, binary, BufferObject, + primitive ARRAY slice, and primitive array owners with byte checks. - Set `FORY_PANIC_ON_ERROR=1` when debugging a failing Go test so you get the full call stack. - Do not set `FORY_PANIC_ON_ERROR=1` when running the full Go test suite, because some tests assert on error contents. diff --git a/.agents/languages/java.md b/.agents/languages/java.md index 41b19b206d..f60cfb249b 100644 --- a/.agents/languages/java.md +++ b/.agents/languages/java.md @@ -14,6 +14,15 @@ Load this file when changing anything under `java/` or when Java drives a cross- values; use qualified names only when a real name conflict requires it. - If you run temporary tests with `java -cp`, run `mvn -T16 install -DskipTests` first so local Fory jars are current. - `WriteContext`, `ReadContext`, and `CopyContext` must stay explicit. Do not reintroduce `ThreadLocal` or ambient runtime-context patterns. +- Java root deserialization container memory budgeting belongs to `ReadContext` + and is initialized by `Fory` root APIs. Public config is + `maxContainerMemoryBytes` with `-1` auto, positive explicit override, + known-length auto `inputBytes * 8 + 64 KiB`, and stream/unknown auto + `128 MiB`. Collection/map/object-array serializers should charge estimated + container-owned memory before allocation while preserving existing + `checkReadableBytes` guards before backing allocation or capacity + reservation. Do not add nested serializer-path `try/finally`, per-element + work, or dynamic stream bytes-read accounting for this budget. - Generated serializers must not retain runtime context fields. `Fory` should stay a root-operation facade rather than accumulating serializer or convenience state. - When the serializer class and constructor shape are known at the call site, prefer direct constructor lambdas or direct instantiation over reflective `Serializers.newSerializer(...)`. - For GraalVM, use `fory codegen` to generate serializers when building native images. Do not add reflection configuration except for JDK `proxy`. diff --git a/.agents/languages/javascript.md b/.agents/languages/javascript.md index e6d62fa494..4781b5ece2 100644 --- a/.agents/languages/javascript.md +++ b/.agents/languages/javascript.md @@ -14,6 +14,13 @@ Load this file when changing `javascript/`. - Runtime value carriers such as decimal or reduced-precision numeric types belong under the core `types/` ownership boundary, with imports, exports, and codegen externals updated together. - Keep `TypeInfo` as schema metadata. Compatibility-sensitive decisions belong on `TypeResolver` or explicit operations, not as retained resolver state on metadata objects. - Normalize optional boolean config values at config construction; do not carry `null` through runtime paths when it means `false`. +- JavaScript root deserialization container memory budgeting belongs to `ReadContext`. + `maxContainerMemoryBytes` uses `-1` auto, positive explicit limits, and known + `Uint8Array` root length as `inputBytes * 8 + 64 KiB`. Generated and dynamic + list/set/map readers must reserve before allocation while preserving existing + byte checks. Keep dedicated string, binary, and dense typed-array owners out of + this budget; compatible list-to-typed-array reads must charge typed inline + storage. - Regenerated compatible read serializers are remote-schema-specific. After classification marks a field as direct, compatible scalar, or skip, generated JavaScript should emit straight-line remote-field-order code. Do not add an outer matched-id switch unless the current regenerated shape cannot preserve those semantics. - Compatible scalar codegen must decide the exact remote/local scalar pair before emitting source. Generate the concrete `reader.readXxx()` call plus inline trivial conversions such as boolean-to-string or numeric widening, and keep helpers only for semantic validation such as range checks, exactness checks, decimal parsing/formatting, and string-to-bool. Do not call a generic hot-path converter that redispatches on `remoteTypeId`, `localTypeId`, field descriptors, or field names. - Compatible scalar conversion is immediate-field-only. Recursive schema comparison for collection elements, array elements, map keys, and map values must reject scalar mismatches instead of applying the top-level scalar conversion matrix. diff --git a/.agents/languages/python.md b/.agents/languages/python.md index 7365632c37..3ed69c6eb7 100644 --- a/.agents/languages/python.md +++ b/.agents/languages/python.md @@ -13,6 +13,12 @@ Load this file when changing `python/`, Cython serialization, or Python xlang be - Cython mode owns the hot runtime path. Do not duplicate core runtime types between Python and Cython, tunnel Python facade methods into hidden Cython internals, or keep dead shims unless the user explicitly needs a compatibility module path. - Use explicit Cython fields and methods for fixed hot-path shapes. Avoid `__getattr__`, generic `object` fields, public bridge internals, or `Fory` backreferences where ownership can stay explicit. - Keep Python and Cython context/ref-tracking branch conditions and stack mutations semantically aligned unless a documented intentional difference exists. +- Root deserialization container memory budgets are owned by pure-Python and Cython `ReadContext`. + Keep `max_container_memory_bytes` public on `pyfory.Fory`/`Config`; `-1` uses known-length + `inputBytes * 8 + 64 KiB` or fixed `128 MiB` for stream roots. Reserve fixed container cost plus + reference slots for list/tuple/set, map object/table/entry estimates for dict, and object-dtype + ndarray item storage. Keep string, bytes, `array.array`, primitive dense array, and primitive + ndarray owners skipped, and preserve byte-availability checks after budget reservation. - Public value constructors should accept normal Python values. Raw-bit, raw-buffer, and memoryview entry points should be explicit low-level APIs, and packed carriers should expose the buffer protocol from the actual storage owner when appropriate. - When debugging runtime or benchmark behavior, install the local package into the exact interpreter under test instead of relying on mixed `PYTHONPATH` state. - For wheel or extension pipeline changes, derive extension-module paths from current build targets, packaging config, or wheel payload discovery rather than historical module names. diff --git a/.agents/languages/rust.md b/.agents/languages/rust.md index ffe5648330..a31126586c 100644 --- a/.agents/languages/rust.md +++ b/.agents/languages/rust.md @@ -18,6 +18,16 @@ Load this file when changing `rust/` or Rust xlang behavior. - If breakage is explicitly acceptable during a Rust module refactor, rewire macros, tests, and sibling crates directly to the new boundaries instead of adding compatibility re-exports. - For panic-safety in hot paths, preserve TLS context reuse. Add scoped guards or owned fallbacks rather than per-call context allocation, and reset reused contexts at entry and successful exit. - Compatible scalar, list-array, and binary/uint8-array adaptations are immediate-field-only. Keep recursive matched-field shape classification owned by `fory-core/src/meta/type_meta.rs`; collection elements, array elements, map keys, and map values must require exact nullability, ref tracking, generic arity, and type shape except documented user-type family normalization. +- Root deserialization container memory budget state belongs to `ReadContext` and is initialized by + the root `Fory` read methods before the header is consumed. Rust roots are byte-slice/`Reader` + backed, so auto budget uses `inputBytes * 8 + 64 KiB`; do not add dynamic bytes-read accounting. +- Rust `Vec` stores inline element storage, so general LIST paths charge fixed `Vec` cost plus + `len * size_of::()`, including `Vec` and `Vec`. Dedicated primitive dense + ARRAY `Vec` readers, strings, binary, and primitive fixed-array owners stay skipped and keep + their byte checks. +- Direct `Serializer` collection/map paths and derive `Codec` collection/map paths are separate + allocation owners. Keep reservations in both before `Vec::with_capacity`, + `HashMap::with_capacity`, or collection materialization; charge zero-size containers. ## Key Paths diff --git a/.agents/languages/swift.md b/.agents/languages/swift.md index ec493ea1ac..0e2607ac59 100644 --- a/.agents/languages/swift.md +++ b/.agents/languages/swift.md @@ -15,6 +15,8 @@ Load this file when changing `swift/` or Swift xlang behavior. - Preserve distinct temporal semantics. Timestamp values and day-only local dates should have protocol-accurate helper names and no stale aliases after a refactor. - When temporal or public-type refactors touch generated Swift code, sweep message fields, union payloads, macros, xlang harnesses, and integration fixtures together. - Compatible scalar, list-array, and binary/uint8-array adaptations are immediate-field-only. Recursive matched-field comparison for collection elements, array elements, map keys, and map values must require exact nullability, ref tracking, generic arity, and type shape except documented user-type family normalization. +- Root deserialization container memory budget state belongs to `ReadContext`. Swift public roots are `Data` and `ByteBuffer`, so auto uses known root bytes; do not add stream bytes-read accounting or serializer-local budget state. +- For Swift container budget formulas, distinguish inline/value storage from reference storage: use `MemoryLayout.stride` for value arrays/lists/maps and the 4-byte reference fallback for `Serializer.isRefType` / `FieldCodec.isRefType` paths. Dedicated `String`, `Data`/binary, and primitive packed-array owners stay skipped, except compatible packed-array-to-list reads must charge the target list materialization before allocation. ## Commands diff --git a/AGENTS.md b/AGENTS.md index 5cd2581d5c..c8346e0f4f 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -32,7 +32,8 @@ This is the entry point for AI guidance in Apache Fory. Read this file first, th - Respect ownership. Keep logic, state, and helpers in their natural owner, and do not move serializer-local, context-local, runtime-type-local, or protocol-local problems into global utilities. - Check the spec before implementation. For wire behavior and xlang mapping, use the specs as the source of truth and never copy one runtime's bug into another runtime just to make tests pass. - Do not make assumptions about runtime behavior, ownership, registration, metadata construction, protocol semantics, or test coverage. Read the current code, owning docs/specs, and relevant tests before making a design judgment or implementation decision. If the evidence is incomplete, inspect more or state the uncertainty explicitly instead of filling gaps from memory or analogy with another runtime. -- For untrusted deserialization, read `docs/security/deserialization.md` before changing allocation, stream filling, skip, reference, metadata, or policy validation behavior. Variable-length deserialization must not allocate or reserve from attacker-declared lengths or counts before the byte owner has proven proportional readable bytes with `checkReadableBytes` or the runtime equivalent. +- For untrusted deserialization, read `docs/security/deserialization.md` before changing allocation, stream filling, skip, reference, metadata, or policy validation behavior. Variable-length deserialization must not allocate or reserve backing/output capacity from attacker-declared lengths or counts before the byte owner has proven proportional readable bytes with `checkReadableBytes` or the runtime equivalent. Container memory-budget reservation is accounting only and may happen before that byte check, but it must not replace the byte check. +- Root deserialization container memory budgets are estimated container-owned memory, not exact heap accounting and not raw slots. Positive `maxContainerMemoryBytes` wins; `-1` auto uses `inputBytes * 8 + 64 KiB` for known-length roots and fixed `128 MiB` for true stream or unknown-length roots. Charge fixed container cost, backing/reference/inline storage, map table and entry overhead, and zero-size containers; skip only dedicated string, binary, primitive array, and primitive dense-array owners. Do not add dynamic stream bytes-read accounting or nested hot-path cleanup just for this budget. - For remote TypeDef/TypeMeta reads, the checked metadata cache is the only owner of remote "already validated" state. Cache hit means the header was previously parsed, body/hash-validated, policy-checked, and published by that cache, so the hot path must skip the body and use cached metadata without extra validation, hashing, limit checks, exact-local checks, allocation, or policy work. A known expected local TypeDef/TypeMeta header/hash match is a local-schema hit, not a remote cache miss: it may skip the body and use the local TypeInfo/TypeMeta without schema-version counting or cache publish. Cache miss is the only path that parses and validates non-local metadata, enforces limits, performs exact-local byte comparison when needed, and publishes remote metadata to the cache. Do not add nullable accepted-header fields, sentinel headers, per-TypeInfo markers, pending metadata state, parallel header-low/header-high slots, or parallel acceptance state for this decision. If a runtime needs a metadata hit hint, cache the concrete checked metadata owner object, such as the TypeInfo, TypeDef, or TypeMeta used by that runtime, and compare its validated header identity directly. - When a user corrects a non-obvious invariant, encode it in the nearest source comment before continuing, and also update `AGENTS.md`, `.agents/**`, docs, or specs when the rule is reusable beyond one file. Do not rely only on chat history, task notes, commit messages, or benchmark logs for corrections that protect security, protocol behavior, ownership, naming, or hot-path performance. - Reject semantic hacks. Do not bypass broken semantics by deleting cases, simplifying callers, adding coercion hooks, or using workaround fallbacks; fix the underlying bug and prove it with focused tests. @@ -111,6 +112,19 @@ This is the entry point for AI guidance in Apache Fory. Read this file first, th - `docs/DEVELOPMENT.md` plus updates under `docs/guide/` and `docs/benchmarks/` are synced to `apache/fory-site`; other website content belongs there. - When benchmark logic, scripts, configuration, or compared serializers change, rerun the relevant benchmarks and refresh the artifacts under `docs/benchmarks/**`. +## Network Error Command Log + +- 2026-06-26: `cargo check` from `rust/` failed while updating crates.io through + `127.0.0.1:7890`; retried as + `env -u all_proxy -u http_proxy -u https_proxy -u ALL_PROXY -u HTTP_PROXY -u HTTPS_PROXY cargo check`, + which still used the configured proxy. `cargo check --offline` succeeded using the local Cargo + cache. +- 2026-06-26: `cmake -S . -B ../tasks/cpp-cmake-build -DFORY_BUILD_TESTS=ON +-DFORY_BUILD_SHARED=OFF -DFORY_BUILD_STATIC=ON` from `cpp/` failed while FetchContent tried to + clone googletest through `127.0.0.1:7890`; retried as + `env -u all_proxy -u http_proxy -u https_proxy -u ALL_PROXY -u HTTP_PROXY -u HTTPS_PROXY cmake -S . -B ../tasks/cpp-cmake-build -DFORY_BUILD_TESTS=ON -DFORY_BUILD_SHARED=OFF -DFORY_BUILD_STATIC=ON`, + which still used the configured proxy in the nested clone. + ## Shared Engineering Expectations - Favor zero-copy techniques, JIT or codegen opportunities, and cache-friendly memory access patterns in performance-critical paths. diff --git a/cpp/fory/serialization/BUILD b/cpp/fory/serialization/BUILD index b74c356a2b..1102e53f1c 100644 --- a/cpp/fory/serialization/BUILD +++ b/cpp/fory/serialization/BUILD @@ -109,6 +109,16 @@ cc_test( ], ) +cc_test( + name = "container_memory_budget_test", + srcs = ["container_memory_budget_test.cc"], + deps = [ + ":fory_serialization", + "@googletest//:gtest", + "@googletest//:gtest_main", + ], +) + cc_test( name = "variant_serializer_test", srcs = ["variant_serializer_test.cc"], diff --git a/cpp/fory/serialization/CMakeLists.txt b/cpp/fory/serialization/CMakeLists.txt index 0b88f0e5e3..7ff18d0320 100644 --- a/cpp/fory/serialization/CMakeLists.txt +++ b/cpp/fory/serialization/CMakeLists.txt @@ -102,6 +102,11 @@ if(FORY_BUILD_TESTS) target_link_libraries(fory_serialization_map_test fory_serialization GTest::gtest GTest::gtest_main) gtest_discover_tests(fory_serialization_map_test) + add_executable(fory_serialization_container_memory_budget_test container_memory_budget_test.cc) + fory_configure_target(fory_serialization_container_memory_budget_test) + target_link_libraries(fory_serialization_container_memory_budget_test fory_serialization GTest::gtest GTest::gtest_main) + gtest_discover_tests(fory_serialization_container_memory_budget_test) + add_executable(fory_serialization_variant_test variant_serializer_test.cc) fory_configure_target(fory_serialization_variant_test) target_link_libraries(fory_serialization_variant_test fory_serialization GTest::gtest GTest::gtest_main) diff --git a/cpp/fory/serialization/collection_serializer.h b/cpp/fory/serialization/collection_serializer.h index 473f9d6950..7c89e1a265 100644 --- a/cpp/fory/serialization/collection_serializer.h +++ b/cpp/fory/serialization/collection_serializer.h @@ -380,6 +380,34 @@ struct has_reserve inline constexpr bool has_reserve_v = has_reserve::value; +constexpr size_t kContainerEntryOverheadBytes = 16; +constexpr size_t kContainerReferenceBytes = sizeof(void *); + +template +struct is_std_vector_container : std::false_type {}; + +template +struct is_std_vector_container> : std::true_type {}; + +template +inline constexpr bool is_std_vector_container_v = + is_std_vector_container::value; + +template +constexpr size_t collection_element_memory_bytes() { + using Elem = typename Container::value_type; + if constexpr (is_std_vector_container_v) { + return sizeof(Elem); + } else { + static_assert(sizeof(Elem) <= std::numeric_limits::max() - + kContainerEntryOverheadBytes - + kContainerReferenceBytes * 2, + "container element memory estimate overflows"); + return sizeof(Elem) + kContainerEntryOverheadBytes + + kContainerReferenceBytes * 2; + } +} + template inline bool reserve_collection(Container &result, ReadContext &ctx, uint32_t length) { @@ -388,6 +416,12 @@ inline bool reserve_collection(Container &result, ReadContext &ctx, if (FORY_PREDICT_FALSE(ctx.has_error())) { return false; } + constexpr size_t fixed_bytes = sizeof(Container); + constexpr size_t elem_bytes = collection_element_memory_bytes(); + if (FORY_PREDICT_FALSE((!ctx.template reserve_counted_container_memory< + fixed_bytes, elem_bytes>(length)))) { + return false; + } if (FORY_PREDICT_FALSE(!ctx.buffer().ensure_readable(length, ctx.error()))) { return false; } @@ -397,6 +431,14 @@ inline bool reserve_collection(Container &result, ReadContext &ctx, return true; } +template +inline bool reserve_empty_collection(ReadContext &ctx) { + if (FORY_PREDICT_FALSE(ctx.has_error())) { + return false; + } + return ctx.reserve_container_memory(sizeof(Container)); +} + // Helper to insert element into container (vector or set) template inline void collection_insert(Container &result, T &&elem) { @@ -412,9 +454,9 @@ template inline Container read_collection_data_slow(ReadContext &ctx, uint32_t length) { Container result; if (length == 0) { - return result; - } - if (FORY_PREDICT_FALSE(!reserve_collection(result, ctx, length))) { + if (FORY_PREDICT_FALSE(!reserve_empty_collection(ctx))) { + return result; + } return result; } @@ -443,6 +485,10 @@ inline Container read_collection_data_slow(ReadContext &ctx, uint32_t length) { } } + if (FORY_PREDICT_FALSE(!reserve_collection(result, ctx, length))) { + return result; + } + // Read elements if (is_same_type) { if (track_ref) { @@ -922,16 +968,20 @@ struct Serializer< return std::vector(); } - // Per xlang spec: header and type_info are omitted when length is 0 - if (length == 0) { - return std::vector(); - } - // Dispatch to slow path for polymorphic/shared-ref elements constexpr bool is_slow_path = is_polymorphic_v || is_shared_ref_v; if constexpr (is_slow_path) { return read_collection_data_slow>(ctx, length); } else { + std::vector result; + // Per xlang spec: header and type_info are omitted when length is 0 + if (length == 0) { + if (FORY_PREDICT_FALSE( + (!reserve_empty_collection>(ctx)))) { + return result; + } + return result; + } // Fast path for non-polymorphic, non-shared-ref elements // Elements header bitmap (CollectionFlags) @@ -961,7 +1011,6 @@ struct Serializer< } } - std::vector result; if (FORY_PREDICT_FALSE(!reserve_collection(result, ctx, length))) { return result; } @@ -1058,6 +1107,10 @@ struct Serializer< std::vector result; if (size == 0) { + if (FORY_PREDICT_FALSE( + (!reserve_empty_collection>(ctx)))) { + return result; + } return result; } if (FORY_PREDICT_FALSE(!reserve_collection(result, ctx, size))) { @@ -1217,16 +1270,20 @@ template struct Serializer> { return std::list(); } - // Per xlang spec: header and type_info are omitted when length is 0 - if (length == 0) { - return std::list(); - } - // Dispatch to slow path for polymorphic/shared-ref elements constexpr bool is_slow_path = is_polymorphic_v || is_shared_ref_v; if constexpr (is_slow_path) { return read_collection_data_slow>(ctx, length); } else { + std::list result; + // Per xlang spec: header and type_info are omitted when length is 0 + if (length == 0) { + if (FORY_PREDICT_FALSE( + (!reserve_empty_collection>(ctx)))) { + return result; + } + return result; + } // Fast path for non-polymorphic, non-shared-ref elements // Elements header bitmap (CollectionFlags) @@ -1256,7 +1313,9 @@ template struct Serializer> { } } - std::list result; + if (FORY_PREDICT_FALSE(!reserve_collection(result, ctx, length))) { + return result; + } // Fast path: no tracking, no nulls, elements have declared type if (!track_ref && !has_null && is_same_type) { @@ -1349,6 +1408,16 @@ template struct Serializer> { } std::list result; + if (size == 0) { + if (FORY_PREDICT_FALSE( + (!reserve_empty_collection>(ctx)))) { + return result; + } + return result; + } + if (FORY_PREDICT_FALSE(!reserve_collection(result, ctx, size))) { + return result; + } for (uint32_t i = 0; i < size; ++i) { if (FORY_PREDICT_FALSE(ctx.has_error())) { return result; @@ -1409,16 +1478,20 @@ template struct Serializer> { return std::deque(); } - // Per xlang spec: header and type_info are omitted when length is 0 - if (length == 0) { - return std::deque(); - } - // Dispatch to slow path for polymorphic/shared-ref elements constexpr bool is_slow_path = is_polymorphic_v || is_shared_ref_v; if constexpr (is_slow_path) { return read_collection_data_slow>(ctx, length); } else { + std::deque result; + // Per xlang spec: header and type_info are omitted when length is 0 + if (length == 0) { + if (FORY_PREDICT_FALSE( + (!reserve_empty_collection>(ctx)))) { + return result; + } + return result; + } // Fast path for non-polymorphic, non-shared-ref elements // Elements header bitmap (CollectionFlags) @@ -1448,7 +1521,9 @@ template struct Serializer> { } } - std::deque result; + if (FORY_PREDICT_FALSE(!reserve_collection(result, ctx, length))) { + return result; + } // Fast path: no tracking, no nulls, elements have declared type if (!track_ref && !has_null && is_same_type) { @@ -1541,6 +1616,16 @@ template struct Serializer> { } std::deque result; + if (size == 0) { + if (FORY_PREDICT_FALSE( + (!reserve_empty_collection>(ctx)))) { + return result; + } + return result; + } + if (FORY_PREDICT_FALSE(!reserve_collection(result, ctx, size))) { + return result; + } for (uint32_t i = 0; i < size; ++i) { if (FORY_PREDICT_FALSE(ctx.has_error())) { return result; @@ -1602,9 +1687,14 @@ struct Serializer> { return std::forward_list(); } + std::forward_list result; // Per xlang spec: header and type_info are omitted when length is 0 if (length == 0) { - return std::forward_list(); + if (FORY_PREDICT_FALSE( + (!reserve_empty_collection>(ctx)))) { + return result; + } + return result; } // Dispatch to slow path for polymorphic/shared-ref elements @@ -1620,7 +1710,7 @@ struct Serializer> { // Elements header bitmap (CollectionFlags) uint8_t bitmap = ctx.read_uint8(ctx.error()); if (FORY_PREDICT_FALSE(ctx.has_error())) { - return std::forward_list(); + return result; } bool track_ref = (bitmap & COLL_TRACKING_REF) != 0; bool has_null = (bitmap & COLL_HAS_NULL) != 0; @@ -1632,7 +1722,7 @@ struct Serializer> { if (is_same_type && !is_decl_type) { const TypeInfo *elem_type_info = ctx.read_any_type_info(ctx.error()); if (FORY_PREDICT_FALSE(ctx.has_error())) { - return std::forward_list(); + return result; } using ElemType = nullable_element_t; uint32_t expected = @@ -1644,8 +1734,12 @@ struct Serializer> { } } + if (FORY_PREDICT_FALSE(!reserve_collection(result, ctx, length))) { + return result; + } + if (FORY_PREDICT_FALSE(!reserve_collection(temp, ctx, length))) { - return std::forward_list(); + return result; } // Fast path: no tracking, no nulls, elements have declared type if (!track_ref && !has_null && is_same_type) { @@ -1688,7 +1782,6 @@ struct Serializer> { } // Build forward_list in reverse order using push_front - std::forward_list result; for (auto it = temp.rbegin(); it != temp.rend(); ++it) { result.push_front(std::move(*it)); } @@ -1968,9 +2061,20 @@ struct Serializer> { return std::forward_list(); } + std::forward_list result; + if (size == 0) { + if (FORY_PREDICT_FALSE( + (!reserve_empty_collection>(ctx)))) { + return result; + } + return result; + } + if (FORY_PREDICT_FALSE(!reserve_collection(result, ctx, size))) { + return result; + } std::vector temp; if (FORY_PREDICT_FALSE(!reserve_collection(temp, ctx, size))) { - return std::forward_list(); + return result; } for (uint32_t i = 0; i < size; ++i) { if (FORY_PREDICT_FALSE(ctx.has_error())) { @@ -1980,7 +2084,6 @@ struct Serializer> { temp.push_back(std::move(elem)); } // Build forward_list in reverse order - std::forward_list result; for (auto it = temp.rbegin(); it != temp.rend(); ++it) { result.push_front(std::move(*it)); } @@ -2069,16 +2172,20 @@ struct Serializer> { return std::set(); } - // Per xlang spec: header and type_info are omitted when length is 0 - if (size == 0) { - return std::set(); - } - // Dispatch to slow path for polymorphic/shared-ref elements constexpr bool is_slow_path = is_polymorphic_v || is_shared_ref_v; if constexpr (is_slow_path) { return read_collection_data_slow>(ctx, size); } else { + std::set result; + // Per xlang spec: header and type_info are omitted when length is 0 + if (size == 0) { + if (FORY_PREDICT_FALSE( + (!reserve_empty_collection>(ctx)))) { + return result; + } + return result; + } // Fast path for non-polymorphic, non-shared-ref elements // Read elements header bitmap (CollectionFlags) in xlang mode @@ -2094,17 +2201,20 @@ struct Serializer> { if (is_same_type && !is_decl_type) { const TypeInfo *elem_type_info = ctx.read_any_type_info(ctx.error()); if (FORY_PREDICT_FALSE(ctx.has_error())) { - return std::set(); + return result; } uint32_t expected = static_cast(Serializer::type_id); if (!type_id_matches(elem_type_info->type_id, expected)) { ctx.set_error( Error::type_mismatch(elem_type_info->type_id, expected)); - return std::set(); + return result; } } - std::set result; + if (FORY_PREDICT_FALSE(!reserve_collection(result, ctx, size))) { + return result; + } + if (!track_ref && !has_null && is_same_type) { for (uint32_t i = 0; i < size; ++i) { if (FORY_PREDICT_FALSE(ctx.has_error())) { @@ -2151,6 +2261,16 @@ struct Serializer> { } std::set result; + if (size == 0) { + if (FORY_PREDICT_FALSE( + (!reserve_empty_collection>(ctx)))) { + return result; + } + return result; + } + if (FORY_PREDICT_FALSE(!reserve_collection(result, ctx, size))) { + return result; + } for (uint32_t i = 0; i < size; ++i) { if (FORY_PREDICT_FALSE(ctx.has_error())) { return result; @@ -2244,17 +2364,22 @@ struct Serializer> { return std::unordered_set(); } - // Per xlang spec: header and type_info are omitted when length is 0 - if (size == 0) { - return std::unordered_set(); - } - // Dispatch to slow path for polymorphic/shared-ref elements constexpr bool is_slow_path = is_polymorphic_v || is_shared_ref_v; if constexpr (is_slow_path) { return read_collection_data_slow>(ctx, size); } else { + std::unordered_set result; + // Per xlang spec: header and type_info are omitted when length is 0 + if (size == 0) { + if (FORY_PREDICT_FALSE( + (!reserve_empty_collection>( + ctx)))) { + return result; + } + return result; + } // Fast path for non-polymorphic, non-shared-ref elements // Read elements header bitmap (CollectionFlags) in xlang mode @@ -2270,20 +2395,20 @@ struct Serializer> { if (is_same_type && !is_decl_type) { const TypeInfo *elem_type_info = ctx.read_any_type_info(ctx.error()); if (FORY_PREDICT_FALSE(ctx.has_error())) { - return std::unordered_set(); + return result; } uint32_t expected = static_cast(Serializer::type_id); if (!type_id_matches(elem_type_info->type_id, expected)) { ctx.set_error( Error::type_mismatch(elem_type_info->type_id, expected)); - return std::unordered_set(); + return result; } } - std::unordered_set result; if (FORY_PREDICT_FALSE(!reserve_collection(result, ctx, size))) { return result; } + if (!track_ref && !has_null && is_same_type) { for (uint32_t i = 0; i < size; ++i) { if (FORY_PREDICT_FALSE(ctx.has_error())) { @@ -2330,6 +2455,14 @@ struct Serializer> { } std::unordered_set result; + if (size == 0) { + if (FORY_PREDICT_FALSE( + (!reserve_empty_collection>( + ctx)))) { + return result; + } + return result; + } if (FORY_PREDICT_FALSE(!reserve_collection(result, ctx, size))) { return result; } diff --git a/cpp/fory/serialization/config.h b/cpp/fory/serialization/config.h index eac9c14436..a59b575f71 100644 --- a/cpp/fory/serialization/config.h +++ b/cpp/fory/serialization/config.h @@ -52,6 +52,10 @@ struct Config { /// When enabled, avoids duplicating shared objects and handles cycles. bool track_ref = true; + /// Maximum estimated container-owned memory accepted during one root + /// deserialization. `-1` selects the automatic input-shaped limit. + int64_t max_container_memory_bytes = -1; + /// Maximum accepted field count in one received struct TypeMeta. uint32_t max_type_fields = 512; diff --git a/cpp/fory/serialization/container_memory_budget_test.cc b/cpp/fory/serialization/container_memory_budget_test.cc new file mode 100644 index 0000000000..781e9312bc --- /dev/null +++ b/cpp/fory/serialization/container_memory_budget_test.cc @@ -0,0 +1,265 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "fory/serialization/fory.h" +#include "gtest/gtest.h" +#include +#include +#include +#include +#include +#include + +namespace fory { +namespace serialization { +namespace { + +constexpr size_t kKnownBudgetSlack = 64 * 1024; + +struct BudgetItem { + int32_t id = 0; + std::string name; + + bool operator==(const BudgetItem &other) const { + return id == other.id && name == other.name; + } + + FORY_STRUCT(BudgetItem, id, name); +}; + +struct BudgetSiblings { + std::vector left; + std::vector right; + + bool operator==(const BudgetSiblings &other) const { + return left == other.left && right == other.right; + } + + FORY_STRUCT(BudgetSiblings, left, right); +}; + +template +auto with_fory(int64_t max_container_memory_bytes, Fn &&fn) { + auto fory = Fory::builder() + .xlang(true) + .compatible(false) + .track_ref(false) + .max_container_memory_bytes(max_container_memory_bytes) + .build(); + fory.register_struct(1); + fory.register_struct(2); + return std::forward(fn)(fory); +} + +template std::vector serialize_value(const T &value) { + auto bytes = with_fory(-1, [&](Fory &fory) { return fory.serialize(value); }); + EXPECT_TRUE(bytes.ok()) << bytes.error().to_string(); + return std::move(bytes).value(); +} + +size_t nested_empty_budget(size_t count) { + using Inner = std::vector; + using Outer = std::vector; + return sizeof(Outer) + count * sizeof(Inner) + count * sizeof(Inner); +} + +TEST(ContainerMemoryBudgetTest, KnownLengthAutoBudget) { + constexpr size_t count = 3000; + std::vector> value(count); + auto bytes = serialize_value(value); + const size_t auto_limit = bytes.size() * 8 + kKnownBudgetSlack; + const size_t required = nested_empty_budget(count); + ASSERT_GT(required, auto_limit); + + auto default_result = with_fory(-1, [&](Fory &fory) { + return fory.deserialize>>(bytes); + }); + ASSERT_FALSE(default_result.ok()); + EXPECT_EQ(default_result.error().code(), ErrorCode::InvalidData); + + auto explicit_auto_result = + with_fory(static_cast(auto_limit), [&](Fory &fory) { + return fory.deserialize>>(bytes); + }); + ASSERT_FALSE(explicit_auto_result.ok()); + EXPECT_EQ(explicit_auto_result.error().code(), ErrorCode::InvalidData); + + auto explicit_result = + with_fory(static_cast(required), [&](Fory &fory) { + return fory.deserialize>>(bytes); + }); + ASSERT_TRUE(explicit_result.ok()) << explicit_result.error().to_string(); + EXPECT_EQ(explicit_result.value(), value); +} + +TEST(ContainerMemoryBudgetTest, StreamAutoBudget) { + constexpr size_t count = 10000; + std::vector> value(count); + auto bytes = serialize_value(value); + const size_t known_limit = bytes.size() * 8 + kKnownBudgetSlack; + ASSERT_GT(nested_empty_budget(count), known_limit); + + auto known_result = with_fory(-1, [&](Fory &fory) { + return fory.deserialize>>(bytes); + }); + ASSERT_FALSE(known_result.ok()); + EXPECT_EQ(known_result.error().code(), ErrorCode::InvalidData); + + std::string input(reinterpret_cast(bytes.data()), bytes.size()); + std::istringstream source(input); + StdInputStream stream(source, 8); + auto stream_result = with_fory(-1, [&](Fory &fory) { + return fory.deserialize>>(stream); + }); + ASSERT_TRUE(stream_result.ok()) << stream_result.error().to_string(); + EXPECT_EQ(stream_result.value(), value); +} + +TEST(ContainerMemoryBudgetTest, ExplicitOverride) { + std::vector value(8); + auto bytes = serialize_value(value); + const size_t required = + sizeof(std::vector) + value.size() * sizeof(BudgetItem); + + auto small_result = + with_fory(static_cast(required - 1), [&](Fory &fory) { + return fory.deserialize>(bytes); + }); + ASSERT_FALSE(small_result.ok()); + EXPECT_EQ(small_result.error().code(), ErrorCode::InvalidData); + + auto exact_result = + with_fory(static_cast(required), [&](Fory &fory) { + return fory.deserialize>(bytes); + }); + ASSERT_TRUE(exact_result.ok()) << exact_result.error().to_string(); + EXPECT_EQ(exact_result.value(), value); +} + +TEST(ContainerMemoryBudgetTest, EmptyContainersChargeFixedCost) { + std::vector> value(1); + auto bytes = serialize_value(value); + const size_t required = nested_empty_budget(1); + + auto small_result = + with_fory(static_cast(required - 1), [&](Fory &fory) { + return fory.deserialize>>(bytes); + }); + ASSERT_FALSE(small_result.ok()); + EXPECT_EQ(small_result.error().code(), ErrorCode::InvalidData); + + auto exact_result = + with_fory(static_cast(required), [&](Fory &fory) { + return fory.deserialize>>(bytes); + }); + ASSERT_TRUE(exact_result.ok()) << exact_result.error().to_string(); + EXPECT_EQ(exact_result.value(), value); +} + +TEST(ContainerMemoryBudgetTest, SiblingCumulativeBudget) { + BudgetSiblings value; + value.left.resize(16); + value.right.resize(16); + auto bytes = serialize_value(value); + const size_t one_vector = + sizeof(std::vector) + value.left.size() * sizeof(BudgetItem); + + auto small_result = + with_fory(static_cast(one_vector), [&](Fory &fory) { + return fory.deserialize(bytes); + }); + ASSERT_FALSE(small_result.ok()); + EXPECT_EQ(small_result.error().code(), ErrorCode::InvalidData); + + auto enough_result = + with_fory(static_cast(one_vector * 2), [&](Fory &fory) { + return fory.deserialize(bytes); + }); + ASSERT_TRUE(enough_result.ok()) << enough_result.error().to_string(); + EXPECT_EQ(enough_result.value(), value); +} + +TEST(ContainerMemoryBudgetTest, MapBudget) { + std::map value{{"a", 1}, {"b", 2}, {"c", 3}}; + auto bytes = serialize_value(value); + const size_t entry_bytes = + sizeof(std::string) + sizeof(int32_t) + 16 + sizeof(void *) * 3; + const size_t required = sizeof(value) + value.size() * entry_bytes; + + auto small_result = + with_fory(static_cast(required - 1), [&](Fory &fory) { + return fory.deserialize>(bytes); + }); + ASSERT_FALSE(small_result.ok()); + EXPECT_EQ(small_result.error().code(), ErrorCode::InvalidData); + + auto exact_result = + with_fory(static_cast(required), [&](Fory &fory) { + return fory.deserialize>(bytes); + }); + ASSERT_TRUE(exact_result.ok()) << exact_result.error().to_string(); + EXPECT_EQ(exact_result.value(), value); +} + +TEST(ContainerMemoryBudgetTest, DensePathsSkipped) { + { + std::string value = "container-budget-string"; + auto bytes = serialize_value(value); + auto result = with_fory( + 1, [&](Fory &fory) { return fory.deserialize(bytes); }); + ASSERT_TRUE(result.ok()) << result.error().to_string(); + EXPECT_EQ(result.value(), value); + } + { + std::vector value(256, 7); + auto bytes = serialize_value(value); + auto result = with_fory(1, [&](Fory &fory) { + return fory.deserialize>(bytes); + }); + ASSERT_TRUE(result.ok()) << result.error().to_string(); + EXPECT_EQ(result.value(), value); + } + { + std::vector value(256, 42); + auto bytes = serialize_value(value); + auto result = with_fory(1, [&](Fory &fory) { + return fory.deserialize>(bytes); + }); + ASSERT_TRUE(result.ok()) << result.error().to_string(); + EXPECT_EQ(result.value(), value); + } +} + +TEST(ContainerMemoryBudgetTest, ByteCheckStillRejectsLargeLength) { + Config config; + auto resolver = std::make_unique(); + ReadContext ctx(config, std::move(resolver)); + std::vector bytes{64}; + Buffer buffer(bytes.data(), static_cast(bytes.size()), false); + ctx.attach(buffer); + + auto result = Serializer>::read_data(ctx); + EXPECT_TRUE(result.empty()); + ASSERT_TRUE(ctx.has_error()); + EXPECT_EQ(ctx.error().code(), ErrorCode::BufferOutOfBound); +} + +} // namespace +} // namespace serialization +} // namespace fory diff --git a/cpp/fory/serialization/context.cc b/cpp/fory/serialization/context.cc index deff5ee16c..686db558ad 100644 --- a/cpp/fory/serialization/context.cc +++ b/cpp/fory/serialization/context.cc @@ -739,6 +739,43 @@ const TypeInfo *ReadContext::read_any_type_info(Error &error) { return result.value(); } +bool ReadContext::reserve_counted_container_checked(uint32_t length, + size_t fixed_bytes, + size_t elem_bytes) { + if (FORY_PREDICT_FALSE( + elem_bytes != 0 && + static_cast(length) > + (std::numeric_limits::max() - fixed_bytes) / + elem_bytes)) { + return set_container_memory_overflow(length, elem_bytes); + } + return reserve_container_memory(static_cast(length) * elem_bytes + + fixed_bytes); +} + +bool ReadContext::set_container_memory_error(const std::string &message) { + set_error(Error::invalid_data(message)); + return false; +} + +bool ReadContext::set_container_memory_overflow(uint32_t length, + size_t elem_bytes) { + set_error(Error::invalid_data( + "container memory estimate overflows: length=" + std::to_string(length) + + " elementBytes=" + std::to_string(elem_bytes))); + return false; +} + +bool ReadContext::set_container_memory_exceeded(size_t bytes, + size_t remaining) { + set_error(Error::invalid_data( + "estimated container memory request " + std::to_string(bytes) + + " bytes exceeds max_container_memory_bytes remaining budget " + + std::to_string(remaining) + " bytes out of effective limit " + + std::to_string(container_memory_limit_bytes_) + " bytes")); + return false; +} + void ReadContext::reset() { // Clear error state first error_ = Error(); @@ -747,6 +784,9 @@ void ReadContext::reset() { } reading_type_infos_.clear(); current_dyn_depth_ = 0; + // Root deserialization initializes the container budget before reading the + // header; direct ReadContext users start with the unlimited sentinel fields. + // Leave those fields untouched here so root guard cleanup stays store-light. if (meta_string_table_active_) { meta_string_table_.reset(); meta_string_table_active_ = false; diff --git a/cpp/fory/serialization/context.h b/cpp/fory/serialization/context.h index 6af99c4ccc..5d2bbc3c60 100644 --- a/cpp/fory/serialization/context.h +++ b/cpp/fory/serialization/context.h @@ -32,6 +32,7 @@ #include "fory/util/result.h" #include +#include #include #include @@ -504,6 +505,98 @@ class ReadContext { } } + FORY_ALWAYS_INLINE bool init_container_budget_known(size_t root_bytes) { + size_t limit = 0; + if (config_->max_container_memory_bytes > 0) { + const uint64_t configured = + static_cast(config_->max_container_memory_bytes); + if constexpr (sizeof(size_t) < sizeof(uint64_t)) { + if (FORY_PREDICT_FALSE( + configured > + static_cast(std::numeric_limits::max()))) { + return set_container_memory_error( + "max_container_memory_bytes does not fit size_t"); + } + } + limit = static_cast(configured); + } else { + constexpr size_t max_root_bytes = (std::numeric_limits::max() - + kKnownContainerBudgetSlackBytes) / + kKnownContainerBudgetMultiplier; + if (FORY_PREDICT_FALSE(root_bytes > max_root_bytes)) { + return set_container_memory_error( + "root input size overflows automatic container memory budget"); + } + limit = root_bytes * kKnownContainerBudgetMultiplier + + kKnownContainerBudgetSlackBytes; + } + container_memory_limit_bytes_ = limit; + remaining_container_memory_bytes_ = limit; + return true; + } + + FORY_ALWAYS_INLINE bool init_container_budget_unknown() { + size_t limit = 0; + if (config_->max_container_memory_bytes > 0) { + const uint64_t configured = + static_cast(config_->max_container_memory_bytes); + if constexpr (sizeof(size_t) < sizeof(uint64_t)) { + if (FORY_PREDICT_FALSE( + configured > + static_cast(std::numeric_limits::max()))) { + return set_container_memory_error( + "max_container_memory_bytes does not fit size_t"); + } + } + limit = static_cast(configured); + } else { + limit = kUnknownContainerBudgetBytes; + } + container_memory_limit_bytes_ = limit; + remaining_container_memory_bytes_ = limit; + return true; + } + + FORY_ALWAYS_INLINE bool reserve_container_memory(size_t bytes) { + const size_t remaining = remaining_container_memory_bytes_; + if (FORY_PREDICT_FALSE(bytes > remaining)) { + return set_container_memory_exceeded(bytes, remaining); + } + remaining_container_memory_bytes_ = remaining - bytes; + return true; + } + + FORY_ALWAYS_INLINE bool reserve_counted_container_memory(uint32_t length, + size_t fixed_bytes, + size_t elem_bytes) { + if (length == 0) { + return reserve_container_memory(fixed_bytes); + } + constexpr size_t kMaxLength = + static_cast(std::numeric_limits::max()); + if (FORY_PREDICT_TRUE(elem_bytes <= + (std::numeric_limits::max() - fixed_bytes) / + kMaxLength)) { + return reserve_container_memory(static_cast(length) * elem_bytes + + fixed_bytes); + } + return reserve_counted_container_checked(length, fixed_bytes, elem_bytes); + } + + template + FORY_ALWAYS_INLINE bool reserve_counted_container_memory(uint32_t length) { + constexpr size_t kMaxLength = + static_cast(std::numeric_limits::max()); + if constexpr (elem_bytes <= + (std::numeric_limits::max() - fixed_bytes) / + kMaxLength) { + return reserve_container_memory(static_cast(length) * elem_bytes + + fixed_bytes); + } else { + return reserve_counted_container_checked(length, fixed_bytes, elem_bytes); + } + } + // =========================================================================== // Read methods with Error& parameter // All methods accept Error& as parameter for reduced overhead. @@ -659,9 +752,22 @@ class ReadContext { inline const Config &config() const { return *config_; } private: + static constexpr size_t kKnownContainerBudgetMultiplier = 8; + static constexpr size_t kKnownContainerBudgetSlackBytes = 64 * 1024; + static constexpr size_t kUnknownContainerBudgetBytes = + 128ULL * 1024ULL * 1024ULL; + FORY_NOINLINE Result check_remote_type_meta_limit(const TypeMeta &type_meta); void record_remote_type_meta(const std::string &type_key); + FORY_NOINLINE bool reserve_counted_container_checked(uint32_t length, + size_t fixed_bytes, + size_t elem_bytes); + FORY_NOINLINE bool set_container_memory_error(const std::string &message); + FORY_NOINLINE bool set_container_memory_overflow(uint32_t length, + size_t elem_bytes); + FORY_NOINLINE bool set_container_memory_exceeded(size_t bytes, + size_t remaining); // Error state - accumulated during deserialization, checked at the end Error error_; @@ -671,6 +777,8 @@ class ReadContext { std::unique_ptr type_resolver_; RefReader ref_reader_; uint32_t current_dyn_depth_; + size_t container_memory_limit_bytes_ = std::numeric_limits::max(); + size_t remaining_container_memory_bytes_ = std::numeric_limits::max(); // Meta sharing state (for compatible mode) // Persistent cache storage for TypeInfo objects keyed by meta header. diff --git a/cpp/fory/serialization/fory.h b/cpp/fory/serialization/fory.h index 36ef992d17..6d26c3bfa7 100644 --- a/cpp/fory/serialization/fory.h +++ b/cpp/fory/serialization/fory.h @@ -109,6 +109,16 @@ class ForyBuilder { return *this; } + /// Set maximum estimated container-owned memory for one root deserialization. + /// + /// Use `-1` for automatic limits. Positive values are explicit byte limits. + ForyBuilder &max_container_memory_bytes(int64_t max_bytes) { + FORY_CHECK(max_bytes == -1 || max_bytes > 0) + << "max_container_memory_bytes must be positive or -1 for auto"; + config_.max_container_memory_bytes = max_bytes; + return *this; + } + /// Set maximum accepted field count in one received struct TypeMeta. ForyBuilder &max_type_fields(uint32_t max_fields) { FORY_CHECK(max_fields > 0) << "max_type_fields must be positive"; @@ -673,19 +683,7 @@ class Fory : public BaseFory { Buffer buffer(const_cast(data), static_cast(size), false); - - Error header_error; - const uint8_t header = buffer.read_uint8(header_error); - if (FORY_PREDICT_FALSE(!header_error.ok())) { - return Unexpected(std::move(header_error)); - } - if (FORY_PREDICT_FALSE(header != precomputed_header_)) { - return Unexpected(invalid_root_header(header)); - } - - read_ctx_->attach(buffer); - ReadContextGuard guard(*read_ctx_); - return deserialize_impl(buffer); + return deserialize_buffer(buffer); } /// Deserialize an object from a byte vector. @@ -711,18 +709,7 @@ class Fory : public BaseFory { if (FORY_PREDICT_FALSE(!finalized_)) { ensure_finalized(); } - Error header_error; - const uint8_t header = buffer.read_uint8(header_error); - if (FORY_PREDICT_FALSE(!header_error.ok())) { - return Unexpected(std::move(header_error)); - } - if (FORY_PREDICT_FALSE(header != precomputed_header_)) { - return Unexpected(invalid_root_header(header)); - } - - read_ctx_->attach(buffer); - ReadContextGuard guard(*read_ctx_); - return deserialize_impl(buffer); + return deserialize_buffer(buffer); } /// Deserialize an object from an input stream. @@ -745,7 +732,10 @@ class Fory : public BaseFory { }; StreamShrinkGuard shrink_guard{&input_stream}; Buffer &buffer = input_stream.get_buffer(); - return deserialize(buffer); + if (FORY_PREDICT_FALSE(!finalized_)) { + ensure_finalized(); + } + return deserialize_buffer(buffer); } /// Deserialize an object from StdInputStream. @@ -883,6 +873,34 @@ class Fory : public BaseFory { return result; } + template + FORY_ALWAYS_INLINE Result deserialize_buffer(Buffer &buffer) { + const bool budget_ok = + unknown_root + ? read_ctx_->init_container_budget_unknown() + : read_ctx_->init_container_budget_known(buffer.remaining_size()); + if (FORY_PREDICT_FALSE(!budget_ok)) { + Error error = read_ctx_->take_error(); + read_ctx_->reset(); + return Unexpected(std::move(error)); + } + + Error header_error; + const uint8_t header = buffer.read_uint8(header_error); + if (FORY_PREDICT_FALSE(!header_error.ok())) { + read_ctx_->reset(); + return Unexpected(std::move(header_error)); + } + if (FORY_PREDICT_FALSE(header != precomputed_header_)) { + read_ctx_->reset(); + return Unexpected(invalid_root_header(header)); + } + + read_ctx_->attach(buffer); + ReadContextGuard guard(*read_ctx_); + return deserialize_impl(buffer); + } + template Result cached_write_root_type_info() { constexpr uint64_t ctid = type_index(); diff --git a/cpp/fory/serialization/map_serializer.h b/cpp/fory/serialization/map_serializer.h index 830e5fbae5..a7a3bc615d 100644 --- a/cpp/fory/serialization/map_serializer.h +++ b/cpp/fory/serialization/map_serializer.h @@ -21,6 +21,7 @@ #include "fory/serialization/serializer.h" #include +#include #include #include #include @@ -81,6 +82,9 @@ struct MapReserver inline bool reserve_map(MapType &map, ReadContext &ctx, uint32_t length) { // Lazy error propagation may continue into later readers; do not let that @@ -88,6 +92,20 @@ inline bool reserve_map(MapType &map, ReadContext &ctx, uint32_t length) { if (FORY_PREDICT_FALSE(ctx.has_error())) { return false; } + using Key = typename MapType::key_type; + using Value = typename MapType::mapped_type; + static_assert(sizeof(Key) <= std::numeric_limits::max() - + sizeof(Value) - kMapEntryBudgetBytes - + kMapReferenceBudgetBytes * 3, + "map entry memory estimate overflows"); + constexpr size_t fixed_bytes = sizeof(MapType); + constexpr size_t elem_bytes = sizeof(Key) + sizeof(Value) + + kMapEntryBudgetBytes + + kMapReferenceBudgetBytes * 3; + if (FORY_PREDICT_FALSE((!ctx.template reserve_counted_container_memory< + fixed_bytes, elem_bytes>(length)))) { + return false; + } if (FORY_PREDICT_FALSE(!ctx.buffer().ensure_readable(length, ctx.error()))) { return false; } @@ -95,6 +113,13 @@ inline bool reserve_map(MapType &map, ReadContext &ctx, uint32_t length) { return true; } +template inline bool reserve_empty_map(ReadContext &ctx) { + if (FORY_PREDICT_FALSE(ctx.has_error())) { + return false; + } + return ctx.reserve_container_memory(sizeof(MapType)); +} + /// write chunk size at header offset inline void write_chunk_size(WriteContext &ctx, size_t header_offset, uint8_t size) { @@ -567,6 +592,9 @@ inline MapType read_map_data_fast(ReadContext &ctx, uint32_t length) { MapType result; if (length == 0) { + if (FORY_PREDICT_FALSE(!reserve_empty_map(ctx))) { + return result; + } return result; } if (FORY_PREDICT_FALSE(!reserve_map(result, ctx, length))) { @@ -699,6 +727,9 @@ template inline MapType read_map_data_slow(ReadContext &ctx, uint32_t length) { MapType result; if (length == 0) { + if (FORY_PREDICT_FALSE(!reserve_empty_map(ctx))) { + return result; + } return result; } if (FORY_PREDICT_FALSE(!reserve_map(result, ctx, length))) { diff --git a/cpp/fory/serialization/struct_serializer.h b/cpp/fory/serialization/struct_serializer.h index 4da4f7751c..00acf71b2a 100644 --- a/cpp/fory/serialization/struct_serializer.h +++ b/cpp/fory/serialization/struct_serializer.h @@ -897,9 +897,9 @@ Container read_configured_list_data(ReadContext &ctx) { uint32_t length = ctx.read_var_uint32(ctx.error()); Container result; if (length == 0) { - return result; - } - if (FORY_PREDICT_FALSE(!reserve_collection(result, ctx, length))) { + if (FORY_PREDICT_FALSE(!reserve_empty_collection(ctx))) { + return result; + } return result; } uint8_t bitmap = ctx.read_uint8(ctx.error()); @@ -916,6 +916,9 @@ Container read_configured_list_data(ReadContext &ctx) { return result; } } + if (FORY_PREDICT_FALSE(!reserve_collection(result, ctx, length))) { + return result; + } const RefMode elem_ref_mode = track_ref ? RefMode::Tracking : (has_null ? RefMode::NullOnly : RefMode::None); @@ -939,7 +942,13 @@ FORY_NOINLINE Container read_configured_list_data_as_array_field( using Elem = element_type_t; uint32_t length = ctx.read_var_uint32(ctx.error()); Container result; - if (FORY_PREDICT_FALSE(ctx.has_error()) || length == 0) { + if (FORY_PREDICT_FALSE(ctx.has_error())) { + return result; + } + if (length == 0) { + if (FORY_PREDICT_FALSE(!reserve_empty_collection(ctx))) { + return result; + } return result; } uint8_t bitmap = ctx.read_uint8(ctx.error()); @@ -1051,6 +1060,9 @@ MapType read_configured_map_data(ReadContext &ctx) { uint32_t length = ctx.read_var_uint32(ctx.error()); MapType result; if (length == 0) { + if (FORY_PREDICT_FALSE(!reserve_empty_map(ctx))) { + return result; + } return result; } if (FORY_PREDICT_FALSE(!reserve_map(result, ctx, length))) { diff --git a/cpp/fory/serialization/union_serializer.h b/cpp/fory/serialization/union_serializer.h index d5247d431f..8a8bc99fe3 100644 --- a/cpp/fory/serialization/union_serializer.h +++ b/cpp/fory/serialization/union_serializer.h @@ -466,9 +466,9 @@ Container read_union_configured_list_data(ReadContext &ctx) { uint32_t length = ctx.read_var_uint32(ctx.error()); Container result; if (length == 0) { - return result; - } - if (FORY_PREDICT_FALSE(!reserve_collection(result, ctx, length))) { + if (FORY_PREDICT_FALSE(!reserve_empty_collection(ctx))) { + return result; + } return result; } uint8_t bitmap = ctx.read_uint8(ctx.error()); @@ -483,6 +483,9 @@ Container read_union_configured_list_data(ReadContext &ctx) { return result; } } + if (FORY_PREDICT_FALSE(!reserve_collection(result, ctx, length))) { + return result; + } for (uint32_t i = 0; i < length; ++i) { if constexpr (ElemNode >= 0) { auto elem = read_union_configured_value( @@ -553,6 +556,9 @@ MapType read_union_configured_map_data(ReadContext &ctx) { uint32_t length = ctx.read_var_uint32(ctx.error()); MapType result; if (length == 0) { + if (FORY_PREDICT_FALSE(!reserve_empty_map(ctx))) { + return result; + } return result; } if (FORY_PREDICT_FALSE(!reserve_map(result, ctx, length))) { diff --git a/csharp/src/Fory.Generator/ForyModelGenerator.cs b/csharp/src/Fory.Generator/ForyModelGenerator.cs index 8e051da478..50c4682515 100644 --- a/csharp/src/Fory.Generator/ForyModelGenerator.cs +++ b/csharp/src/Fory.Generator/ForyModelGenerator.cs @@ -1163,10 +1163,12 @@ private static void EmitReadCompatibleListArrayPayload( uint elementTypeId = PackedArrayElementTypeId(codec.TypeId); if (codec.CarrierKind == CarrierKind.Array) { + sb.AppendLine($"{indent}context.ReserveArrayMemory<{elementTypeName}>({lengthVar});"); sb.AppendLine($"{indent}{codec.TypeName} {targetVar} = new {ElementTypeName(codec.TypeName)}[{lengthVar}];"); } else { + sb.AppendLine($"{indent}context.ReserveListMemory<{elementTypeName}>({lengthVar});"); sb.AppendLine($"{indent}{codec.TypeName} {targetVar} = new({lengthVar});"); } diff --git a/csharp/src/Fory/CollectionSerializers.cs b/csharp/src/Fory/CollectionSerializers.cs index c407153fd5..2e5a610d00 100644 --- a/csharp/src/Fory/CollectionSerializers.cs +++ b/csharp/src/Fory/CollectionSerializers.cs @@ -201,6 +201,7 @@ public static List ReadCollectionData(Serializer elementSerializer, Rea int length = checked((int)context.Reader.ReadVarUInt32()); if (length == 0) { + context.ReserveListMemory(length); return []; } @@ -213,6 +214,7 @@ public static List ReadCollectionData(Serializer elementSerializer, Rea bool hasNull = (header & CollectionBits.HasNull) != 0; bool declared = (header & CollectionBits.DeclaredElementType) != 0; bool sameType = (header & CollectionBits.SameType) != 0; + context.ReserveNonEmptyListMemory(length); context.Reader.CheckBound(length); List values = new(length); if (!sameType) @@ -522,6 +524,7 @@ public override void WriteData(WriteContext context, in T[] value, bool hasGener public override T[] ReadData(ReadContext context) { List values = CollectionCodec.ReadCollectionData(context.TypeResolver.GetSerializer(), context); + context.ReserveArrayMemory(values.Count); return values.ToArray(); } } @@ -554,7 +557,9 @@ public override void WriteData(WriteContext context, in HashSet value, bool h public override HashSet ReadData(ReadContext context) { - return [.. CollectionCodec.ReadCollectionData(context.TypeResolver.GetSerializer(), context)]; + List values = CollectionCodec.ReadCollectionData(context.TypeResolver.GetSerializer(), context); + context.ReserveLinkedCollectionMemory(values.Count); + return [.. values]; } } @@ -570,7 +575,9 @@ public override void WriteData(WriteContext context, in SortedSet value, bool public override SortedSet ReadData(ReadContext context) { - return [.. CollectionCodec.ReadCollectionData(context.TypeResolver.GetSerializer(), context)]; + List values = CollectionCodec.ReadCollectionData(context.TypeResolver.GetSerializer(), context); + context.ReserveLinkedCollectionMemory(values.Count); + return [.. values]; } } @@ -586,7 +593,9 @@ public override void WriteData(WriteContext context, in ImmutableHashSet valu public override ImmutableHashSet ReadData(ReadContext context) { - return ImmutableHashSet.CreateRange(CollectionCodec.ReadCollectionData(context.TypeResolver.GetSerializer(), context)); + List values = CollectionCodec.ReadCollectionData(context.TypeResolver.GetSerializer(), context); + context.ReserveLinkedCollectionMemory(values.Count); + return ImmutableHashSet.CreateRange(values); } } @@ -602,7 +611,9 @@ public override void WriteData(WriteContext context, in LinkedList value, boo public override LinkedList ReadData(ReadContext context) { - return new LinkedList(CollectionCodec.ReadCollectionData(context.TypeResolver.GetSerializer(), context)); + List values = CollectionCodec.ReadCollectionData(context.TypeResolver.GetSerializer(), context); + context.ReserveLinkedCollectionMemory(values.Count); + return new LinkedList(values); } } @@ -619,6 +630,7 @@ public override void WriteData(WriteContext context, in Queue value, bool has public override Queue ReadData(ReadContext context) { List values = CollectionCodec.ReadCollectionData(context.TypeResolver.GetSerializer(), context); + context.ReserveLinkedCollectionMemory(values.Count); Queue queue = new(values.Count); for (int i = 0; i < values.Count; i++) { @@ -655,6 +667,7 @@ public override void WriteData(WriteContext context, in Stack value, bool has public override Stack ReadData(ReadContext context) { List values = CollectionCodec.ReadCollectionData(context.TypeResolver.GetSerializer(), context); + context.ReserveLinkedCollectionMemory(values.Count); Stack stack = new(values.Count); for (int i = 0; i < values.Count; i++) { diff --git a/csharp/src/Fory/Config.cs b/csharp/src/Fory/Config.cs index 438039d2c8..1947bac29c 100644 --- a/csharp/src/Fory/Config.cs +++ b/csharp/src/Fory/Config.cs @@ -28,6 +28,7 @@ internal Config( bool compatible, bool checkStructVersion, int maxDepth, + long maxContainerMemoryBytes, int maxTypeFields, int maxTypeMetaBytes, int maxSchemaVersionsPerType, @@ -37,6 +38,12 @@ internal Config( { throw new ArgumentOutOfRangeException(nameof(maxDepth), "MaxDepth must be greater than 0."); } + if (maxContainerMemoryBytes != -1 && maxContainerMemoryBytes <= 0) + { + throw new ArgumentOutOfRangeException( + nameof(maxContainerMemoryBytes), + "MaxContainerMemoryBytes must be positive or -1 for auto."); + } if (maxTypeFields <= 0) { throw new ArgumentOutOfRangeException(nameof(maxTypeFields), "MaxTypeFields must be greater than 0."); @@ -58,6 +65,7 @@ internal Config( Compatible = compatible; CheckStructVersion = checkStructVersion; MaxDepth = maxDepth; + MaxContainerMemoryBytes = maxContainerMemoryBytes; MaxTypeFields = maxTypeFields; MaxTypeMetaBytes = maxTypeMetaBytes; MaxSchemaVersionsPerType = maxSchemaVersionsPerType; @@ -84,6 +92,11 @@ internal Config( /// public int MaxDepth { get; } + /// + /// Gets the maximum estimated container-owned memory accepted during one root deserialization. + /// + public long MaxContainerMemoryBytes { get; } + /// /// Gets the maximum accepted field count in one received struct TypeMeta. /// @@ -114,6 +127,7 @@ public sealed class ForyBuilder private bool? _compatible; private bool _checkStructVersion; private int _maxDepth = 20; + private long _maxContainerMemoryBytes = -1; private int _maxTypeFields = 512; private int _maxTypeMetaBytes = 4096; private int _maxSchemaVersionsPerType = 10; @@ -169,6 +183,23 @@ public ForyBuilder MaxDepth(int value) return this; } + /// + /// Sets the maximum estimated container-owned memory accepted during one root deserialization. + /// Use -1 for the automatic root-size-based limit, or a positive byte limit. + /// + public ForyBuilder MaxContainerMemoryBytes(long value) + { + if (value != -1 && value <= 0) + { + throw new ArgumentOutOfRangeException( + nameof(value), + "MaxContainerMemoryBytes must be positive or -1 for auto."); + } + + _maxContainerMemoryBytes = value; + return this; + } + /// /// Sets the maximum accepted field count in one received struct TypeMeta. /// @@ -235,6 +266,7 @@ private Config BuildConfig() compatible: compatible, checkStructVersion: compatible ? false : _checkStructVersion, maxDepth: _maxDepth, + maxContainerMemoryBytes: _maxContainerMemoryBytes, maxTypeFields: _maxTypeFields, maxTypeMetaBytes: _maxTypeMetaBytes, maxSchemaVersionsPerType: _maxSchemaVersionsPerType, diff --git a/csharp/src/Fory/DictionarySerializers.cs b/csharp/src/Fory/DictionarySerializers.cs index 5aa49dfa75..bdcec3222a 100644 --- a/csharp/src/Fory/DictionarySerializers.cs +++ b/csharp/src/Fory/DictionarySerializers.cs @@ -214,9 +214,11 @@ public override TDictionary ReadData(ReadContext context) int totalLength = checked((int)context.Reader.ReadVarUInt32()); if (totalLength == 0) { + context.ReserveMapMemory(totalLength); return CreateMap(0); } + context.ReserveNonEmptyMapMemory(totalLength); context.Reader.CheckBound(totalLength); TDictionary map = CreateMap(totalLength); bool keyDynamicType = keyTypeInfo.IsDynamicType; diff --git a/csharp/src/Fory/Fory.cs b/csharp/src/Fory/Fory.cs index 9bbafd1775..edcfdb13b2 100644 --- a/csharp/src/Fory/Fory.cs +++ b/csharp/src/Fory/Fory.cs @@ -190,6 +190,7 @@ public T Deserialize(ReadOnlySpan payload) { ByteReader reader = _readContext.Reader; reader.Reset(payload); + _readContext.InitContainerBudgetKnown(payload.Length); T value = DeserializeFromReader(reader); if (reader.Remaining != 0) { @@ -210,6 +211,7 @@ public T Deserialize(byte[] payload) { ByteReader reader = _readContext.Reader; reader.Reset(payload); + _readContext.InitContainerBudgetKnown(payload.Length); T value = DeserializeFromReader(reader); if (reader.Remaining != 0) { @@ -230,6 +232,7 @@ public T Deserialize(ref ReadOnlySequence payload) byte[] bytes = payload.ToArray(); ByteReader reader = _readContext.Reader; reader.Reset(bytes); + _readContext.InitContainerBudgetKnown(bytes.Length); T value = DeserializeFromReader(reader); payload = payload.Slice(reader.Cursor); return value; diff --git a/csharp/src/Fory/NullableKeyDictionary.cs b/csharp/src/Fory/NullableKeyDictionary.cs index d6c8caab47..fe573cae02 100644 --- a/csharp/src/Fory/NullableKeyDictionary.cs +++ b/csharp/src/Fory/NullableKeyDictionary.cs @@ -537,9 +537,11 @@ public override NullableKeyDictionary ReadData(ReadContext context int totalLength = checked((int)context.Reader.ReadVarUInt32()); if (totalLength == 0) { + context.ReserveMapMemory(totalLength); return new NullableKeyDictionary(); } + context.ReserveNonEmptyMapMemory(totalLength); context.Reader.CheckBound(totalLength); NullableKeyDictionary map = new(totalLength); bool keyDynamicType = keyTypeInfo.IsDynamicType; diff --git a/csharp/src/Fory/PrimitiveDictionarySerializers.cs b/csharp/src/Fory/PrimitiveDictionarySerializers.cs index a136bd57bd..e753280f72 100644 --- a/csharp/src/Fory/PrimitiveDictionarySerializers.cs +++ b/csharp/src/Fory/PrimitiveDictionarySerializers.cs @@ -672,9 +672,11 @@ public static TMap ReadMap( int totalLength = checked((int)context.Reader.ReadVarUInt32()); if (totalLength == 0) { + context.ReserveMapMemory(totalLength); return TMapOps.Create(0); } + context.ReserveNonEmptyMapMemory(totalLength); context.Reader.CheckBound(totalLength); TMap map = TMapOps.Create(totalLength); TypeId keyTypeId = TKeyCodec.WireTypeId; diff --git a/csharp/src/Fory/ReadContext.cs b/csharp/src/Fory/ReadContext.cs index f83ac0e99e..31cc878714 100644 --- a/csharp/src/Fory/ReadContext.cs +++ b/csharp/src/Fory/ReadContext.cs @@ -15,11 +15,21 @@ // specific language governing permissions and limitations // under the License. +using System.ComponentModel; +using System.Runtime.CompilerServices; + namespace Apache.Fory; public sealed class ReadContext { private const int MinRemoteTypeMetaLimit = 8192; + internal const long KnownContainerBudgetSlackBytes = 64 * 1024; + internal const long UnknownContainerBudgetBytes = 128L * 1024 * 1024; + internal const int ContainerFixedBytes = 32; + internal const int ArrayHeaderBytes = 24; + internal const int ReferenceBytes = 4; + internal const int CollectionEntryOverheadBytes = 16; + internal const int MapEntryOverheadBytes = 24; private readonly ReusableArray _typeMetaRefs = new(); private readonly UInt64Map _typeMetasByHeader = new(); @@ -40,6 +50,8 @@ public sealed class ReadContext private readonly Dictionary _remoteSchemaVersionsByType = []; private readonly Config _config; private int _totalAcceptedSchemaVersions; + private long _containerMemoryLimitBytes = long.MaxValue; + private long _remainingContainerMemoryBytes = long.MaxValue; public ReadContext( ByteReader reader, @@ -70,6 +82,134 @@ public ReadContext( internal RefReader RefReader { get; } + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal static int ElementBytes() => ContainerElementBytes.Value; + + private static class ContainerElementBytes + { + internal static readonly int Value = typeof(T).IsValueType ? Unsafe.SizeOf() : ReferenceBytes; + } + + private static class MapElementBytes + { + internal static readonly int Value = + ElementBytes() + ElementBytes() + MapEntryOverheadBytes + ReferenceBytes; + } + + /// + /// Reserves estimated list-owned memory for generated serializer code. + /// Configure instead of calling this directly. + /// + [EditorBrowsable(EditorBrowsableState.Never)] + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public void ReserveListMemory(int length) + { + if (length == 0) + { + ReserveContainerMemory(ContainerFixedBytes); + return; + } + + ReserveNonEmptyListMemory(length); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal void ReserveNonEmptyListMemory(int length) + { + ReserveContainerMemory((long)(uint)length * ElementBytes() + ContainerFixedBytes + ArrayHeaderBytes); + } + + /// + /// Reserves estimated array-owned memory for generated serializer code. + /// Configure instead of calling this directly. + /// + [EditorBrowsable(EditorBrowsableState.Never)] + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public void ReserveArrayMemory(int length) + { + ReserveCountedContainerMemory( + length, + ArrayHeaderBytes, + ElementBytes()); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal void ReserveLinkedCollectionMemory(int length) + { + if (length == 0) + { + ReserveContainerMemory(ContainerFixedBytes); + return; + } + + ReserveContainerMemory( + (long)(uint)length * (ElementBytes() + CollectionEntryOverheadBytes + ReferenceBytes * 2) + + ContainerFixedBytes); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal void ReserveMapMemory(int length) + { + if (length == 0) + { + ReserveContainerMemory(ContainerFixedBytes); + return; + } + + ReserveNonEmptyMapMemory(length); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal void ReserveNonEmptyMapMemory(int length) + { + ReserveContainerMemory( + (long)(uint)length * MapElementBytes.Value + ContainerFixedBytes + ArrayHeaderBytes * 2); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal void InitContainerBudgetKnown(int rootBytes) + { + long limit = _config.MaxContainerMemoryBytes; + if (limit < 0) + { + limit = (long)rootBytes * 8 + KnownContainerBudgetSlackBytes; + } + + _containerMemoryLimitBytes = limit; + _remainingContainerMemoryBytes = limit; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal void ReserveContainerMemory(long bytes) + { + long remaining = _remainingContainerMemoryBytes; + if ((ulong)bytes > (ulong)remaining) + { + ThrowContainerBudgetExceeded(bytes, remaining, _containerMemoryLimitBytes); + } + + _remainingContainerMemoryBytes = remaining - bytes; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal void ReserveCountedContainerMemory(int count, int fixedBytes, int elementBytes) + { + ReserveContainerMemory((long)(uint)count * elementBytes + fixedBytes); + } + + [MethodImpl(MethodImplOptions.NoInlining)] + private static void ThrowContainerBudgetOverflow() + { + throw new InvalidDataException("container memory estimate overflows"); + } + + [MethodImpl(MethodImplOptions.NoInlining)] + private static void ThrowContainerBudgetExceeded(long bytes, long remaining, long limit) + { + throw new InvalidDataException( + $"estimated container memory request {bytes} bytes exceeds MaxContainerMemoryBytes remaining budget {remaining} bytes out of effective limit {limit} bytes"); + } + internal void ResetFor(ByteReader reader) { Reader = reader; diff --git a/csharp/tests/Fory.Tests/ContainerMemoryBudgetTests.cs b/csharp/tests/Fory.Tests/ContainerMemoryBudgetTests.cs new file mode 100644 index 0000000000..dabaa03b28 --- /dev/null +++ b/csharp/tests/Fory.Tests/ContainerMemoryBudgetTests.cs @@ -0,0 +1,188 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +using System.Buffers; +using Apache.Fory; +using ForyRuntime = Apache.Fory.Fory; + +namespace Apache.Fory.Tests; + +[ForyStruct] +public sealed class BudgetItem +{ + public int Id { get; set; } + public string Name { get; set; } = string.Empty; +} + +[ForyStruct] +public sealed class BudgetSiblings +{ + public List Left { get; set; } = []; + public List Right { get; set; } = []; +} + +[ForyStruct] +public sealed class BudgetArrayHolder +{ + public BudgetItem[] Values { get; set; } = []; +} + +public sealed class ContainerMemoryBudgetTests +{ + private static ForyRuntime NewFory(long maxContainerMemoryBytes = -1) + { + return ForyRuntime.Builder() + .Compatible(false) + .TrackRef(false) + .MaxContainerMemoryBytes(maxContainerMemoryBytes) + .Build() + .Register(1001) + .Register(1002) + .Register(1003); + } + + private static byte[] Serialize(T value) + { + return NewFory().Serialize(value); + } + + private static long ListBudget(int count) + { + return count == 0 + ? ReadContext.ContainerFixedBytes + : ReadContext.ContainerFixedBytes + ReadContext.ArrayHeaderBytes + + (long)count * ReadContext.ElementBytes(); + } + + private static long ArrayBudget(int count) + { + return ReadContext.ArrayHeaderBytes + (long)count * ReadContext.ElementBytes(); + } + + private static long MapBudget(int count) + { + return count == 0 + ? ReadContext.ContainerFixedBytes + : ReadContext.ContainerFixedBytes + ReadContext.ArrayHeaderBytes * 2 + + (long)count * (ReadContext.ElementBytes() + ReadContext.ElementBytes() + + ReadContext.MapEntryOverheadBytes + ReadContext.ReferenceBytes); + } + + [Fact] + public void KnownLengthAutoBudgetRejectsLargeNestedEmpties() + { + const int count = 3000; + List> value = Enumerable.Range(0, count).Select(_ => new List()).ToList(); + byte[] bytes = Serialize(value); + long autoLimit = bytes.LongLength * 8 + ReadContext.KnownContainerBudgetSlackBytes; + long required = ListBudget>(count) + count * ListBudget(0); + Assert.True(required > autoLimit); + + Assert.Throws(() => NewFory().Deserialize>>(bytes)); + + List> result = NewFory(required).Deserialize>>(bytes); + Assert.Equal(count, result.Count); + } + + [Fact] + public void ReadOnlySequenceUsesKnownLengthAutoBudget() + { + const int count = 3000; + List> value = Enumerable.Range(0, count).Select(_ => new List()).ToList(); + byte[] bytes = Serialize(value); + ReadOnlySequence sequence = new(bytes); + + Assert.Throws(() => NewFory().Deserialize>>(ref sequence)); + } + + [Fact] + public void ExplicitConfigOverridesAutoBudget() + { + List value = Enumerable.Range(0, 8).Select(i => new BudgetItem { Id = i }).ToList(); + byte[] bytes = Serialize(value); + long required = ListBudget(value.Count); + + Assert.Throws(() => NewFory(required - 1).Deserialize>(bytes)); + List result = NewFory(required).Deserialize>(bytes); + Assert.Equal(value.Count, result.Count); + } + + [Fact] + public void SiblingContainersShareOneBudget() + { + BudgetSiblings value = new() + { + Left = Enumerable.Range(0, 16).Select(i => new BudgetItem { Id = i }).ToList(), + Right = Enumerable.Range(0, 16).Select(i => new BudgetItem { Id = i }).ToList(), + }; + byte[] bytes = Serialize(value); + long oneList = ListBudget(16); + + Assert.Throws(() => NewFory(oneList).Deserialize(bytes)); + BudgetSiblings result = NewFory(oneList * 2).Deserialize(bytes); + Assert.Equal(16, result.Left.Count); + Assert.Equal(16, result.Right.Count); + } + + [Fact] + public void MapBudgetIsCharged() + { + Dictionary value = new() { ["a"] = 1, ["b"] = 2, ["c"] = 3 }; + byte[] bytes = Serialize(value); + long required = MapBudget(value.Count); + + Assert.Throws(() => NewFory(required - 1).Deserialize>(bytes)); + Dictionary result = NewFory(required).Deserialize>(bytes); + Assert.Equal(value, result); + } + + [Fact] + public void ReferenceArrayAndInlineValueListAreCharged() + { + BudgetArrayHolder holder = new() + { + Values = Enumerable.Range(0, 4).Select(i => new BudgetItem { Id = i }).ToArray(), + }; + byte[] holderBytes = Serialize(holder); + long holderRequired = ListBudget(4) + ArrayBudget(4); + Assert.Throws(() => NewFory(holderRequired - 1).Deserialize(holderBytes)); + Assert.Equal(4, NewFory(holderRequired).Deserialize(holderBytes).Values.Length); + + List ints = [1, 2, 3, 4]; + byte[] intBytes = Serialize(ints); + long listRequired = ListBudget(ints.Count); + Assert.Throws(() => NewFory(listRequired - 1).Deserialize>(intBytes)); + Assert.Equal(ints, NewFory(listRequired).Deserialize>(intBytes)); + } + + [Fact] + public void DenseStringBinaryAndPrimitiveArraysAreSkipped() + { + Assert.Equal("budget", NewFory(1).Deserialize(Serialize("budget"))); + Assert.Equal(new byte[] { 1, 2, 3 }, NewFory(1).Deserialize(Serialize(new byte[] { 1, 2, 3 }))); + Assert.Equal(new[] { 1, 2, 3 }, NewFory(1).Deserialize(Serialize(new[] { 1, 2, 3 }))); + } + + [Fact] + public void ByteAvailabilityCheckStillRejectsLargeLength() + { + byte[] bytes = [64, 0]; + ReadContext context = new(new ByteReader(bytes), new TypeResolver(), NewFory().Config); + + Assert.Throws(() => new ListSerializer().ReadData(context)); + } +} diff --git a/dart/packages/fory/lib/src/config.dart b/dart/packages/fory/lib/src/config.dart index d5d248cd36..6f529ecc3c 100644 --- a/dart/packages/fory/lib/src/config.dart +++ b/dart/packages/fory/lib/src/config.dart @@ -28,6 +28,7 @@ final class Config { static const int defaultMaxTypeMetaBytes = 4096; static const int defaultMaxSchemaVersionsPerType = 10; static const int defaultMaxAverageSchemaVersionsPerType = 3; + static const int defaultMaxContainerMemoryBytes = -1; /// Enables compatible struct encoding and decoding. /// @@ -56,6 +57,11 @@ final class Config { /// types. final int maxAverageSchemaVersionsPerType; + /// Maximum estimated container-owned memory per root deserialization. + /// + /// `-1` means auto. Positive values are explicit byte limits. + final int maxContainerMemoryBytes; + /// Creates an immutable configuration object. /// /// Invalid numeric limits fail fast. When [compatible] is `true`, @@ -69,6 +75,7 @@ final class Config { this.maxSchemaVersionsPerType = defaultMaxSchemaVersionsPerType, this.maxAverageSchemaVersionsPerType = defaultMaxAverageSchemaVersionsPerType, + this.maxContainerMemoryBytes = defaultMaxContainerMemoryBytes, }) : checkStructVersion = compatible ? false : checkStructVersion, assert(maxDepth > 0, 'maxDepth must be positive'), assert(maxTypeFields > 0, 'maxTypeFields must be positive'), @@ -80,5 +87,9 @@ final class Config { assert( maxAverageSchemaVersionsPerType > 0, 'maxAverageSchemaVersionsPerType must be positive', + ), + assert( + maxContainerMemoryBytes == -1 || maxContainerMemoryBytes > 0, + 'maxContainerMemoryBytes must be -1 or positive', ); } diff --git a/dart/packages/fory/lib/src/context/read_context.dart b/dart/packages/fory/lib/src/context/read_context.dart index faa8191aba..1acf28c0d6 100644 --- a/dart/packages/fory/lib/src/context/read_context.dart +++ b/dart/packages/fory/lib/src/context/read_context.dart @@ -45,6 +45,15 @@ import 'package:fory/src/types/uint64.dart'; /// deserialization operation. Application code normally interacts with [Fory] /// instead of preparing contexts directly. final class ReadContext { + static const int _knownRootBudgetMultiplier = 8; + static const int _knownRootBudgetSlackBytes = 64 * 1024; + static const int _collectionObjectBytes = 24; + static const int _mapObjectBytes = 48; + static const int _arrayHeaderBytes = 16; + static const int _mapEntryBytes = 32; + static const int _referenceBytes = 4; + static const int _maxSafeBudgetBytes = 9007199254740991; + /// Effective runtime configuration for the active operation. final Config config; final TypeResolver _typeResolver; @@ -54,6 +63,8 @@ final class ReadContext { late Buffer _buffer; final List _sharedTypes = []; int _depth = 0; + int _effectiveContainerMemoryBytes = 0; + int _remainingContainerMemoryBytes = 0; @internal ReadContext( @@ -64,8 +75,20 @@ final class ReadContext { ); @internal + @pragma('vm:prefer-inline') void prepare(Buffer buffer) { _buffer = buffer; + final configured = config.maxContainerMemoryBytes; + final limit = + configured > 0 + ? configured + : buffer.readableBytes * _knownRootBudgetMultiplier + + _knownRootBudgetSlackBytes; + if (limit > _maxSafeBudgetBytes) { + _throwContainerMemoryOverflow(limit); + } + _effectiveContainerMemoryBytes = limit; + _remainingContainerMemoryBytes = limit; } @internal @@ -74,6 +97,8 @@ final class ReadContext { _refReader.reset(); _metaStringReader.reset(); _depth = 0; + _effectiveContainerMemoryBytes = 0; + _remainingContainerMemoryBytes = 0; } /// The active input buffer for the current operation. @@ -85,6 +110,76 @@ final class ReadContext { @internal RefReader get refReader => _refReader; + @internal + int get effectiveContainerMemoryBytes => _effectiveContainerMemoryBytes; + + @internal + int get remainingContainerMemoryBytes => _remainingContainerMemoryBytes; + + @internal + @pragma('vm:prefer-inline') + void reserveCollectionMemory(int numElements) { + final bytes = _collectionObjectBytes + numElements * _referenceBytes; + final remaining = _remainingContainerMemoryBytes - bytes; + if (remaining < 0) { + _throwContainerMemoryExceeded(bytes); + } + _remainingContainerMemoryBytes = remaining; + } + + @internal + @pragma('vm:prefer-inline') + void reserveMapMemory(int numElements) { + final bytes = + _mapObjectBytes + + numElements * + (_referenceBytes * 2 + _mapEntryBytes + _referenceBytes * 3); + final remaining = _remainingContainerMemoryBytes - bytes; + if (remaining < 0) { + _throwContainerMemoryExceeded(bytes); + } + _remainingContainerMemoryBytes = remaining; + } + + @internal + @pragma('vm:prefer-inline') + void reserveTypedArrayMemory(int numElements, int elementBytes) { + final bytes = _arrayHeaderBytes + numElements * elementBytes; + final remaining = _remainingContainerMemoryBytes - bytes; + if (remaining < 0) { + _throwContainerMemoryExceeded(bytes); + } + _remainingContainerMemoryBytes = remaining; + } + + @internal + void reserveContainerMemory(int bytes) { + if (bytes < 0 || bytes > _maxSafeBudgetBytes) { + _throwContainerMemoryOverflow(bytes); + } + final remaining = _remainingContainerMemoryBytes - bytes; + if (remaining < 0) { + _throwContainerMemoryExceeded(bytes); + } + _remainingContainerMemoryBytes = remaining; + } + + @pragma('vm:never-inline') + Never _throwContainerMemoryOverflow(int bytes) { + throw StateError( + 'maxContainerMemoryBytes overflow: requested $bytes estimated container bytes.', + ); + } + + @pragma('vm:never-inline') + Never _throwContainerMemoryExceeded(int bytes) { + throw StateError( + 'maxContainerMemoryBytes exceeded: requested $bytes estimated container bytes, ' + '$_remainingContainerMemoryBytes remaining, effective limit ' + '$_effectiveContainerMemoryBytes.', + ); + } + @internal @pragma('vm:prefer-inline') TypeInfo readTypeMetaValue([TypeInfo? expectedNamedType]) => diff --git a/dart/packages/fory/lib/src/fory.dart b/dart/packages/fory/lib/src/fory.dart index adc6091a8d..48a5f9b133 100644 --- a/dart/packages/fory/lib/src/fory.dart +++ b/dart/packages/fory/lib/src/fory.dart @@ -62,7 +62,16 @@ final class Fory { int maxSchemaVersionsPerType = Config.defaultMaxSchemaVersionsPerType, int maxAverageSchemaVersionsPerType = Config.defaultMaxAverageSchemaVersionsPerType, + int maxContainerMemoryBytes = Config.defaultMaxContainerMemoryBytes, }) { + if (maxContainerMemoryBytes != Config.defaultMaxContainerMemoryBytes && + maxContainerMemoryBytes <= 0) { + throw ArgumentError.value( + maxContainerMemoryBytes, + 'maxContainerMemoryBytes', + 'must be -1 or positive', + ); + } final config = Config( compatible: compatible, checkStructVersion: checkStructVersion, @@ -71,6 +80,7 @@ final class Fory { maxTypeMetaBytes: maxTypeMetaBytes, maxSchemaVersionsPerType: maxSchemaVersionsPerType, maxAverageSchemaVersionsPerType: maxAverageSchemaVersionsPerType, + maxContainerMemoryBytes: maxContainerMemoryBytes, ); _readBuffer = Buffer(); _writeBuffer = Buffer(); diff --git a/dart/packages/fory/lib/src/serializer/collection_serializers.dart b/dart/packages/fory/lib/src/serializer/collection_serializers.dart index 4e2a8050c0..b80839f2d6 100644 --- a/dart/packages/fory/lib/src/serializer/collection_serializers.dart +++ b/dart/packages/fory/lib/src/serializer/collection_serializers.dart @@ -270,10 +270,10 @@ final class ListSerializer extends Serializer { } final declaredTypeInfo = elementFieldType == null || - elementFieldType.isDynamic || - elementFieldType.typeId == TypeIds.unknown - ? null - : context.typeResolver.resolveFieldType(elementFieldType); + elementFieldType.isDynamic || + elementFieldType.typeId == TypeIds.unknown + ? null + : context.typeResolver.resolveFieldType(elementFieldType); final usesDeclaredType = declaredTypeInfo != null && usesDeclaredTypeInfo( @@ -296,8 +296,9 @@ final class ListSerializer extends Serializer { sameType: analysis.sameType, ); context.buffer.writeUint8(header); - final sameTypeInfo = - !usesDeclaredType && analysis.sameType ? analysis.sameTypeInfo : null; + final sameTypeInfo = !usesDeclaredType && analysis.sameType + ? analysis.sameTypeInfo + : null; if (!usesDeclaredType && sameTypeInfo != null && analysis.firstNonNull != null) { @@ -378,13 +379,13 @@ final class SetSerializer extends Serializer { FieldType? elementFieldType, { bool hasPreservedRef = false, }) { - return Set.of( - ListSerializer.readPayload( - context, - elementFieldType, - hasPreservedRef: hasPreservedRef, - ), + final values = ListSerializer.readPayload( + context, + elementFieldType, + hasPreservedRef: hasPreservedRef, ); + context.reserveCollectionMemory(values.length); + return Set.of(values); } } @@ -401,8 +402,9 @@ Object? readCompatibleMatchedCollectionArrayField( final remoteType = remoteField.fieldType; if (isCompatibleArrayType(localType.typeId) && remoteType.typeId == TypeIds.list) { - final elementType = - remoteType.arguments.isEmpty ? null : remoteType.arguments.single; + final elementType = remoteType.arguments.isEmpty + ? null + : remoteType.arguments.single; if (elementType == null || _arrayElementTypeId(localType.typeId) != _compatibleArrayElementTypeId(elementType.typeId)) { @@ -419,8 +421,9 @@ Object? readCompatibleMatchedCollectionArrayField( } if (localType.typeId == TypeIds.list && isCompatibleArrayType(remoteType.typeId)) { - final localElementType = - localType.arguments.isEmpty ? null : localType.arguments.single; + final localElementType = localType.arguments.isEmpty + ? null + : localType.arguments.single; if (localElementType == null || _arrayElementTypeId(remoteType.typeId) != _compatibleArrayElementTypeId(localElementType.typeId)) { @@ -429,7 +432,7 @@ Object? readCompatibleMatchedCollectionArrayField( ); } final raw = readCompatibleField(context, remoteField); - return _arrayToListValue(raw); + return _arrayToListValue(context, raw); } return readFieldValue(context, localField); } @@ -490,8 +493,9 @@ bool _listElementMatchesArray( int arrayTypeId, { required bool requireUnframedElement, }) { - final elementType = - listType.arguments.isEmpty ? null : listType.arguments.single; + final elementType = listType.arguments.isEmpty + ? null + : listType.arguments.single; // Nullable element schema is allowed for list -> array; actual // null payload elements fail in the dense-array reader. Ref-tracked // element framing is rejected here because this path stays primitive-only. @@ -508,6 +512,7 @@ Object _readCompatibleListAsArrayField( String fieldName, ) { final size = context.buffer.readVarUint32(); + context.reserveTypedArrayMemory(size, _arrayElementBytes(arrayTypeId)); if (size == 0) { return _newArrayValue(arrayTypeId, 0); } @@ -570,6 +575,21 @@ int _compatibleArrayElementTypeId(int typeId) { }; } +int _arrayElementBytes(int arrayTypeId) { + return switch (arrayTypeId) { + TypeIds.boolArray || TypeIds.int8Array || TypeIds.uint8Array => 1, + TypeIds.int16Array || + TypeIds.uint16Array || + TypeIds.float16Array || + TypeIds.bfloat16Array => 2, + TypeIds.int32Array || TypeIds.uint32Array || TypeIds.float32Array => 4, + TypeIds.int64Array || TypeIds.uint64Array || TypeIds.float64Array => 8, + _ => throw StateError( + 'Unsupported compatible array field type $arrayTypeId.', + ), + }; +} + Object _newArrayValue(int arrayTypeId, int length) { return switch (arrayTypeId) { TypeIds.boolArray => BoolList(length), @@ -585,8 +605,9 @@ Object _newArrayValue(int arrayTypeId, int length) { TypeIds.bfloat16Array => Bfloat16List(length), TypeIds.float32Array => Float32List(length), TypeIds.float64Array => Float64List(length), - _ => - throw StateError('Unsupported compatible array field type $arrayTypeId.'), + _ => throw StateError( + 'Unsupported compatible array field type $arrayTypeId.', + ), }; } @@ -601,8 +622,9 @@ void _setArrayValue(Object target, int arrayTypeId, int index, Object? value) { case TypeIds.int32Array: (target as Int32List)[index] = value as int; case TypeIds.int64Array: - (target as Int64List)[index] = - value is int ? Int64(value) : value as Int64; + (target as Int64List)[index] = value is int + ? Int64(value) + : value as Int64; case TypeIds.uint8Array: (target as Uint8List)[index] = value as int; case TypeIds.uint16Array: @@ -610,8 +632,9 @@ void _setArrayValue(Object target, int arrayTypeId, int index, Object? value) { case TypeIds.uint32Array: (target as Uint32List)[index] = value as int; case TypeIds.uint64Array: - (target as Uint64List)[index] = - value is int ? Uint64(value) : value as Uint64; + (target as Uint64List)[index] = value is int + ? Uint64(value) + : value as Uint64; case TypeIds.float16Array: (target as Float16List)[index] = value as double; case TypeIds.bfloat16Array: @@ -625,11 +648,13 @@ void _setArrayValue(Object target, int arrayTypeId, int index, Object? value) { } } -Object _arrayToListValue(Object? raw) { +Object _arrayToListValue(ReadContext context, Object? raw) { if (raw is BoolList) { + context.reserveCollectionMemory(raw.length); return raw.toList(); } if (raw is Iterable) { + context.reserveCollectionMemory(raw.length); return raw.toList(); } throw StateError('Expected compatible array payload.'); @@ -650,29 +675,29 @@ List readTypedListPayload( } final directTypeInfo = state.declaredTypeInfo ?? state.sameTypeInfo; if (directTypeInfo != null && !state.trackRef && !state.hasNull) { - final directFieldType = - state.declaredTypeInfo != null ? state.elementFieldType : null; + final directFieldType = state.declaredTypeInfo != null + ? state.elementFieldType + : null; if (directTypeInfo.type == T && directTypeInfo.kind == RegistrationKind.struct) { final structSerializer = directTypeInfo.structSerializer!; context.buffer.checkReadableBytes(state.size); - final result = - directTypeInfo.remoteTypeDef == null - ? List.generate( - state.size, - (_) => structSerializer.readValue(context, directTypeInfo) as T, - growable: false, - ) - : List.generate( - state.size, - (_) => - structSerializer.readGeneratedCompatibleValue( - context, - directTypeInfo, - ) - as T, - growable: false, - ); + final result = directTypeInfo.remoteTypeDef == null + ? List.generate( + state.size, + (_) => structSerializer.readValue(context, directTypeInfo) as T, + growable: false, + ) + : List.generate( + state.size, + (_) => + structSerializer.readGeneratedCompatibleValue( + context, + directTypeInfo, + ) + as T, + growable: false, + ); if (state.tracksDepth) { context.decreaseDepth(); } @@ -719,7 +744,9 @@ Set readTypedSetPayload( FieldType? elementFieldType, T Function(Object? value) convert, ) { - return Set.of(readTypedListPayload(context, elementFieldType, convert)); + final values = readTypedListPayload(context, elementFieldType, convert); + context.reserveCollectionMemory(values.length); + return Set.of(values); } void writeTypedListPayload( @@ -910,6 +937,7 @@ _PreparedListRead _prepareListRead( FieldType? elementFieldType, ) { final size = context.buffer.readVarUint32(); + context.reserveCollectionMemory(size); if (size == 0) { return _PreparedListRead( size: 0, @@ -936,15 +964,13 @@ _PreparedListRead _prepareListRead( elementFieldType != null && (usesDeclaredType || (sameType && TypeIds.isUserType(elementFieldType.typeId))); - final expectedElementTypeInfo = - needsExpectedElementType - ? context.typeResolver.tryResolveFieldType(elementFieldType) - : null; + final expectedElementTypeInfo = needsExpectedElementType + ? context.typeResolver.tryResolveFieldType(elementFieldType) + : null; final declaredTypeInfo = usesDeclaredType ? expectedElementTypeInfo : null; - final sameTypeInfo = - (!usesDeclaredType && sameType) - ? context.readTypeMetaValue(expectedElementTypeInfo) - : null; + final sameTypeInfo = (!usesDeclaredType && sameType) + ? context.readTypeMetaValue(expectedElementTypeInfo) + : null; final tracksDepth = (declaredTypeInfo != null && tracksNestedPayloadDepth(declaredTypeInfo)) || diff --git a/dart/packages/fory/lib/src/serializer/map_serializers.dart b/dart/packages/fory/lib/src/serializer/map_serializers.dart index 051454c3d6..0391699b23 100644 --- a/dart/packages/fory/lib/src/serializer/map_serializers.dart +++ b/dart/packages/fory/lib/src/serializer/map_serializers.dart @@ -56,14 +56,13 @@ final class MapSerializer extends Serializer { required bool trackRef, }) { context.buffer.writeVarUint32(values.length); - final declaredKeyTypeInfo = - keyFieldType == null || keyFieldType.isDynamic - ? null - : context.typeResolver.resolveFieldType(keyFieldType); + final declaredKeyTypeInfo = keyFieldType == null || keyFieldType.isDynamic + ? null + : context.typeResolver.resolveFieldType(keyFieldType); final declaredValueTypeInfo = valueFieldType == null || valueFieldType.isDynamic - ? null - : context.typeResolver.resolveFieldType(valueFieldType); + ? null + : context.typeResolver.resolveFieldType(valueFieldType); final keyDeclared = declaredKeyTypeInfo != null && usesDeclaredTypeInfo( @@ -106,17 +105,17 @@ final class MapSerializer extends Serializer { (keyDeclared ? declaredKeyTypeInfo.supportsRef : (key == null || - context.typeResolver - .resolveValue(key as Object) - .supportsRef)); + context.typeResolver + .resolveValue(key as Object) + .supportsRef)); final valueTrackRef = valueRequestedRef && (valueDeclared ? declaredValueTypeInfo.supportsRef : (value == null || - context.typeResolver - .resolveValue(value as Object) - .supportsRef)); + context.typeResolver + .resolveValue(value as Object) + .supportsRef)); _writeNullChunk( context, key, @@ -132,14 +131,12 @@ final class MapSerializer extends Serializer { ); continue; } - final chunkKeyTypeInfo = - keyDeclared - ? declaredKeyTypeInfo - : context.typeResolver.resolveValue(key as Object); - final chunkValueTypeInfo = - valueDeclared - ? declaredValueTypeInfo - : context.typeResolver.resolveValue(value as Object); + final chunkKeyTypeInfo = keyDeclared + ? declaredKeyTypeInfo + : context.typeResolver.resolveValue(key as Object); + final chunkValueTypeInfo = valueDeclared + ? declaredValueTypeInfo + : context.typeResolver.resolveValue(value as Object); final chunkKeyTrackRef = keyRequestedRef && chunkKeyTypeInfo.supportsRef; final chunkValueTrackRef = valueRequestedRef && chunkValueTypeInfo.supportsRef; @@ -189,14 +186,12 @@ final class MapSerializer extends Serializer { pendingEntry = nextEntry; break; } - final nextKeyTypeInfo = - keyDeclared - ? declaredKeyTypeInfo - : context.typeResolver.resolveValue(nextKey as Object); - final nextValueTypeInfo = - valueDeclared - ? declaredValueTypeInfo - : context.typeResolver.resolveValue(nextValue as Object); + final nextKeyTypeInfo = keyDeclared + ? declaredKeyTypeInfo + : context.typeResolver.resolveValue(nextKey as Object); + final nextValueTypeInfo = valueDeclared + ? declaredValueTypeInfo + : context.typeResolver.resolveValue(nextValue as Object); final nextKeyTrackRef = keyRequestedRef && nextKeyTypeInfo.supportsRef; final nextValueTrackRef = valueRequestedRef && nextValueTypeInfo.supportsRef; @@ -257,14 +252,15 @@ Map readTypedMapPayload( bool hasPreservedRef = false, }) { var remaining = context.buffer.readVarUint32(); - final declaredKeyTypeInfo = - keyFieldType == null || keyFieldType.isDynamic - ? null - : context.typeResolver.resolveFieldType(keyFieldType); + context.reserveMapMemory(remaining); + context.buffer.checkReadableBytes(remaining); + final declaredKeyTypeInfo = keyFieldType == null || keyFieldType.isDynamic + ? null + : context.typeResolver.resolveFieldType(keyFieldType); final declaredValueTypeInfo = valueFieldType == null || valueFieldType.isDynamic - ? null - : context.typeResolver.resolveFieldType(valueFieldType); + ? null + : context.typeResolver.resolveFieldType(valueFieldType); final result = {}; if (hasPreservedRef) { context.reference(result); @@ -312,34 +308,32 @@ Map readTypedMapPayload( context.increaseDepth(); } for (var index = 0; index < chunkSize; index += 1) { - final key = - keyDeclared - ? _readDeclaredMapValue( - context, - keyFieldType!, - declaredKeyTypeInfo!, - trackRef: keyTrackRef, - ) - : _readResolvedMapValue( - context, - keyTypeInfo!, - null, - trackRef: keyTrackRef, - ); - final value = - valueDeclared - ? _readDeclaredMapValue( - context, - valueFieldType!, - declaredValueTypeInfo!, - trackRef: valueTrackRef, - ) - : _readResolvedMapValue( - context, - valueTypeInfo!, - null, - trackRef: valueTrackRef, - ); + final key = keyDeclared + ? _readDeclaredMapValue( + context, + keyFieldType!, + declaredKeyTypeInfo!, + trackRef: keyTrackRef, + ) + : _readResolvedMapValue( + context, + keyTypeInfo!, + null, + trackRef: keyTrackRef, + ); + final value = valueDeclared + ? _readDeclaredMapValue( + context, + valueFieldType!, + declaredValueTypeInfo!, + trackRef: valueTrackRef, + ) + : _readResolvedMapValue( + context, + valueTypeInfo!, + null, + trackRef: valueTrackRef, + ); result[convertKey(key)] = convertValue(value); } if (tracksDepth) { diff --git a/dart/packages/fory/test/container_memory_budget_test.dart b/dart/packages/fory/test/container_memory_budget_test.dart new file mode 100644 index 0000000000..61d6970300 --- /dev/null +++ b/dart/packages/fory/test/container_memory_budget_test.dart @@ -0,0 +1,280 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +import 'dart:typed_data'; + +import 'package:fory/fory.dart'; +import 'package:fory/src/context/meta_string_reader.dart'; +import 'package:fory/src/context/ref_reader.dart'; +import 'package:fory/src/resolver/type_resolver.dart'; +import 'package:fory/src/serializer/collection_serializers.dart'; +import 'package:fory/src/serializer/map_serializers.dart'; +import 'package:test/test.dart'; + +part 'container_memory_budget_test.fory.dart'; + +const Matcher _throwsContainerBudget = ThrowsContainerBudget(); + +@ForyStruct() +class BudgetGeneratedEnvelope { + BudgetGeneratedEnvelope(); + + @ListField(element: Int32Type(encoding: Encoding.fixed)) + List ids = []; + + @SetField(element: StringType()) + Set tags = {}; + + @MapField( + key: StringType(), + value: Int32Type(encoding: Encoding.fixed), + ) + Map counts = {}; +} + +@ForyStruct() +class BudgetCompatibleListEnvelope { + BudgetCompatibleListEnvelope(); + + @ListField(element: Int32Type(encoding: Encoding.fixed)) + List values = []; +} + +@ForyStruct() +class BudgetCompatibleArrayEnvelope { + BudgetCompatibleArrayEnvelope(); + + @ArrayField(element: Int32Type()) + Int32List values = Int32List(0); +} + +final class ThrowsContainerBudget extends Matcher { + const ThrowsContainerBudget(); + + @override + Description describe(Description description) { + return description.add('throws a maxContainerMemoryBytes StateError'); + } + + @override + bool matches(Object? item, Map matchState) { + if (item is! Function) { + return false; + } + try { + item(); + } on StateError catch (error) { + return error.message.contains('maxContainerMemoryBytes'); + } + return false; + } +} + +void _registerGenerated(Fory fory) { + ContainerMemoryBudgetTestForyModule.register( + fory, + BudgetGeneratedEnvelope, + name: 'test.BudgetGeneratedEnvelope', + ); +} + +void _registerCompatibleList(Fory fory) { + ContainerMemoryBudgetTestForyModule.register( + fory, + BudgetCompatibleListEnvelope, + name: 'test.BudgetCompatibleEnvelope', + ); +} + +void _registerCompatibleArray(Fory fory) { + ContainerMemoryBudgetTestForyModule.register( + fory, + BudgetCompatibleArrayEnvelope, + name: 'test.BudgetCompatibleEnvelope', + ); +} + +ReadContext _readContext(Buffer buffer, {int maxContainerMemoryBytes = -1}) { + final config = Config(maxContainerMemoryBytes: maxContainerMemoryBytes); + final resolver = TypeResolver(config); + return ReadContext(config, resolver, RefReader(), MetaStringReader(resolver)) + ..prepare(buffer); +} + +Uint8List _serialize(Object? value) => Fory().serialize(value); + +Object? _readWithBudget(Object? value, int budget) { + return Fory( + maxContainerMemoryBytes: budget, + ).deserialize(_serialize(value)); +} + +void main() { + group('container memory budget', () { + test('known length auto derives from input bytes', () { + final buffer = Buffer.wrap(Uint8List(17)); + final context = _readContext(buffer); + + expect(context.effectiveContainerMemoryBytes, equals(17 * 8 + 64 * 1024)); + expect( + () => context.reserveContainerMemory(17 * 8 + 64 * 1024), + returnsNormally, + ); + expect(() => context.reserveContainerMemory(1), _throwsContainerBudget); + }); + + test('explicit config overrides auto', () { + final buffer = Buffer.wrap(Uint8List(4096)); + final context = _readContext(buffer, maxContainerMemoryBytes: 31); + + expect(context.effectiveContainerMemoryBytes, equals(31)); + expect(() => context.reserveContainerMemory(31), returnsNormally); + expect(() => context.reserveContainerMemory(1), _throwsContainerBudget); + expect(() => Fory(maxContainerMemoryBytes: 0), throwsArgumentError); + expect(() => Fory(maxContainerMemoryBytes: -2), throwsArgumentError); + }); + + test('charges nested empty containers', () { + final value = [[]]; + + expect(() => _readWithBudget(value, 51), _throwsContainerBudget); + expect(_readWithBudget(value, 52), equals(value)); + }); + + test('charges sibling containers cumulatively', () { + final value = [[], [], []]; + + expect(() => _readWithBudget(value, 107), _throwsContainerBudget); + expect(_readWithBudget(value, 108), equals(value)); + }); + + test('charges map table and entries', () { + final value = {'a': 1}; + + expect(() => _readWithBudget(value, 99), _throwsContainerBudget); + expect(_readWithBudget(value, 100), equals(value)); + }); + + test('charges generated list set and map reads', () { + final writer = Fory(); + _registerGenerated(writer); + final bytes = writer.serialize( + BudgetGeneratedEnvelope() + ..ids = [1] + ..tags = {'x'} + ..counts = {'one': 1}, + ); + + final failingReader = Fory(maxContainerMemoryBytes: 183); + _registerGenerated(failingReader); + expect( + () => failingReader.deserialize(bytes), + _throwsContainerBudget, + ); + + final passingReader = Fory(maxContainerMemoryBytes: 184); + _registerGenerated(passingReader); + final roundTrip = passingReader.deserialize( + bytes, + ); + expect(roundTrip.ids, equals([1])); + expect(roundTrip.tags, equals({'x'})); + expect(roundTrip.counts, equals({'one': 1})); + }); + + test('charges compatible list array materialization', () { + final listWriter = Fory(); + _registerCompatibleList(listWriter); + final listBytes = listWriter.serialize( + BudgetCompatibleListEnvelope()..values = [1, 2, 3], + ); + + final arrayFail = Fory(maxContainerMemoryBytes: 27); + _registerCompatibleArray(arrayFail); + expect( + () => arrayFail.deserialize(listBytes), + _throwsContainerBudget, + ); + + final arrayPass = Fory(maxContainerMemoryBytes: 28); + _registerCompatibleArray(arrayPass); + expect( + arrayPass + .deserialize(listBytes) + .values + .toList(), + equals([1, 2, 3]), + ); + + final arrayWriter = Fory(); + _registerCompatibleArray(arrayWriter); + final arrayBytes = arrayWriter.serialize( + BudgetCompatibleArrayEnvelope() + ..values = Int32List.fromList([1, 2, 3]), + ); + + final listFail = Fory(maxContainerMemoryBytes: 35); + _registerCompatibleList(listFail); + expect( + () => listFail.deserialize(arrayBytes), + _throwsContainerBudget, + ); + + final listPass = Fory(maxContainerMemoryBytes: 36); + _registerCompatibleList(listPass); + expect( + listPass.deserialize(arrayBytes).values, + equals([1, 2, 3]), + ); + }); + + test('skips strings binary and dense typed arrays', () { + final fory = Fory(maxContainerMemoryBytes: 1); + final text = List.filled(128, 'x').join(); + + expect(fory.deserialize(Fory().serialize(text)), hasLength(128)); + expect( + fory.deserialize(Fory().serialize(Uint8List(128))).length, + equals(128), + ); + expect( + fory.deserialize(Fory().serialize(Int32List(32))).length, + equals(32), + ); + }); + + test('keeps byte availability checks before allocation', () { + final listBuffer = Buffer() + ..writeVarUint32(64) + ..writeUint8(0); + final listContext = _readContext(listBuffer); + expect( + () => ListSerializer.readPayload(listContext, null), + throwsStateError, + ); + + final mapBuffer = Buffer()..writeVarUint32(64); + final mapContext = _readContext(mapBuffer); + expect( + () => MapSerializer.readPayload(mapContext, null, null), + throwsStateError, + ); + }); + }); +} diff --git a/docs/guide/cpp/configuration.md b/docs/guide/cpp/configuration.md index d617450041..aee6e633d5 100644 --- a/docs/guide/cpp/configuration.md +++ b/docs/guide/cpp/configuration.md @@ -96,6 +96,29 @@ When enabled, avoids duplicating shared objects and handles cycles. **Default:** `true` +### max_container_memory_bytes(int64_t) + +Set the maximum estimated memory that container objects may reserve during one +root deserialization. + +```cpp +auto fory = Fory::builder() + .max_container_memory_bytes(64 * 1024 * 1024) + .build(); +``` + +Use `-1` for the automatic limit. For byte-array and `Buffer` roots, the +automatic limit is the root input size multiplied by `8`, plus `64 KiB`. For +stream roots, the automatic limit is `128 MiB` because the full root size is not +known up front. Positive values always override the automatic limit. + +This budget is an estimate for container-owned memory such as collection +objects, backing storage, map entries, and object/reference arrays. It is not an +exact process heap limit. Dedicated string, binary, and primitive dense-array +payloads continue to rely on their byte-availability checks instead. + +**Default:** `-1` + ### max_dyn_depth(uint32_t) Set maximum allowed nesting depth for dynamically-typed objects. @@ -205,6 +228,7 @@ auto fory = Fory::builder().build_thread_safe(); // Returns ThreadSafeFory | `xlang(bool)` | Use xlang mode | `true` | | `compatible(bool)` | Enable schema evolution | `true` | | `track_ref(bool)` | Enable reference tracking | `true` | +| `max_container_memory_bytes(int64_t)` | Max estimated container memory per root read | `-1` | | `max_dyn_depth(uint32_t)` | Maximum nesting depth for dynamic types | `5` | | `max_type_fields(uint32_t)` | Max fields in one received struct metadata body | `512` | | `max_type_meta_bytes(uint32_t)` | Max encoded bytes in one received metadata body | `4096` | @@ -218,6 +242,8 @@ Security-related configuration: - Register all structs and polymorphic implementations before deserializing untrusted payloads. - Use `check_struct_version(true)` with `compatible(false)` for intentional same-schema payloads. +- Leave `max_container_memory_bytes(-1)` enabled for automatic root-size-based container limits, or + set a positive value for a stricter trusted-workload envelope. - Keep `max_dyn_depth(...)` as low as your model permits to reject unexpectedly deep polymorphic graphs. - Keep the remote schema metadata limits at their defaults unless the data is not malicious and a diff --git a/docs/guide/csharp/configuration.md b/docs/guide/csharp/configuration.md index e7c0c24d42..c9e8e80cf6 100644 --- a/docs/guide/csharp/configuration.md +++ b/docs/guide/csharp/configuration.md @@ -41,6 +41,7 @@ ThreadSafeFory threadSafe = Fory.Builder().BuildThreadSafe(); | `Compatible` | `true` | Compatible schema-evolution metadata enabled | | `CheckStructVersion` | `false` | Struct schema hash checks disabled | | `MaxDepth` | `20` | Max dynamic nesting depth | +| `MaxContainerMemoryBytes` | `-1` | Auto container memory budget | | `MaxTypeFields` | `512` | Max fields in one received struct metadata body | | `MaxTypeMetaBytes` | `4096` | Max encoded bytes in one received metadata body | | `MaxSchemaVersionsPerType` | `10` | Max remote metadata versions for one logical type | @@ -96,6 +97,20 @@ Fory fory = Fory.Builder() `value` must be greater than `0`. +### `MaxContainerMemoryBytes(long value)` + +Sets the maximum estimated container-owned memory accepted during one root deserialization. + +```csharp +Fory fory = Fory.Builder() + .MaxContainerMemoryBytes(64L * 1024 * 1024) + .Build(); +``` + +Use `-1` for the default automatic limit. For current C# inputs, auto uses the root input byte +length times `8`, plus `64 KiB`. A positive value overrides the automatic limit. `0` and negative +values other than `-1` are rejected. + ### `MaxTypeFields(int value)` Sets the maximum fields accepted in one received remote struct metadata body. @@ -173,6 +188,8 @@ Security-related configuration: - Register only the expected types before deserializing untrusted payloads. - Use `CheckStructVersion(true)` with `Compatible(false)` for intentional same-schema payloads. - Set `MaxDepth(...)` to reject unexpectedly deep dynamic object graphs. +- Set `MaxContainerMemoryBytes(...)` to cap estimated list, array, set, and map memory during one + root deserialization. - Keep the remote schema metadata limits at their defaults unless the data is not malicious and a trusted peer sends larger metadata or many schema versions. - Prefer generated or registered concrete models over broad dynamic fields for untrusted input. diff --git a/docs/guide/dart/configuration.md b/docs/guide/dart/configuration.md index 6a4c640f6a..c7f851e253 100644 --- a/docs/guide/dart/configuration.md +++ b/docs/guide/dart/configuration.md @@ -38,6 +38,7 @@ final fory = Fory( maxTypeMetaBytes: 4096, maxSchemaVersionsPerType: 10, maxAverageSchemaVersionsPerType: 3, + maxContainerMemoryBytes: 64 * 1024 * 1024, ); ``` @@ -107,6 +108,27 @@ final fory = Fory( - `maxAverageSchemaVersionsPerType` limits the average across accepted remote types. The effective global floor is `8192` schemas. +### `maxContainerMemoryBytes` + +Limits estimated container-owned memory for one root deserialization. The budget covers Dart lists, +sets, maps, object/reference arrays, and compatible list/array materialization. It does not count +strings, binary values, or dense typed-array payloads, which are protected by byte-availability +checks. + +The default is `-1`, which means auto. Dart root inputs are memory-backed, so auto derives from the +root input size: + +```text +inputBytes * 8 + 64 KiB +``` + +Set a positive value when a trusted workload legitimately contains compact, container-heavy +payloads: + +```dart +final fory = Fory(maxContainerMemoryBytes: 256 * 1024 * 1024); +``` + ## Defaults | Option | Default | @@ -118,6 +140,7 @@ final fory = Fory( | `maxTypeMetaBytes` | 4096 | | `maxSchemaVersionsPerType` | 10 | | `maxAverageSchemaVersionsPerType` | 3 | +| `maxContainerMemoryBytes` | -1 | ## Xlang Notes @@ -134,6 +157,8 @@ Security-related configuration: - Register only the expected generated models before deserializing untrusted payloads. - Use `checkStructVersion: true` with `compatible: false` for intentional same-schema payloads. - Set `maxDepth` to reject unexpectedly deep payload shapes. +- Keep `maxContainerMemoryBytes` at the auto default for most inputs, or set an explicit positive + byte limit for known trusted container-heavy payloads. - Keep the remote schema metadata limits at their defaults unless the data is not malicious and a trusted peer sends larger metadata or many schema versions. - Prefer generated schemas and explicit field metadata over broad dynamic fields for untrusted input. diff --git a/docs/guide/go/configuration.md b/docs/guide/go/configuration.md index 20d9012aee..d1cb294c9c 100644 --- a/docs/guide/go/configuration.md +++ b/docs/guide/go/configuration.md @@ -39,6 +39,7 @@ Default settings: | MaxDepth | 20 | Maximum nesting depth | | IsXlang | true | Xlang mode enabled | | Compatible | true | Compatible schema-evolution metadata enabled | +| MaxContainerMemoryBytes | -1 | Automatic container memory limit per root read | | MaxTypeFields | 512 | Max fields in one received struct metadata body | | MaxTypeMetaBytes | 4096 | Max encoded bytes in one received metadata body | | MaxSchemaVersionsPerType | 10 | Max remote metadata versions for one logical type | @@ -51,6 +52,7 @@ f := fory.New( fory.WithXlang(true), fory.WithTrackRef(true), fory.WithMaxDepth(10), + fory.WithMaxContainerMemoryBytes(-1), fory.WithMaxTypeFields(512), fory.WithMaxTypeMetaBytes(4096), fory.WithMaxSchemaVersionsPerType(10), @@ -127,6 +129,27 @@ f := fory.New(fory.WithMaxDepth(30)) - Protects against deeply nested, recursive structures or malicious data - Serialization fails with error when exceeded +### WithMaxContainerMemoryBytes + +Limit estimated container-owned memory accepted during one root deserialization: + +```go +f := fory.New(fory.WithMaxContainerMemoryBytes(256 * 1024 * 1024)) +``` + +The default `-1` selects an automatic limit. Byte-slice roots use: + +```text +inputBytes * 8 + 64 KiB +``` + +`DeserializeFromReader` and `DeserializeFromStream` use `128 MiB` because the +full root length is unknown. The budget covers Go slices, maps, sets, and +generated container reads. Strings, binary blobs, and primitive dense array +owners keep their byte-availability checks and are not charged to this budget. +Set a positive value when a service needs a stricter or larger limit for trusted +data. + ### WithMaxTypeFields Set the maximum fields accepted in one received remote struct metadata body: diff --git a/docs/guide/java/configuration.md b/docs/guide/java/configuration.md index 7b3ce60bf6..4e46d512a7 100644 --- a/docs/guide/java/configuration.md +++ b/docs/guide/java/configuration.md @@ -38,6 +38,7 @@ This page documents all configuration options available through `ForyBuilder`. | `registerGuavaTypes` | Whether to pre-register Guava types such as `RegularImmutableMap`/`RegularImmutableList`. These types are not public API, but seem pretty stable. | `true` | | `requireClassRegistration` | Disabling may allow unknown classes to be deserialized, potentially causing security risks. | `true` | | `maxDepth` | Set max depth for deserialization, when depth exceeds, an exception will be thrown. This can be used to refuse deserialization DDOS attack. | `50` | +| `maxContainerMemoryBytes` | Maximum estimated container-owned memory accepted during one root deserialization. `-1` derives an automatic limit from the input shape: known-length inputs use `inputBytes * 8 + 64 KiB`, and stream or unknown-length inputs use `128 MiB`. Positive values set an explicit byte limit. | `-1` | | `maxTypeFields` | Maximum fields accepted in one received remote struct metadata body. | `512` | | `maxTypeMetaBytes` | Maximum encoded body bytes accepted for one received TypeDef or TypeMeta body, excluding the 8-byte header and any extended-size varint. | `4096` | | `maxSchemaVersionsPerType` | Maximum accepted remote metadata versions for one logical type. | `10` | @@ -90,6 +91,7 @@ Keep class registration enabled for production and any untrusted payload source: Fory fory = Fory.builder() .requireClassRegistration(true) .withMaxDepth(50) + .withMaxContainerMemoryBytes(-1) .build(); ``` @@ -97,6 +99,9 @@ Security-related options: - `requireClassRegistration(true)` restricts deserialization to registered classes. - `withMaxDepth(...)` rejects unexpectedly deep object graphs. +- `withMaxContainerMemoryBytes(...)` bounds estimated container-owned memory during one root + deserialization. Keep `-1` for the automatic input-shaped default, or set a positive byte limit + when trusted payloads need a larger or smaller limit. - `withMaxTypeFields(...)` and `withMaxTypeMetaBytes(...)` bound the field count and encoded body size of one received remote metadata body. - `withMaxSchemaVersionsPerType(...)` and diff --git a/docs/guide/javascript/configuration.md b/docs/guide/javascript/configuration.md index 058bccf4b3..71175c301e 100644 --- a/docs/guide/javascript/configuration.md +++ b/docs/guide/javascript/configuration.md @@ -43,6 +43,7 @@ const fory = new Fory({ ref: true, compatible: true, maxDepth: 100, + maxContainerMemoryBytes: -1, maxTypeFields: 512, maxTypeMetaBytes: 4096, maxSchemaVersionsPerType: 10, @@ -56,6 +57,7 @@ const fory = new Fory({ | `ref` | `false` | Enable reference tracking for shared or circular object graphs | | `compatible` | `true` | Allow field additions/removals without breaking existing messages | | `maxDepth` | `50` | Maximum nesting depth. Must be `>= 2`. Increase for deeply nested structures | +| `maxContainerMemoryBytes` | `-1` | Maximum estimated container-owned memory accepted during one root deserialization | | `maxTypeFields` | `512` | Maximum fields accepted in one received remote struct metadata body | | `maxTypeMetaBytes` | `4096` | Maximum encoded body bytes accepted for one received TypeMeta body | | `maxSchemaVersionsPerType` | `10` | Maximum accepted remote metadata versions for one logical type | @@ -92,6 +94,26 @@ to that struct. For cross-language payloads, set `compatible: false` only after verifying that every language uses the same schema, or when native types are generated from Fory schema IDL. See [Schema Evolution](schema-evolution.md). +## Container Memory Budget + +`maxContainerMemoryBytes` limits estimated memory committed by arrays, sets, +maps, and container backing storage during one root deserialization. The default +`-1` derives an automatic limit from the input bytes. JavaScript deserializes +from `Uint8Array` roots, so the automatic limit is `inputBytes * 8 + 64 KiB`. + +Use a positive byte value to set an explicit lower or higher limit: + +```ts +const fory = new Fory({ + maxContainerMemoryBytes: 32 * 1024 * 1024, +}); +``` + +String, binary, and dedicated dense primitive array payloads keep their normal +byte-size checks and do not consume this container budget. Raise the limit only +for trusted workloads that legitimately contain very compact, container-heavy +graphs. + ## Optional HPS String Path `@apache-fory/hps` provides an optional Node.js string fast path: @@ -110,6 +132,8 @@ Security-related configuration: - Register only the expected schemas before deserializing untrusted payloads. - Set `maxDepth` for the maximum nesting depth your service accepts. +- Set `maxContainerMemoryBytes` for the maximum container memory your service + accepts from one root payload. - Keep `maxTypeFields` and `maxTypeMetaBytes` at their defaults unless the data is not malicious and a trusted peer sends larger remote metadata. - Keep `maxSchemaVersionsPerType` and diff --git a/docs/guide/python/configuration.md b/docs/guide/python/configuration.md index fdd6459fea..26cb42e50f 100644 --- a/docs/guide/python/configuration.md +++ b/docs/guide/python/configuration.md @@ -40,6 +40,7 @@ class Fory: max_type_meta_bytes: int = 4096, max_schema_versions_per_type: int = 10, max_average_schema_versions_per_type: int = 3, + max_container_memory_bytes: int = -1, policy: DeserializationPolicy = None, field_nullable: bool = False, meta_compressor=None, @@ -70,6 +71,7 @@ class ThreadSafeFory: | `max_type_meta_bytes` | `int` | `4096` | Maximum encoded body bytes accepted for one received TypeDef body, excluding the 8-byte header and any extended-size varint. | | `max_schema_versions_per_type` | `int` | `10` | Maximum accepted remote metadata versions for one logical type. | | `max_average_schema_versions_per_type` | `int` | `3` | Average accepted remote metadata versions across accepted remote types. The effective global floor is `8192` schemas. | +| `max_container_memory_bytes` | `int` | `-1` | Maximum estimated container-owned memory for one root deserialization. `-1` selects the automatic limit. | | `policy` | `DeserializationPolicy \| None` | `None` | Deserialization policy used for security checks. Strongly recommended when `strict=False`. | | `field_nullable` | `bool` | `False` | Treat dataclass fields as nullable by default. | | `meta_compressor` | `Any` | `None` | Optional metadata compressor used for compatible-mode metadata encoding. | @@ -197,6 +199,7 @@ fory = pyfory.Fory( max_type_meta_bytes=4096, max_schema_versions_per_type=10, max_average_schema_versions_per_type=3, + max_container_memory_bytes=-1, ) fory.register(UserModel, name="example.User") @@ -222,6 +225,10 @@ Received remote metadata is also limited: - `max_type_meta_bytes` limits the encoded body bytes accepted for one received TypeDef body. - `max_schema_versions_per_type` limits accepted remote metadata versions for one logical type. - `max_average_schema_versions_per_type` limits the average across accepted remote types. +- `max_container_memory_bytes` limits estimated list, tuple, set, dict, and object-array storage + created during one root deserialization. The default `-1` uses `input_bytes * 8 + 64 KiB` for + known-length inputs and `128 MiB` for stream inputs. Set a positive byte value for trusted + payloads that legitimately contain larger container graphs. These limits do not change `strict`, `policy`, dynamic loading, unknown-class handling, or schema-evolution semantics. @@ -278,6 +285,7 @@ unchanged. - Register all expected application types before deserialization. - Use `DeserializationPolicy` when `strict=False` is necessary. - Keep `max_depth` low enough to reject unexpectedly deep payloads. +- Keep `max_container_memory_bytes=-1` unless a trusted workload needs a higher explicit limit. - Do not treat xlang/native mode choice as a security control. ## Related Topics diff --git a/docs/guide/rust/configuration.md b/docs/guide/rust/configuration.md index 58bd070567..6e04126bf0 100644 --- a/docs/guide/rust/configuration.md +++ b/docs/guide/rust/configuration.md @@ -110,6 +110,30 @@ let fory = Fory::builder() - `max_average_schema_versions_per_type` defaults to `3` and limits the average across accepted remote types. The effective global floor is `8192` schemas. +### Container Memory Budget + +`max_container_memory_bytes(...)` limits the estimated memory that deserialization may allocate for +containers such as lists, sets, and maps during one root read. The default is `-1`, which selects an +automatic limit based on the input size: + +```rust +let fory = Fory::builder().max_container_memory_bytes(-1).build(); +``` + +For byte-slice and `Reader` roots, the automatic limit is: + +```text +input bytes * 8 + 64 KiB +``` + +Set a positive byte value when trusted payloads need a larger or smaller limit: + +```rust +let fory = Fory::builder() + .max_container_memory_bytes(256 * 1024 * 1024) + .build(); +``` + ### Explicit Xlang Examples Set `.xlang(true)` explicitly for xlang serialization examples: @@ -135,6 +159,11 @@ let fory = Fory::builder().xlang(false).compatible(false).build(); // Custom depth limit let fory = Fory::builder().max_dyn_depth(10).build(); +// Custom container memory budget +let fory = Fory::builder() + .max_container_memory_bytes(256 * 1024 * 1024) + .build(); + // Combined configuration let fory = Fory::builder() .xlang(false) @@ -149,6 +178,7 @@ let fory = Fory::builder() | `compatible(bool)` | Enable schema evolution | `true` | | `xlang(bool)` | Use xlang mode | `true` | | `max_dyn_depth(u32)` | Maximum nesting depth for dynamic types | `5` | +| `max_container_memory_bytes(i64)` | Estimated container memory per root read | `-1` | | `max_type_fields(usize)` | Max fields in one received struct metadata body | `512` | | `max_type_meta_bytes(usize)` | Max encoded bytes in one received metadata body | `4096` | | `max_schema_versions_per_type(usize)` | Max remote metadata versions for one logical type | `10` | @@ -169,6 +199,8 @@ Security-related configuration: - Register application structs and trait-object implementations before deserializing untrusted payloads. - Use `max_dyn_depth(...)` to reject unexpectedly deep dynamic object graphs. +- Keep `max_container_memory_bytes(-1)` for the default input-shaped container budget, or set a + positive byte limit for trusted workloads with larger legitimate containers. - Keep the remote schema metadata limits at their defaults unless the data is not malicious and a trusted peer sends larger metadata or many schema versions. - Prefer concrete typed fields over `dyn Any` or broad trait-object fields for untrusted input. diff --git a/docs/guide/swift/configuration.md b/docs/guide/swift/configuration.md index e3a0478817..40b80744fb 100644 --- a/docs/guide/swift/configuration.md +++ b/docs/guide/swift/configuration.md @@ -31,6 +31,7 @@ public struct Config { public let compatible: Bool public let checkClassVersion: Bool public let maxDepth: Int + public let maxContainerMemoryBytes: Int64 public let maxTypeFields: Int public let maxTypeMetaBytes: Int public let maxSchemaVersionsPerType: Int @@ -90,8 +91,14 @@ let fory = Fory(compatible: false, checkClassVersion: true) ### Size and Depth Limits -`maxDepth` bounds decoded payload nesting depth. Compatible-mode remote metadata -is also limited: +`maxDepth` bounds decoded payload nesting depth. + +`maxContainerMemoryBytes` bounds the estimated container-owned memory accepted during one root +deserialization. Use `-1` for the default automatic limit. Swift roots are currently `Data` or +`ByteBuffer`, so auto uses the root input byte length times `8`, plus `64 KiB`. A positive value +overrides the automatic limit. `0` and negative values other than `-1` are rejected. + +Compatible-mode remote metadata is also limited: - `maxTypeFields` defaults to `512` and limits fields in one received struct metadata body. - `maxTypeMetaBytes` defaults to `4096` and limits encoded body bytes in one received TypeMeta body, @@ -104,6 +111,7 @@ is also limited: ```swift let fory = Fory( maxDepth: 5, + maxContainerMemoryBytes: -1, maxTypeFields: 512, maxTypeMetaBytes: 4096, maxSchemaVersionsPerType: 10, @@ -140,5 +148,7 @@ Security-related configuration: - Register only the expected generated models before deserializing untrusted payloads. - Use `checkClassVersion` with `compatible: false` for intentional same-schema payloads. - Set `maxDepth` for the largest nesting depth your service accepts. +- Set `maxContainerMemoryBytes` to cap estimated list, set, array, and map memory during one root + deserialization. - Keep the remote schema metadata limits at their defaults unless the data is not malicious and a trusted peer sends larger metadata or many schema versions. diff --git a/docs/security/deserialization.md b/docs/security/deserialization.md index 33dddd2886..109fb225c2 100644 --- a/docs/security/deserialization.md +++ b/docs/security/deserialization.md @@ -149,11 +149,13 @@ For buffer-backed input: comparison. - Multi-byte element arrays should compute the required byte size with overflow checks before allocation. -- Container readers that allocate, reserve, or size-hint from a declared - logical element count should first call the byte owner's readability check for - that count. This is not a full container-body validation; it is the allocation - proof that the sender has supplied at least proportional input bytes before - the reader preallocates from the count. +- Container readers that allocate backing storage or size-hint from a declared + logical element count should call the byte owner's readability check for that + count before that backing allocation or capacity reservation. This is not a + full container-body validation; it is the allocation proof that the sender has + supplied at least proportional input bytes before the reader preallocates from + the count. Estimated memory-budget accounting may reserve budget before this + byte check because it does not allocate backing storage. For stream-backed input: @@ -203,6 +205,42 @@ validation can cause a no-progress loop, unbounded resource growth, retained state, or success across a Fory policy boundary. Protocol-allowed chunk segmentation is normal input and is not a security issue by itself. +## Container Memory Budget + +Runtimes should enforce a root-deserialization budget for estimated +container-owned memory. This is cumulative accounting for containers created by +one root read; it is not exact heap measurement and it is not a raw element-slot +limit. + +The public configuration should be named around `maxContainerMemoryBytes`. +`-1` means automatic input-shaped budgeting. Positive user configuration always +wins. For known-length root input, the automatic budget is +`inputBytes * 8 + 64 KiB`. For true stream or otherwise unknown-length root +input, the automatic budget is fixed at `128 MiB`. Stream budgeting should not +depend on dynamic bytes-read accounting. + +Container budget accounting should: + +- happen in root-operation read state, with cleanup owned by the root + deserialization `finally`; +- reject arithmetic overflow before comparing budget or allocating; +- charge fixed container object cost, backing capacity, map table and entry + overhead, reference arrays, and inline or value storage where a runtime stores + elements inline; +- charge fixed cost even for zero-size containers; +- preserve existing byte-availability checks before backing allocation or + capacity reservation; +- skip dedicated string, binary, primitive array, and primitive dense-array + owner paths. + +Each runtime must inspect the concrete container path before choosing formulas. +Reference-backed containers should charge reference storage, using a 4-byte +reference slot when the actual reference slot size is not cheap or reliable to +query. Inline/value containers such as a value-type vector or list must charge +the inline element storage instead of treating those elements as references. +General inline-value containers must not be skipped just because dedicated +primitive dense arrays are skipped. + ## Skip Semantics Skipping unknown or incompatible data is classified by concrete impact, not by diff --git a/docs/specification/xlang_implementation_guide.md b/docs/specification/xlang_implementation_guide.md index d8fb205012..53f41887fd 100644 --- a/docs/specification/xlang_implementation_guide.md +++ b/docs/specification/xlang_implementation_guide.md @@ -388,9 +388,9 @@ chunk, nullability, reference, and type-dispatch semantics. It is still the right allocation proof for count-based preallocation: after validating a non-empty count and reading any serializer-owned header or type metadata that precedes allocation, call `checkReadableBytes(logicalCount)` before allocating, -reserving, or size-hinting from that count. The byte owner handles buffer versus -stream readiness; the container serializer then allocates with the declared -count and reads elements through its normal owner path. +reserving backing capacity, or size-hinting from that count. The byte owner +handles buffer versus stream readiness; the container serializer then allocates +with the declared count and reads elements through its normal owner path. This check is not a full container-body validation. It only prevents a small or truncated input from causing a large count-based preallocation. Chunk sizes, @@ -398,6 +398,25 @@ duplicate keys, element value semantics, and protocol strictness remain owned by the container/map serializer and should be validated only when they protect a real owner invariant. +Container readers should also charge a root-operation estimated container memory +budget before allocation or size hinting. The budget belongs to `ReadContext` or +the equivalent root read state, not to serializers and not to ambient +thread-local state. Positive `maxContainerMemoryBytes` configuration wins; auto +configuration uses `inputBytes * 8 + 64 KiB` for known-length root input and +fixed `128 MiB` for true stream or unknown-length root input. Do not add dynamic +stream bytes-read accounting for this budget. + +The budget estimates container-owned memory, not exact heap bytes. Charge fixed +container object cost, backing capacity, map table and entry overhead, +reference arrays, and inline/value element storage where the runtime stores +container elements inline. Charge zero-size containers for their fixed cost. +Skip dedicated string, binary, primitive array, and primitive dense-array owners, +but do not skip general inline-value containers such as vectors or lists of +value objects. If reference slot size is not cheap or reliable to query, use a +4-byte reference slot. Reject arithmetic overflow before budget comparison or +allocation, and keep the existing `checkReadableBytes` proof before backing +allocation or capacity reservation. + For TypeDef or TypeMeta bodies, first prove that the encoded metadata body bytes are readable through the byte owner. Field-list allocation should happen after that body readability check and should not use a separate small initial-capacity diff --git a/go/fory/README.md b/go/fory/README.md index 8b5ec3392c..e9c633ad8c 100644 --- a/go/fory/README.md +++ b/go/fory/README.md @@ -93,11 +93,15 @@ f := fory.New(fory.WithXlang(false), fory.WithCompatible(false)) // Set maximum nesting depth f := fory.New(fory.WithMaxDepth(20)) +// Set maximum estimated container memory for one root read +f := fory.New(fory.WithMaxContainerMemoryBytes(256 * 1024 * 1024)) + // Combine multiple options f := fory.New( fory.WithXlang(true), fory.WithTrackRef(true), fory.WithMaxDepth(20), + fory.WithMaxContainerMemoryBytes(-1), ) ``` diff --git a/go/fory/array.go b/go/fory/array.go index f99f6ff39f..93b81a85c2 100644 --- a/go/fory/array.go +++ b/go/fory/array.go @@ -290,7 +290,7 @@ func (s *arrayConcreteValueSerializer) ReadWithTypeInfo(ctx *ReadContext, refMod // arrayDynSerializer wraps sliceDynSerializer for arrays with interface element types. // It converts arrays to slices and delegates to sliceDynSerializer. type arrayDynSerializer struct { - sliceSerializer sliceDynSerializer + sliceSerializer *sliceDynSerializer } func newArrayDynSerializer(elemType reflect.Type) (arrayDynSerializer, error) { @@ -318,6 +318,9 @@ func (s arrayDynSerializer) Write(ctx *WriteContext, refMode RefMode, writeType func (s arrayDynSerializer) ReadData(ctx *ReadContext, value reflect.Value) { // Create a temp slice to read into, then copy back to array sliceType := reflect.SliceOf(value.Type().Elem()) + if !ctx.reserveSliceTypeMemory(value.Len(), value.Type().Elem()) { + return + } tempSlice := reflect.MakeSlice(sliceType, value.Len(), value.Len()) s.sliceSerializer.readData(ctx, tempSlice, value.Len()) if ctx.HasError() { diff --git a/go/fory/codegen/decoder.go b/go/fory/codegen/decoder.go index 0e57021343..3a25284909 100644 --- a/go/fory/codegen/decoder.go +++ b/go/fory/codegen/decoder.go @@ -172,6 +172,9 @@ func generateFieldReadTyped(buf *bytes.Buffer, field *FieldInfo) error { fmt.Fprintf(buf, "\t\t\tif ctx.HasError() {\n") fmt.Fprintf(buf, "\t\t\t\treturn ctx.TakeError()\n") fmt.Fprintf(buf, "\t\t\t}\n") + fmt.Fprintf(buf, "\t\t\tif !ctx.ReserveSliceMemory(sliceLen, %s) {\n", unsafeSizeExpr(elemType)) + fmt.Fprintf(buf, "\t\t\t\treturn ctx.TakeError()\n") + fmt.Fprintf(buf, "\t\t\t}\n") fmt.Fprintf(buf, "\t\t\tif sliceLen == 0 {\n") fmt.Fprintf(buf, "\t\t\t\t%s = make([]any, 0)\n", fieldAccess) fmt.Fprintf(buf, "\t\t\t} else {\n") @@ -200,6 +203,9 @@ func generateFieldReadTyped(buf *bytes.Buffer, field *FieldInfo) error { fmt.Fprintf(buf, "\t\t\t\tif ctx.HasError() {\n") fmt.Fprintf(buf, "\t\t\t\t\treturn ctx.TakeError()\n") fmt.Fprintf(buf, "\t\t\t\t}\n") + fmt.Fprintf(buf, "\t\t\t\tif !ctx.ReserveSliceMemory(sliceLen, %s) {\n", unsafeSizeExpr(elemType)) + fmt.Fprintf(buf, "\t\t\t\t\treturn ctx.TakeError()\n") + fmt.Fprintf(buf, "\t\t\t\t}\n") fmt.Fprintf(buf, "\t\t\t\tif sliceLen == 0 {\n") fmt.Fprintf(buf, "\t\t\t\t\t%s = make([]any, 0)\n", fieldAccess) fmt.Fprintf(buf, "\t\t\t\t} else {\n") @@ -501,6 +507,12 @@ func generateSliceReadInline(buf *bytes.Buffer, sliceType *types.Slice, fieldAcc fmt.Fprintf(buf, "\t\tif isXlang {\n") fmt.Fprintf(buf, "\t\t\t// xlang mode: slices are not nullable, read directly without null flag\n") fmt.Fprintf(buf, "\t\t\tsliceLen := ctx.ReadCollectionLength()\n") + fmt.Fprintf(buf, "\t\t\tif ctx.HasError() {\n") + fmt.Fprintf(buf, "\t\t\t\treturn ctx.TakeError()\n") + fmt.Fprintf(buf, "\t\t\t}\n") + fmt.Fprintf(buf, "\t\t\tif !ctx.ReserveSliceMemory(sliceLen, %s) {\n", unsafeSizeExpr(elemType)) + fmt.Fprintf(buf, "\t\t\t\treturn ctx.TakeError()\n") + fmt.Fprintf(buf, "\t\t\t}\n") fmt.Fprintf(buf, "\t\t\tif sliceLen == 0 {\n") fmt.Fprintf(buf, "\t\t\t\t%s = make(%s, 0)\n", fieldAccess, sliceType.String()) fmt.Fprintf(buf, "\t\t\t} else {\n") @@ -519,6 +531,9 @@ func generateSliceReadInline(buf *bytes.Buffer, sliceType *types.Slice, fieldAcc fmt.Fprintf(buf, "\t\t\t\tif ctx.HasError() {\n") fmt.Fprintf(buf, "\t\t\t\t\treturn ctx.TakeError()\n") fmt.Fprintf(buf, "\t\t\t\t}\n") + fmt.Fprintf(buf, "\t\t\t\tif !ctx.ReserveSliceMemory(sliceLen, %s) {\n", unsafeSizeExpr(elemType)) + fmt.Fprintf(buf, "\t\t\t\t\treturn ctx.TakeError()\n") + fmt.Fprintf(buf, "\t\t\t\t}\n") fmt.Fprintf(buf, "\t\t\t\tif sliceLen == 0 {\n") fmt.Fprintf(buf, "\t\t\t\t\t%s = make(%s, 0)\n", fieldAccess, sliceType.String()) fmt.Fprintf(buf, "\t\t\t\t} else {\n") @@ -545,6 +560,9 @@ func generateSliceReadInlineNoNull(buf *bytes.Buffer, sliceType *types.Slice, fi fmt.Fprintf(buf, "%sif ctx.HasError() {\n", indent) fmt.Fprintf(buf, "%s\treturn ctx.TakeError()\n", indent) fmt.Fprintf(buf, "%s}\n", indent) + fmt.Fprintf(buf, "%sif !ctx.ReserveSliceMemory(sliceLen, %s) {\n", indent, unsafeSizeExpr(elemType)) + fmt.Fprintf(buf, "%s\treturn ctx.TakeError()\n", indent) + fmt.Fprintf(buf, "%s}\n", indent) fmt.Fprintf(buf, "%sif sliceLen == 0 {\n", indent) fmt.Fprintf(buf, "%s\t%s = make([]any, 0)\n", indent, fieldAccess) fmt.Fprintf(buf, "%s} else {\n", indent) @@ -568,6 +586,9 @@ func generateSliceReadInlineNoNull(buf *bytes.Buffer, sliceType *types.Slice, fi fmt.Fprintf(buf, "%sif ctx.HasError() {\n", indent) fmt.Fprintf(buf, "%s\treturn ctx.TakeError()\n", indent) fmt.Fprintf(buf, "%s}\n", indent) + fmt.Fprintf(buf, "%sif !ctx.ReserveSliceMemory(sliceLen, %s) {\n", indent, unsafeSizeExpr(elemType)) + fmt.Fprintf(buf, "%s\treturn ctx.TakeError()\n", indent) + fmt.Fprintf(buf, "%s}\n", indent) fmt.Fprintf(buf, "%sif sliceLen == 0 {\n", indent) fmt.Fprintf(buf, "%s\t%s = make(%s, 0)\n", indent, fieldAccess, sliceType.String()) fmt.Fprintf(buf, "%s} else {\n", indent) @@ -831,6 +852,12 @@ func generateMapReadInline(buf *bytes.Buffer, mapType *types.Map, fieldAccess st fmt.Fprintf(buf, "\t\tif isXlang {\n") fmt.Fprintf(buf, "\t\t\t// xlang mode: maps are not nullable, read directly without null flag\n") fmt.Fprintf(buf, "\t\t\tmapLen := ctx.ReadCollectionLength()\n") + fmt.Fprintf(buf, "\t\t\tif ctx.HasError() {\n") + fmt.Fprintf(buf, "\t\t\t\treturn ctx.TakeError()\n") + fmt.Fprintf(buf, "\t\t\t}\n") + fmt.Fprintf(buf, "\t\t\tif !ctx.ReserveMapMemory(mapLen, %s, %s) {\n", unsafeSizeExpr(keyType), unsafeSizeExpr(valueType)) + fmt.Fprintf(buf, "\t\t\t\treturn ctx.TakeError()\n") + fmt.Fprintf(buf, "\t\t\t}\n") fmt.Fprintf(buf, "\t\t\tif mapLen == 0 {\n") fmt.Fprintf(buf, "\t\t\t\t%s = make(%s)\n", fieldAccess, mapType.String()) fmt.Fprintf(buf, "\t\t\t} else {\n") @@ -849,6 +876,9 @@ func generateMapReadInline(buf *bytes.Buffer, mapType *types.Map, fieldAccess st fmt.Fprintf(buf, "\t\t\t\tif ctx.HasError() {\n") fmt.Fprintf(buf, "\t\t\t\t\treturn ctx.TakeError()\n") fmt.Fprintf(buf, "\t\t\t\t}\n") + fmt.Fprintf(buf, "\t\t\t\tif !ctx.ReserveMapMemory(mapLen, %s, %s) {\n", unsafeSizeExpr(keyType), unsafeSizeExpr(valueType)) + fmt.Fprintf(buf, "\t\t\t\t\treturn ctx.TakeError()\n") + fmt.Fprintf(buf, "\t\t\t\t}\n") fmt.Fprintf(buf, "\t\t\t\tif mapLen == 0 {\n") fmt.Fprintf(buf, "\t\t\t\t\t%s = make(%s)\n", fieldAccess, mapType.String()) fmt.Fprintf(buf, "\t\t\t\t} else {\n") @@ -884,6 +914,9 @@ func generateMapReadInlineNoNull(buf *bytes.Buffer, mapType *types.Map, fieldAcc fmt.Fprintf(buf, "%sif ctx.HasError() {\n", indent) fmt.Fprintf(buf, "%s\treturn ctx.TakeError()\n", indent) fmt.Fprintf(buf, "%s}\n", indent) + fmt.Fprintf(buf, "%sif !ctx.ReserveMapMemory(mapLen, %s, %s) {\n", indent, unsafeSizeExpr(keyType), unsafeSizeExpr(valueType)) + fmt.Fprintf(buf, "%s\treturn ctx.TakeError()\n", indent) + fmt.Fprintf(buf, "%s}\n", indent) fmt.Fprintf(buf, "%sif mapLen == 0 {\n", indent) fmt.Fprintf(buf, "%s\t%s = make(%s)\n", indent, fieldAccess, mapType.String()) fmt.Fprintf(buf, "%s} else {\n", indent) @@ -978,6 +1011,10 @@ func getGoTypeString(t types.Type) string { return t.String() } +func unsafeSizeExpr(t types.Type) string { + return fmt.Sprintf("int64(unsafe.Sizeof(*new(%s)))", getGoTypeString(t)) +} + // generateMapKeyRead generates code to read a map key // Uses error-aware methods for deferred error checking func generateMapKeyRead(buf *bytes.Buffer, keyType types.Type, varName string) error { diff --git a/go/fory/codegen/generator.go b/go/fory/codegen/generator.go index dbe7842da4..30fcbd7fb3 100644 --- a/go/fory/codegen/generator.go +++ b/go/fory/codegen/generator.go @@ -21,6 +21,7 @@ import ( "bytes" "fmt" "go/format" + "go/types" "io/ioutil" "log" "os" @@ -33,6 +34,16 @@ import ( var logger = log.New(os.Stdout, "", 0) +func typeNeedsContainerReservation(t types.Type) bool { + if _, ok := t.(*types.Slice); ok { + return true + } + if _, ok := t.(*types.Map); ok { + return true + } + return false +} + // GeneratorOptions contains configuration for the code generator type GeneratorOptions struct { TypeList string // comma-separated list of types to generate code for @@ -286,6 +297,7 @@ func generateCodeForFile(pkg *packages.Package, structs []*StructInfo, sourceFil needsTime := false needsReflect := false needsOptional := false + needsUnsafe := false for _, s := range structs { for _, field := range s.Fields { @@ -295,6 +307,12 @@ func generateCodeForFile(pkg *packages.Package, structs []*StructInfo, sourceFil } if field.IsOptional { needsOptional = true + if field.OptionalElem != nil && typeNeedsContainerReservation(field.OptionalElem) { + needsUnsafe = true + } + } + if typeNeedsContainerReservation(field.Type) { + needsUnsafe = true } // We need reflect for the interface compatibility methods needsReflect = true @@ -310,6 +328,9 @@ func generateCodeForFile(pkg *packages.Package, structs []*StructInfo, sourceFil if needsTime { fmt.Fprintf(&buf, "\t\"time\"\n") } + if needsUnsafe { + fmt.Fprintf(&buf, "\t\"unsafe\"\n") + } fmt.Fprintf(&buf, "\t\"github.com/apache/fory/go/fory\"\n") if needsOptional { fmt.Fprintf(&buf, "\t\"github.com/apache/fory/go/fory/optional\"\n") @@ -551,6 +572,7 @@ func generateCode(pkg *packages.Package, structs []*StructInfo) error { needsTime := false needsReflect := false needsOptional := false + needsUnsafe := false for _, s := range structs { for _, field := range s.Fields { @@ -560,6 +582,12 @@ func generateCode(pkg *packages.Package, structs []*StructInfo) error { } if field.IsOptional { needsOptional = true + if field.OptionalElem != nil && typeNeedsContainerReservation(field.OptionalElem) { + needsUnsafe = true + } + } + if typeNeedsContainerReservation(field.Type) { + needsUnsafe = true } // We need reflect for the interface compatibility methods needsReflect = true @@ -575,6 +603,9 @@ func generateCode(pkg *packages.Package, structs []*StructInfo) error { if needsTime { fmt.Fprintf(&buf, "\t\"time\"\n") } + if needsUnsafe { + fmt.Fprintf(&buf, "\t\"unsafe\"\n") + } fmt.Fprintf(&buf, "\t\"github.com/apache/fory/go/fory\"\n") if needsOptional { fmt.Fprintf(&buf, "\t\"github.com/apache/fory/go/fory/optional\"\n") diff --git a/go/fory/container_memory_budget_test.go b/go/fory/container_memory_budget_test.go new file mode 100644 index 0000000000..16959b3d0a --- /dev/null +++ b/go/fory/container_memory_budget_test.go @@ -0,0 +1,207 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package fory + +import ( + "bytes" + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +type budgetItem struct { + A int32 +} + +type budgetSiblings struct { + A []string + B []string +} + +func TestContainerMemoryBudgetConfig(t *testing.T) { + require.Equal(t, int64(-1), New().config.MaxContainerMemoryBytes) + require.Equal(t, int64(123), New(WithMaxContainerMemoryBytes(123)).config.MaxContainerMemoryBytes) + require.Panics(t, func() { New(WithMaxContainerMemoryBytes(0)) }) + require.Panics(t, func() { New(WithMaxContainerMemoryBytes(-2)) }) +} + +func TestContainerMemoryBudgetAutoLimits(t *testing.T) { + ctx := NewReadContext(false) + ctx.initContainerMemoryBudget(10, false) + require.False(t, ctx.HasError()) + require.Equal(t, int64(10)*knownRootBudgetMultiplier+knownRootBudgetSlackBytes, ctx.containerMemoryLimitBytes) + require.True(t, ctx.ReserveContainerMemory(ctx.containerMemoryLimitBytes)) + require.False(t, ctx.ReserveContainerMemory(1)) + require.Contains(t, ctx.CheckError().Error(), "maxContainerMemoryBytes") + + ctx = NewReadContext(false) + ctx.initContainerMemoryBudget(10, true) + require.False(t, ctx.HasError()) + require.Equal(t, streamRootBudgetBytes, ctx.containerMemoryLimitBytes) + require.True(t, ctx.ReserveContainerMemory(streamRootBudgetBytes)) + require.False(t, ctx.ReserveContainerMemory(1)) + require.Contains(t, ctx.CheckError().Error(), "maxContainerMemoryBytes") + + ctx = NewReadContext(false) + ctx.maxContainerMemoryBytes = 77 + ctx.initContainerMemoryBudget(10, true) + require.False(t, ctx.HasError()) + require.Equal(t, int64(77), ctx.containerMemoryLimitBytes) +} + +func TestContainerMemoryBudgetKnownVsStreamRoot(t *testing.T) { + writer := New(WithCompatible(false)) + values := make([]any, 12000) + for i := range values { + values[i] = []any{} + } + data, err := writer.Serialize(values) + require.NoError(t, err) + + var fromBytes []any + err = New(WithCompatible(false)).Deserialize(data, &fromBytes) + require.Error(t, err) + require.Contains(t, err.Error(), "maxContainerMemoryBytes") + + var fromStream []any + err = New(WithCompatible(false)).DeserializeFromReader(bytes.NewReader(data), &fromStream) + require.NoError(t, err) + require.Len(t, fromStream, len(values)) +} + +func TestContainerMemoryBudgetExplicitOverride(t *testing.T) { + writer := New(WithCompatible(false)) + values := make([]any, 12000) + for i := range values { + values[i] = []any{} + } + data, err := writer.Serialize(values) + require.NoError(t, err) + + var out []any + err = New(WithCompatible(false), WithMaxContainerMemoryBytes(4*1024*1024)).Deserialize(data, &out) + require.NoError(t, err) + require.Len(t, out, len(values)) +} + +func TestContainerMemoryBudgetEmptyAndCumulative(t *testing.T) { + data, err := New(WithCompatible(false)).Serialize([]any{}) + require.NoError(t, err) + var empty []any + err = New(WithCompatible(false), WithMaxContainerMemoryBytes(sliceObjectBytes-1)).Deserialize(data, &empty) + require.Error(t, err) + require.Contains(t, err.Error(), "maxContainerMemoryBytes") + + writer := New(WithCompatible(false)) + require.NoError(t, writer.RegisterStructByName(budgetSiblings{}, "test.BudgetSiblings")) + data, err = writer.Serialize(&budgetSiblings{A: []string{}, B: []string{}}) + require.NoError(t, err) + reader := New(WithCompatible(false), WithMaxContainerMemoryBytes(sliceObjectBytes)) + require.NoError(t, reader.RegisterStructByName(budgetSiblings{}, "test.BudgetSiblings")) + var out budgetSiblings + err = reader.Deserialize(data, &out) + require.Error(t, err) + require.Contains(t, err.Error(), "maxContainerMemoryBytes") +} + +func TestContainerMemoryBudgetMapAndOverflow(t *testing.T) { + data, err := New().Serialize(map[string]string{"k": "v"}) + require.NoError(t, err) + var out map[string]string + oneEntryBudget := mapObjectBytes + + 2*referenceSlotBytes + + mapEntryOverheadBytes + referenceSlotBytes + + containerSizeOf[string]() + containerSizeOf[string]() + err = New(WithMaxContainerMemoryBytes(oneEntryBudget-1)).Deserialize(data, &out) + require.Error(t, err) + require.Contains(t, err.Error(), "maxContainerMemoryBytes") + + ctx := NewReadContext(false) + ctx.initContainerMemoryBudget(0, true) + require.False(t, ctx.ReserveMapMemory(MaxInt, MaxInt64, 1)) + require.Contains(t, ctx.CheckError().Error(), "overflows") +} + +func TestContainerMemoryBudgetSlicesAndInlineValues(t *testing.T) { + data, err := New().Serialize([]string{"a"}) + require.NoError(t, err) + var stringsOut []string + err = New(WithMaxContainerMemoryBytes(sliceObjectBytes+containerSizeOf[string]()-1)).Deserialize(data, &stringsOut) + require.Error(t, err) + require.Contains(t, err.Error(), "maxContainerMemoryBytes") + + writer := New() + require.NoError(t, writer.RegisterStructByName(budgetItem{}, "test.BudgetItem")) + data, err = writer.Serialize([]budgetItem{{A: 1}}) + require.NoError(t, err) + reader := New(WithMaxContainerMemoryBytes(sliceObjectBytes + containerSizeOf[budgetItem]() - 1)) + require.NoError(t, reader.RegisterStructByName(budgetItem{}, "test.BudgetItem")) + var items []budgetItem + err = reader.Deserialize(data, &items) + require.Error(t, err) + require.Contains(t, err.Error(), "maxContainerMemoryBytes") +} + +func TestContainerMemoryBudgetSkipsDenseOwners(t *testing.T) { + f := New(WithMaxContainerMemoryBytes(1)) + + stringData, err := New().Serialize(strings.Repeat("x", 128)) + require.NoError(t, err) + var s string + require.NoError(t, f.Deserialize(stringData, &s)) + require.Len(t, s, 128) + + bytesData, err := New().Serialize([]byte{1, 2, 3, 4}) + require.NoError(t, err) + var b []byte + require.NoError(t, f.Deserialize(bytesData, &b)) + require.Equal(t, []byte{1, 2, 3, 4}, b) + + intsData, err := New().Serialize([]int32{1, 2, 3, 4}) + require.NoError(t, err) + var ints []int32 + require.NoError(t, f.Deserialize(intsData, &ints)) + require.Equal(t, []int32{1, 2, 3, 4}, ints) +} + +func TestContainerMemoryBudgetPreservesByteChecks(t *testing.T) { + buf := NewByteBuffer(nil) + buf.WriteByte_(XLangFlag) + buf.WriteInt8(NotNullValueFlag) + buf.WriteUint8(uint8(LIST)) + buf.WriteLength(1024) + buf.WriteInt8(int8(CollectionIsSameType)) + buf.WriteUint8(uint8(STRING)) + + var stringsOut []string + err := New(WithMaxContainerMemoryBytes(8*1024*1024)).Deserialize(buf.Bytes(), &stringsOut) + require.Error(t, err) + require.Contains(t, err.Error(), "buffer out of bound") + + buf = NewByteBuffer(nil) + buf.WriteByte_(XLangFlag) + buf.WriteInt8(NotNullValueFlag) + buf.WriteUint8(uint8(INT32_ARRAY)) + buf.WriteLength(4096) + + var ints []int32 + err = New(WithMaxContainerMemoryBytes(8*1024*1024)).Deserialize(buf.Bytes(), &ints) + require.Error(t, err) + require.Contains(t, err.Error(), "buffer out of bound") +} diff --git a/go/fory/field_serializer.go b/go/fory/field_serializer.go index d91b6ec5e1..d3ef3a787b 100644 --- a/go/fory/field_serializer.go +++ b/go/fory/field_serializer.go @@ -42,7 +42,7 @@ func serializerNeedsGenericDispatch(serializer Serializer) bool { switch serializer.(type) { case *sliceSerializer, primitiveListSerializer, - sliceDynSerializer, + *sliceDynSerializer, setSerializer, mapSerializer, stringSliceSerializer, @@ -68,10 +68,13 @@ func newDeclaredSliceSerializer(type_ reflect.Type, elemSerializer Serializer, r if elem.Kind() == reflect.Ptr && elem.Elem().Kind() == reflect.Interface { return nil, fmt.Errorf("slice serializer does not support pointer to interface element type: %v", type_) } + elemBytes := int64(elem.Size()) return &sliceSerializer{ type_: type_, elemSerializer: elemSerializer, referencable: referencable, + elemBytes: elemBytes, + maxLength: maxSliceLength(elemBytes), }, nil } diff --git a/go/fory/fory.go b/go/fory/fory.go index 412fc46449..7bfb9867ef 100644 --- a/go/fory/fory.go +++ b/go/fory/fory.go @@ -69,6 +69,7 @@ type Config struct { MaxDepth int IsXlang bool Compatible bool // Schema evolution compatibility mode + MaxContainerMemoryBytes int64 MaxTypeFields int MaxTypeMetaBytes int MaxSchemaVersionsPerType int @@ -82,6 +83,7 @@ func defaultConfig() Config { MaxDepth: 20, IsXlang: true, MaxTypeFields: 512, + MaxContainerMemoryBytes: -1, MaxTypeMetaBytes: 4096, MaxSchemaVersionsPerType: 10, MaxAverageSchemaVersionsPerType: 3, @@ -110,6 +112,17 @@ func WithMaxDepth(depth int) Option { } } +// WithMaxContainerMemoryBytes sets the maximum estimated container-owned memory accepted during one root deserialization. +// Use -1 for the automatic input-shaped limit. +func WithMaxContainerMemoryBytes(size int64) Option { + if size != -1 && size <= 0 { + panic("MaxContainerMemoryBytes must be positive or -1 for auto") + } + return func(f *Fory) { + f.config.MaxContainerMemoryBytes = size + } +} + // WithXlang sets cross-language serialization mode func WithXlang(enabled bool) Option { return func(f *Fory) { @@ -218,6 +231,7 @@ func New(opts ...Option) *Fory { f.writeCtx.xlang = f.config.IsXlang f.readCtx = NewReadContext(f.config.TrackRef) + f.readCtx.maxContainerMemoryBytes = f.config.MaxContainerMemoryBytes f.readCtx.typeResolver = f.typeResolver f.readCtx.refResolver = f.refResolver f.readCtx.compatible = f.config.Compatible @@ -556,6 +570,10 @@ func (f *Fory) Serialize(value any) ([]byte, error) { func (f *Fory) Deserialize(data []byte, v any) error { defer f.resetReadState() f.readCtx.SetData(data) + f.readCtx.initContainerMemoryBudget(len(data), false) + if f.readCtx.HasError() { + return f.readCtx.TakeError() + } readHeader(f.readCtx) if f.readCtx.HasError() { @@ -1016,6 +1034,10 @@ func Deserialize[T any](f *Fory, data []byte, target *T) error { // Reuse context, reset and set new data f.readCtx.Reset() f.readCtx.SetData(data) + f.readCtx.initContainerMemoryBudget(len(data), false) + if f.readCtx.HasError() { + return f.readCtx.TakeError() + } // ReadData and validate header readHeader(f.readCtx) diff --git a/go/fory/map.go b/go/fory/map.go index 8b1d82cc95..fdb8ebbc53 100644 --- a/go/fory/map.go +++ b/go/fory/map.go @@ -303,6 +303,9 @@ func (s mapSerializer) ReadData(ctx *ReadContext, value reflect.Value) { iface := reflect.TypeOf((*any)(nil)).Elem() mapType = reflect.MapOf(iface, iface) } + if !ctx.reserveMapTypeMemory(size, mapType.Key(), mapType.Elem()) { + return + } if size == 0 { if value.IsNil() { value.Set(reflect.MakeMap(mapType)) diff --git a/go/fory/map_primitive.go b/go/fory/map_primitive.go index d520e97bd5..5a4e925ade 100644 --- a/go/fory/map_primitive.go +++ b/go/fory/map_primitive.go @@ -25,6 +25,27 @@ import ( // Optimized map serializers for common types // ============================================================================ +var ( + stringStringMapElemBytes = mapElementMemory(stringElementBytes, stringElementBytes) + stringStringMapMaxLength = maxMapLength(stringStringMapElemBytes) + stringInt64MapElemBytes = mapElementMemory(stringElementBytes, containerSizeOf[int64]()) + stringInt64MapMaxLength = maxMapLength(stringInt64MapElemBytes) + stringInt32MapElemBytes = mapElementMemory(stringElementBytes, containerSizeOf[int32]()) + stringInt32MapMaxLength = maxMapLength(stringInt32MapElemBytes) + stringIntMapElemBytes = mapElementMemory(stringElementBytes, containerSizeOf[int]()) + stringIntMapMaxLength = maxMapLength(stringIntMapElemBytes) + stringFloat64MapElemBytes = mapElementMemory(stringElementBytes, containerSizeOf[float64]()) + stringFloat64MapMaxLength = maxMapLength(stringFloat64MapElemBytes) + stringBoolMapElemBytes = mapElementMemory(stringElementBytes, containerSizeOf[bool]()) + stringBoolMapMaxLength = maxMapLength(stringBoolMapElemBytes) + int32Int32MapElemBytes = mapElementMemory(containerSizeOf[int32](), containerSizeOf[int32]()) + int32Int32MapMaxLength = maxMapLength(int32Int32MapElemBytes) + int64Int64MapElemBytes = mapElementMemory(containerSizeOf[int64](), containerSizeOf[int64]()) + int64Int64MapMaxLength = maxMapLength(int64Int64MapElemBytes) + intIntMapElemBytes = mapElementMemory(containerSizeOf[int](), containerSizeOf[int]()) + intIntMapMaxLength = maxMapLength(intIntMapElemBytes) +) + // writeMapStringString writes map[string]string using chunk protocol // When hasGenerics=true, element types are known so we set DECL_TYPE flags and skip type info func writeMapStringString(buf *ByteBuffer, m map[string]string, hasGenerics bool) { @@ -68,10 +89,16 @@ func writeMapStringString(buf *ByteBuffer, m map[string]string, hasGenerics bool } } -func readTypedMapSize(ctx *ReadContext) (int, bool) { +func readTypedMapSize(ctx *ReadContext, elemBytes int64, maxLength int64) (int, bool) { size := ctx.ReadCollectionLength() - if size == 0 || ctx.HasError() { - return size, false + if ctx.HasError() { + return 0, false + } + if !ctx.reserveMapMemory(size, elemBytes, maxLength) { + return 0, false + } + if size == 0 { + return size, true } if !ctx.Buffer().CheckReadable(size, ctx.Err()) { return 0, false @@ -83,12 +110,11 @@ func readTypedMapSize(ctx *ReadContext) (int, bool) { func readMapStringString(ctx *ReadContext) map[string]string { err := ctx.Err() buf := ctx.Buffer() - size, ok := readTypedMapSize(ctx) - result := make(map[string]string) + size, ok := readTypedMapSize(ctx, stringStringMapElemBytes, stringStringMapMaxLength) if !ok { - return result + return nil } - result = make(map[string]string, size) + result := make(map[string]string, size) for size > 0 { chunkHeader := buf.ReadUint8(err) @@ -171,12 +197,11 @@ func writeMapStringInt64(buf *ByteBuffer, m map[string]int64, hasGenerics bool) func readMapStringInt64(ctx *ReadContext) map[string]int64 { err := ctx.Err() buf := ctx.Buffer() - size, ok := readTypedMapSize(ctx) - result := make(map[string]int64) + size, ok := readTypedMapSize(ctx, stringInt64MapElemBytes, stringInt64MapMaxLength) if !ok { - return result + return nil } - result = make(map[string]int64, size) + result := make(map[string]int64, size) for size > 0 { chunkHeader := buf.ReadUint8(err) @@ -256,12 +281,11 @@ func writeMapStringInt32(buf *ByteBuffer, m map[string]int32, hasGenerics bool) func readMapStringInt32(ctx *ReadContext) map[string]int32 { err := ctx.Err() buf := ctx.Buffer() - size, ok := readTypedMapSize(ctx) - result := make(map[string]int32) + size, ok := readTypedMapSize(ctx, stringInt32MapElemBytes, stringInt32MapMaxLength) if !ok { - return result + return nil } - result = make(map[string]int32, size) + result := make(map[string]int32, size) for size > 0 { chunkHeader := buf.ReadUint8(err) @@ -341,12 +365,11 @@ func writeMapStringInt(buf *ByteBuffer, m map[string]int, hasGenerics bool) { func readMapStringInt(ctx *ReadContext) map[string]int { err := ctx.Err() buf := ctx.Buffer() - size, ok := readTypedMapSize(ctx) - result := make(map[string]int) + size, ok := readTypedMapSize(ctx, stringIntMapElemBytes, stringIntMapMaxLength) if !ok { - return result + return nil } - result = make(map[string]int, size) + result := make(map[string]int, size) for size > 0 { chunkHeader := buf.ReadUint8(err) @@ -426,12 +449,11 @@ func writeMapStringFloat64(buf *ByteBuffer, m map[string]float64, hasGenerics bo func readMapStringFloat64(ctx *ReadContext) map[string]float64 { err := ctx.Err() buf := ctx.Buffer() - size, ok := readTypedMapSize(ctx) - result := make(map[string]float64) + size, ok := readTypedMapSize(ctx, stringFloat64MapElemBytes, stringFloat64MapMaxLength) if !ok { - return result + return nil } - result = make(map[string]float64, size) + result := make(map[string]float64, size) for size > 0 { chunkHeader := buf.ReadUint8(err) @@ -511,12 +533,11 @@ func writeMapStringBool(buf *ByteBuffer, m map[string]bool, hasGenerics bool) { func readMapStringBool(ctx *ReadContext) map[string]bool { err := ctx.Err() buf := ctx.Buffer() - size, ok := readTypedMapSize(ctx) - result := make(map[string]bool) + size, ok := readTypedMapSize(ctx, stringBoolMapElemBytes, stringBoolMapMaxLength) if !ok { - return result + return nil } - result = make(map[string]bool, size) + result := make(map[string]bool, size) for size > 0 { chunkHeader := buf.ReadUint8(err) @@ -600,12 +621,11 @@ func writeMapInt32Int32(buf *ByteBuffer, m map[int32]int32, hasGenerics bool) { func readMapInt32Int32(ctx *ReadContext) map[int32]int32 { err := ctx.Err() buf := ctx.Buffer() - size, ok := readTypedMapSize(ctx) - result := make(map[int32]int32) + size, ok := readTypedMapSize(ctx, int32Int32MapElemBytes, int32Int32MapMaxLength) if !ok { - return result + return nil } - result = make(map[int32]int32, size) + result := make(map[int32]int32, size) for size > 0 { chunkHeader := buf.ReadUint8(err) @@ -685,12 +705,11 @@ func writeMapInt64Int64(buf *ByteBuffer, m map[int64]int64, hasGenerics bool) { func readMapInt64Int64(ctx *ReadContext) map[int64]int64 { err := ctx.Err() buf := ctx.Buffer() - size, ok := readTypedMapSize(ctx) - result := make(map[int64]int64) + size, ok := readTypedMapSize(ctx, int64Int64MapElemBytes, int64Int64MapMaxLength) if !ok { - return result + return nil } - result = make(map[int64]int64, size) + result := make(map[int64]int64, size) for size > 0 { chunkHeader := buf.ReadUint8(err) @@ -770,12 +789,11 @@ func writeMapIntInt(buf *ByteBuffer, m map[int]int, hasGenerics bool) { func readMapIntInt(ctx *ReadContext) map[int]int { err := ctx.Err() buf := ctx.Buffer() - size, ok := readTypedMapSize(ctx) - result := make(map[int]int) + size, ok := readTypedMapSize(ctx, intIntMapElemBytes, intIntMapMaxLength) if !ok { - return result + return nil } - result = make(map[int]int, size) + result := make(map[int]int, size) for size > 0 { chunkHeader := buf.ReadUint8(err) @@ -831,12 +849,12 @@ func (s stringStringMapSerializer) Write(ctx *WriteContext, refMode RefMode, wri } func (s stringStringMapSerializer) ReadData(ctx *ReadContext, value reflect.Value) { - if value.IsNil() { - value.Set(reflect.MakeMap(value.Type())) - } - ctx.RefResolver().Reference(value) result := readMapStringString(ctx) + if ctx.HasError() { + return + } value.Set(reflect.ValueOf(result)) + ctx.RefResolver().Reference(value) } func (s stringStringMapSerializer) Read(ctx *ReadContext, refMode RefMode, readType bool, hasGenerics bool, value reflect.Value) { @@ -866,12 +884,12 @@ func (s stringInt64MapSerializer) Write(ctx *WriteContext, refMode RefMode, writ } func (s stringInt64MapSerializer) ReadData(ctx *ReadContext, value reflect.Value) { - if value.IsNil() { - value.Set(reflect.MakeMap(value.Type())) - } - ctx.RefResolver().Reference(value) result := readMapStringInt64(ctx) + if ctx.HasError() { + return + } value.Set(reflect.ValueOf(result)) + ctx.RefResolver().Reference(value) } func (s stringInt64MapSerializer) Read(ctx *ReadContext, refMode RefMode, readType bool, hasGenerics bool, value reflect.Value) { @@ -901,12 +919,12 @@ func (s stringIntMapSerializer) Write(ctx *WriteContext, refMode RefMode, writeT } func (s stringIntMapSerializer) ReadData(ctx *ReadContext, value reflect.Value) { - if value.IsNil() { - value.Set(reflect.MakeMap(value.Type())) - } - ctx.RefResolver().Reference(value) result := readMapStringInt(ctx) + if ctx.HasError() { + return + } value.Set(reflect.ValueOf(result)) + ctx.RefResolver().Reference(value) } func (s stringIntMapSerializer) Read(ctx *ReadContext, refMode RefMode, readType bool, hasGenerics bool, value reflect.Value) { @@ -936,12 +954,12 @@ func (s stringFloat64MapSerializer) Write(ctx *WriteContext, refMode RefMode, wr } func (s stringFloat64MapSerializer) ReadData(ctx *ReadContext, value reflect.Value) { - if value.IsNil() { - value.Set(reflect.MakeMap(value.Type())) - } - ctx.RefResolver().Reference(value) result := readMapStringFloat64(ctx) + if ctx.HasError() { + return + } value.Set(reflect.ValueOf(result)) + ctx.RefResolver().Reference(value) } func (s stringFloat64MapSerializer) Read(ctx *ReadContext, refMode RefMode, readType bool, hasGenerics bool, value reflect.Value) { @@ -971,12 +989,12 @@ func (s stringBoolMapSerializer) Write(ctx *WriteContext, refMode RefMode, write } func (s stringBoolMapSerializer) ReadData(ctx *ReadContext, value reflect.Value) { - if value.IsNil() { - value.Set(reflect.MakeMap(value.Type())) - } - ctx.RefResolver().Reference(value) result := readMapStringBool(ctx) + if ctx.HasError() { + return + } value.Set(reflect.ValueOf(result)) + ctx.RefResolver().Reference(value) } func (s stringBoolMapSerializer) Read(ctx *ReadContext, refMode RefMode, readType bool, hasGenerics bool, value reflect.Value) { @@ -1006,12 +1024,12 @@ func (s int32Int32MapSerializer) Write(ctx *WriteContext, refMode RefMode, write } func (s int32Int32MapSerializer) ReadData(ctx *ReadContext, value reflect.Value) { - if value.IsNil() { - value.Set(reflect.MakeMap(value.Type())) - } - ctx.RefResolver().Reference(value) result := readMapInt32Int32(ctx) + if ctx.HasError() { + return + } value.Set(reflect.ValueOf(result)) + ctx.RefResolver().Reference(value) } func (s int32Int32MapSerializer) Read(ctx *ReadContext, refMode RefMode, readType bool, hasGenerics bool, value reflect.Value) { @@ -1041,12 +1059,12 @@ func (s int64Int64MapSerializer) Write(ctx *WriteContext, refMode RefMode, write } func (s int64Int64MapSerializer) ReadData(ctx *ReadContext, value reflect.Value) { - if value.IsNil() { - value.Set(reflect.MakeMap(value.Type())) - } - ctx.RefResolver().Reference(value) result := readMapInt64Int64(ctx) + if ctx.HasError() { + return + } value.Set(reflect.ValueOf(result)) + ctx.RefResolver().Reference(value) } func (s int64Int64MapSerializer) Read(ctx *ReadContext, refMode RefMode, readType bool, hasGenerics bool, value reflect.Value) { @@ -1076,12 +1094,12 @@ func (s intIntMapSerializer) Write(ctx *WriteContext, refMode RefMode, writeType } func (s intIntMapSerializer) ReadData(ctx *ReadContext, value reflect.Value) { - if value.IsNil() { - value.Set(reflect.MakeMap(value.Type())) - } - ctx.RefResolver().Reference(value) result := readMapIntInt(ctx) + if ctx.HasError() { + return + } value.Set(reflect.ValueOf(result)) + ctx.RefResolver().Reference(value) } func (s intIntMapSerializer) Read(ctx *ReadContext, refMode RefMode, readType bool, hasGenerics bool, value reflect.Value) { diff --git a/go/fory/reader.go b/go/fory/reader.go index 3985bb4e2b..b3d6301d65 100644 --- a/go/fory/reader.go +++ b/go/fory/reader.go @@ -29,21 +29,60 @@ import ( // ReadContext holds all state needed during deserialization. type ReadContext struct { - buffer *ByteBuffer - refReader *RefReader - trackRef bool // Cached flag to avoid indirection - xlang bool // Cross-language serialization mode - rootHeader byte - compatible bool // Schema evolution compatibility mode - typeResolver *TypeResolver // For complex type deserialization - refResolver *RefResolver // For reference tracking in native-mode paths - outOfBandBuffers []*ByteBuffer // Out-of-band buffers for deserialization - outOfBandIndex int // Current index into out-of-band buffers - depth int // Current nesting depth for cycle detection - maxDepth int // Maximum allowed nesting depth - err Error // Accumulated error state for deferred checking - lastTypePtr uintptr - lastTypeInfo *TypeInfo + buffer *ByteBuffer + refReader *RefReader + trackRef bool // Cached flag to avoid indirection + xlang bool // Cross-language serialization mode + rootHeader byte + compatible bool // Schema evolution compatibility mode + typeResolver *TypeResolver // For complex type deserialization + refResolver *RefResolver // For reference tracking in native-mode paths + outOfBandBuffers []*ByteBuffer // Out-of-band buffers for deserialization + outOfBandIndex int // Current index into out-of-band buffers + depth int // Current nesting depth for cycle detection + maxDepth int // Maximum allowed nesting depth + err Error // Accumulated error state for deferred checking + lastTypePtr uintptr + lastTypeInfo *TypeInfo + maxContainerMemoryBytes int64 + containerMemoryLimitBytes int64 + remainingContainerMemoryBytes int64 +} + +const ( + knownRootBudgetMultiplier = int64(8) + knownRootBudgetSlackBytes = int64(64 * 1024) + streamRootBudgetBytes = int64(128 * 1024 * 1024) + sliceObjectBytes = int64(unsafe.Sizeof([]byte(nil))) + mapObjectBytes = int64(48) + mapEntryOverheadBytes = int64(16) +) + +var referenceSlotBytes = int64(unsafe.Sizeof(uintptr(0))) +var stringElementBytes = containerSizeOf[string]() +var stringSliceMaxLength = maxSliceLength(stringElementBytes) + +func containerSizeOf[T any]() int64 { + var v T + return int64(unsafe.Sizeof(v)) +} + +func maxSliceLength(elemBytes int64) int64 { + if elemBytes == 0 { + return MaxInt64 + } + return (MaxInt64 - sliceObjectBytes) / elemBytes +} + +func mapElementMemory(keyBytes int64, valueBytes int64) int64 { + return keyBytes + valueBytes + mapEntryOverheadBytes + referenceSlotBytes + 2*referenceSlotBytes +} + +func maxMapLength(elemBytes int64) int64 { + if elemBytes == 0 { + return MaxInt64 + } + return (MaxInt64 - mapObjectBytes) / elemBytes } // IsXlang returns whether cross-language serialization mode is enabled @@ -54,10 +93,11 @@ func (c *ReadContext) IsXlang() bool { // NewReadContext creates a new read context func NewReadContext(trackRef bool) *ReadContext { return &ReadContext{ - buffer: NewByteBuffer(nil), - refReader: NewRefReader(trackRef), - trackRef: trackRef, - maxDepth: 128, // Default maximum nesting depth + buffer: NewByteBuffer(nil), + refReader: NewRefReader(trackRef), + trackRef: trackRef, + maxDepth: 128, // Default maximum nesting depth + maxContainerMemoryBytes: -1, } } @@ -67,6 +107,8 @@ func (c *ReadContext) Reset() { c.outOfBandBuffers = nil c.outOfBandIndex = 0 c.err = Error{} // Clear error state + // Container budget state is overwritten by each root read before deserialization. + // Avoid extra reset stores on the successful root hot path. if c.refResolver != nil { c.refResolver.resetRead() } @@ -75,6 +117,157 @@ func (c *ReadContext) Reset() { } } +func (c *ReadContext) initContainerMemoryBudget(rootInputBytes int, unknownLengthInput bool) { + limit := c.maxContainerMemoryBytes + if limit <= 0 { + if unknownLengthInput { + limit = streamRootBudgetBytes + } else { + if rootInputBytes < 0 { + c.setContainerMemoryError("root input size must be non-negative: %d", rootInputBytes) + return + } + if int64(rootInputBytes) > (MaxInt64-knownRootBudgetSlackBytes)/knownRootBudgetMultiplier { + c.setContainerMemoryError("root input size %d overflows automatic container memory budget", rootInputBytes) + return + } + limit = int64(rootInputBytes)*knownRootBudgetMultiplier + knownRootBudgetSlackBytes + } + } + c.containerMemoryLimitBytes = limit + c.remainingContainerMemoryBytes = limit +} + +// ReserveSliceMemory reserves estimated memory for a Go slice backing array before allocation. +func (c *ReadContext) ReserveSliceMemory(length int, elemBytes int64) bool { + if elemBytes < 0 { + c.setContainerMemoryError("negative container element size: %d", elemBytes) + return false + } + return c.reserveSliceMemory(length, elemBytes, maxSliceLength(elemBytes)) +} + +func (c *ReadContext) reserveSliceMemory(length int, elemBytes int64, maxLength int64) bool { + if length < 0 { + c.setContainerMemoryError("negative container length: %d", length) + return false + } + if int64(length) > maxLength { + c.setContainerMemoryError("container memory estimate overflows: length=%d elementBytes=%d", length, elemBytes) + return false + } + bytes := sliceObjectBytes + int64(length)*elemBytes + remaining := c.remainingContainerMemoryBytes + if bytes > remaining { + c.setContainerMemoryExceeded(bytes, remaining) + return false + } + c.remainingContainerMemoryBytes = remaining - bytes + return true +} + +func (c *ReadContext) reserveSliceTypeMemory(length int, elemType reflect.Type) bool { + elemBytes := referenceSlotBytes + if elemType != nil { + elemBytes = int64(elemType.Size()) + } + return c.ReserveSliceMemory(length, elemBytes) +} + +// ReserveMapMemory reserves estimated memory for a Go map before allocation or size hinting. +func (c *ReadContext) ReserveMapMemory(length int, keyBytes int64, valueBytes int64) bool { + if keyBytes < 0 || valueBytes < 0 { + c.setContainerMemoryError("negative map element size: key=%d value=%d", keyBytes, valueBytes) + return false + } + perEntry := keyBytes + valueBytes + if perEntry < keyBytes || perEntry > MaxInt64-mapEntryOverheadBytes-referenceSlotBytes { + c.setContainerMemoryError("map element size overflows: key=%d value=%d", keyBytes, valueBytes) + return false + } + perEntry += mapEntryOverheadBytes + referenceSlotBytes + if perEntry > MaxInt64-2*referenceSlotBytes { + c.setContainerMemoryError("map entry size overflows: key=%d value=%d", keyBytes, valueBytes) + return false + } + elemBytes := perEntry + 2*referenceSlotBytes + return c.reserveMapMemory(length, elemBytes, maxMapLength(elemBytes)) +} + +func (c *ReadContext) reserveMapTypeMemory(length int, keyType reflect.Type, valueType reflect.Type) bool { + keyBytes := referenceSlotBytes + valueBytes := referenceSlotBytes + if keyType != nil { + keyBytes = int64(keyType.Size()) + } + if valueType != nil { + valueBytes = int64(valueType.Size()) + } + return c.ReserveMapMemory(length, keyBytes, valueBytes) +} + +func (c *ReadContext) reserveMapMemory(length int, elemBytes int64, maxLength int64) bool { + if length < 0 { + c.setContainerMemoryError("negative container length: %d", length) + return false + } + if int64(length) > maxLength { + c.setContainerMemoryError("container memory estimate overflows: length=%d elementBytes=%d", length, elemBytes) + return false + } + bytes := mapObjectBytes + int64(length)*elemBytes + remaining := c.remainingContainerMemoryBytes + if bytes > remaining { + c.setContainerMemoryExceeded(bytes, remaining) + return false + } + c.remainingContainerMemoryBytes = remaining - bytes + return true +} + +func (c *ReadContext) reserveCountedMemory(length int, fixedBytes int64, elemBytes int64) bool { + if length < 0 { + c.setContainerMemoryError("negative container length: %d", length) + return false + } + if fixedBytes < 0 || elemBytes < 0 { + c.setContainerMemoryError("negative container memory estimate: fixed=%d elem=%d", fixedBytes, elemBytes) + return false + } + if elemBytes != 0 && int64(length) > (MaxInt64-fixedBytes)/elemBytes { + c.setContainerMemoryError("container memory estimate overflows: length=%d elementBytes=%d", length, elemBytes) + return false + } + return c.ReserveContainerMemory(fixedBytes + int64(length)*elemBytes) +} + +// ReserveContainerMemory reserves raw estimated container-owned bytes. +func (c *ReadContext) ReserveContainerMemory(bytes int64) bool { + if bytes < 0 { + c.setContainerMemoryError("estimated container memory must be non-negative, got %d bytes", bytes) + return false + } + remaining := c.remainingContainerMemoryBytes + if bytes > remaining { + c.setContainerMemoryExceeded(bytes, remaining) + return false + } + c.remainingContainerMemoryBytes = remaining - bytes + return true +} + +//go:noinline +func (c *ReadContext) setContainerMemoryError(format string, args ...any) { + c.SetError(DeserializationErrorf(format, args...)) +} + +//go:noinline +func (c *ReadContext) setContainerMemoryExceeded(bytes int64, remaining int64) { + c.SetError(DeserializationErrorf( + "estimated container memory request %d bytes exceeds maxContainerMemoryBytes remaining budget %d bytes out of effective limit %d bytes", + bytes, remaining, c.containerMemoryLimitBytes)) +} + // SetData sets new input data (for buffer reuse) // Reuses existing buffer to avoid allocation func (c *ReadContext) SetData(data []byte) { @@ -536,7 +729,42 @@ func (c *ReadContext) ReadStringSlice(refMode RefMode, readType bool) []string { if readType { _ = c.buffer.ReadUint8(err) } - return ReadStringSlice(c.buffer, err) + return c.readStringSliceData() +} + +func (c *ReadContext) readStringSliceData() []string { + buf := c.buffer + err := c.Err() + length := buf.ReadLength(err) + if c.HasError() { + return nil + } + if !c.reserveSliceMemory(length, containerSizeOf[string](), stringSliceMaxLength) { + return nil + } + if length == 0 { + return make([]string, 0) + } + collectFlag := buf.ReadInt8(err) + if (collectFlag&CollectionIsSameType) != 0 && (collectFlag&CollectionIsDeclElementType) == 0 { + _ = buf.ReadUint8(err) + } + if c.HasError() || !buf.CheckReadable(length, err) { + return nil + } + result := make([]string, length) + trackRefs := (collectFlag & CollectionTrackingRef) != 0 + hasNull := (collectFlag & CollectionHasNull) != 0 + for i := 0; i < length; i++ { + if trackRefs || hasNull { + rf := buf.ReadInt8(err) + if rf == NullFlag { + continue + } + } + result[i] = readString(buf, err) + } + return result } // ReadStringStringMap reads map[string]string with optional ref/type info diff --git a/go/fory/set.go b/go/fory/set.go index 1a42739547..652f8ddca9 100644 --- a/go/fory/set.go +++ b/go/fory/set.go @@ -318,6 +318,9 @@ func (s setSerializer) ReadData(ctx *ReadContext, value reflect.Value) { return } if length == 0 { + if !ctx.reserveMapTypeMemory(length, type_.Key(), type_.Elem()) { + return + } // Initialize empty set if length is 0 value.Set(reflect.MakeMap(type_)) return @@ -356,6 +359,9 @@ func (s setSerializer) ReadData(ctx *ReadContext, value reflect.Value) { if !buf.CheckReadable(length, err) { return } + if !ctx.reserveMapTypeMemory(length, type_.Key(), type_.Elem()) { + return + } // Initialize set if nil if value.IsNil() { diff --git a/go/fory/slice.go b/go/fory/slice.go index 6d941b3bf6..56d4d4845f 100644 --- a/go/fory/slice.go +++ b/go/fory/slice.go @@ -124,6 +124,8 @@ type sliceSerializer struct { type_ reflect.Type elemSerializer Serializer referencable bool + elemBytes int64 + maxLength int64 } // newSliceSerializer creates a sliceSerializer for slices with concrete element types. @@ -144,10 +146,13 @@ func newSliceSerializer(type_ reflect.Type, elemSerializer Serializer, xlang boo reflect.Uint8, reflect.Float32, reflect.Float64: return nil, fmt.Errorf("sliceSerializer does not support primitive element type %v: use dedicated primitive slice serializer", type_) } + elemBytes := int64(elem.Size()) return &sliceSerializer{ type_: type_, elemSerializer: elemSerializer, referencable: isRefType(elem, xlang), + elemBytes: elemBytes, + maxLength: maxSliceLength(elemBytes), }, nil } @@ -308,6 +313,9 @@ func (s *sliceSerializer) ReadData(ctx *ReadContext, value reflect.Value) { } isArrayType := value.Type().Kind() == reflect.Array + if !isArrayType && !ctx.reserveSliceMemory(length, s.elemBytes, s.maxLength) { + return + } if length == 0 { if !isArrayType { value.Set(reflect.MakeSlice(value.Type(), 0, 0)) diff --git a/go/fory/slice_dyn.go b/go/fory/slice_dyn.go index 907fcddd4f..d341e9455b 100644 --- a/go/fory/slice_dyn.go +++ b/go/fory/slice_dyn.go @@ -31,35 +31,43 @@ type sliceDynSerializer struct { elemType reflect.Type isInterfaceElem bool isPointerElem bool + elemBytes int64 + maxLength int64 } // newSliceDynSerializer creates a new sliceDynSerializer. // This serializer is ONLY for slices with interface or pointer to interface element types. // For other slice types, use sliceSerializer instead. -func newSliceDynSerializer(elemType reflect.Type) (sliceDynSerializer, error) { +func newSliceDynSerializer(elemType reflect.Type) (*sliceDynSerializer, error) { // Nil element type is allowed for fully dynamic slices (e.g., []any) if elemType == nil { - return sliceDynSerializer{ + elemBytes := containerSizeOf[any]() + return &sliceDynSerializer{ isInterfaceElem: true, + elemBytes: elemBytes, + maxLength: maxSliceLength(elemBytes), }, nil } // Validate element type is interface or pointer to interface isInterface := elemType.Kind() == reflect.Interface isPointerToInterface := elemType.Kind() == reflect.Ptr && elemType.Elem().Kind() == reflect.Interface if !isInterface && !isPointerToInterface { - return sliceDynSerializer{}, fmt.Errorf( + return nil, fmt.Errorf( "sliceDynSerializer only supports interface or pointer to interface element types, got %v; use sliceSerializer for other types", elemType) } - return sliceDynSerializer{ + elemBytes := int64(elemType.Size()) + return &sliceDynSerializer{ elemType: elemType, isInterfaceElem: isInterface, isPointerElem: isPointerToInterface, + elemBytes: elemBytes, + maxLength: maxSliceLength(elemBytes), }, nil } // mustNewSliceDynSerializer is like newSliceDynSerializer but panics on error. // Used for initialization code where the element type is known to be valid. -func mustNewSliceDynSerializer(elemType reflect.Type) sliceDynSerializer { +func mustNewSliceDynSerializer(elemType reflect.Type) *sliceDynSerializer { s, err := newSliceDynSerializer(elemType) if err != nil { panic(err) @@ -67,7 +75,7 @@ func mustNewSliceDynSerializer(elemType reflect.Type) sliceDynSerializer { return s } -func (s sliceDynSerializer) Write(ctx *WriteContext, refMode RefMode, writeType bool, hasGenerics bool, value reflect.Value) { +func (s *sliceDynSerializer) Write(ctx *WriteContext, refMode RefMode, writeType bool, hasGenerics bool, value reflect.Value) { done := writeSliceRefAndType(ctx, refMode, writeType, value, LIST) if done || ctx.HasError() { return @@ -75,7 +83,7 @@ func (s sliceDynSerializer) Write(ctx *WriteContext, refMode RefMode, writeType s.WriteData(ctx, value) } -func (s sliceDynSerializer) WriteData(ctx *WriteContext, value reflect.Value) { +func (s *sliceDynSerializer) WriteData(ctx *WriteContext, value reflect.Value) { buf := ctx.Buffer() // Get slice length and handle empty slice case length := value.Len() @@ -103,7 +111,7 @@ func (s sliceDynSerializer) WriteData(ctx *WriteContext, value reflect.Value) { // - Type consistency flags // - Element type information (if homogeneous) // Returns pointer to TypeInfo to avoid copy overhead. -func (s sliceDynSerializer) writeHeader(ctx *WriteContext, buf *ByteBuffer, value reflect.Value) (byte, *TypeInfo) { +func (s *sliceDynSerializer) writeHeader(ctx *WriteContext, buf *ByteBuffer, value reflect.Value) (byte, *TypeInfo) { collectFlag := CollectionDefaultFlag var elemTypeInfo *TypeInfo hasNull := false @@ -161,7 +169,7 @@ func (s sliceDynSerializer) writeHeader(ctx *WriteContext, buf *ByteBuffer, valu } // writeSameType efficiently serializes a slice where all elements share the same type -func (s sliceDynSerializer) writeSameType( +func (s *sliceDynSerializer) writeSameType( ctx *WriteContext, buf *ByteBuffer, value reflect.Value, typeInfo *TypeInfo, flag byte) { if typeInfo == nil { return @@ -194,7 +202,7 @@ func (s sliceDynSerializer) writeSameType( } // writeDifferentTypes handles serialization of slices with mixed element types -func (s sliceDynSerializer) writeDifferentTypes(ctx *WriteContext, buf *ByteBuffer, value reflect.Value, flag byte) { +func (s *sliceDynSerializer) writeDifferentTypes(ctx *WriteContext, buf *ByteBuffer, value reflect.Value, flag byte) { trackRefs := (flag & CollectionTrackingRef) != 0 hasNull := (flag & CollectionHasNull) != 0 @@ -246,7 +254,7 @@ func (s sliceDynSerializer) writeDifferentTypes(ctx *WriteContext, buf *ByteBuff } } -func (s sliceDynSerializer) Read(ctx *ReadContext, refMode RefMode, readType bool, hasGenerics bool, value reflect.Value) { +func (s *sliceDynSerializer) Read(ctx *ReadContext, refMode RefMode, readType bool, hasGenerics bool, value reflect.Value) { done, typeId := readSliceRefAndType(ctx, refMode, readType, value) if done || ctx.HasError() { return @@ -258,11 +266,11 @@ func (s sliceDynSerializer) Read(ctx *ReadContext, refMode RefMode, readType boo s.ReadData(ctx, value) } -func (s sliceDynSerializer) ReadData(ctx *ReadContext, value reflect.Value) { +func (s *sliceDynSerializer) ReadData(ctx *ReadContext, value reflect.Value) { s.readData(ctx, value, -1) } -func (s sliceDynSerializer) readData(ctx *ReadContext, value reflect.Value, expectedLength int) { +func (s *sliceDynSerializer) readData(ctx *ReadContext, value reflect.Value, expectedLength int) { buf := ctx.Buffer() ctxErr := ctx.Err() length := ctx.ReadCollectionLength() @@ -274,6 +282,10 @@ func (s sliceDynSerializer) readData(ctx *ReadContext, value reflect.Value, expe ctx.SetError(DeserializationErrorf("array length %d does not match serialized length %d", expectedLength, length)) return } + allocatedByCaller := expectedLength >= 0 + if !allocatedByCaller && !ctx.reserveSliceMemory(length, s.elemBytes, s.maxLength) { + return + } if length == 0 { value.Set(reflect.MakeSlice(sliceType, 0, 0)) return @@ -305,7 +317,9 @@ func (s sliceDynSerializer) readData(ctx *ReadContext, value reflect.Value, expe if !buf.CheckReadable(length, ctxErr) { return } - value.Set(reflect.MakeSlice(sliceType, length, length)) + if !allocatedByCaller { + value.Set(reflect.MakeSlice(sliceType, length, length)) + } ctx.RefResolver().Reference(value) s.readSameType(ctx, buf, value, elemType, elemSerializer, collectFlag, length) return @@ -313,18 +327,20 @@ func (s sliceDynSerializer) readData(ctx *ReadContext, value reflect.Value, expe if !buf.CheckReadable(length, ctxErr) { return } - value.Set(reflect.MakeSlice(sliceType, length, length)) + if !allocatedByCaller { + value.Set(reflect.MakeSlice(sliceType, length, length)) + } ctx.RefResolver().Reference(value) s.readDifferentTypes(ctx, buf, value, collectFlag, length) } -func (s sliceDynSerializer) ReadWithTypeInfo(ctx *ReadContext, refMode RefMode, typeInfo *TypeInfo, value reflect.Value) { +func (s *sliceDynSerializer) ReadWithTypeInfo(ctx *ReadContext, refMode RefMode, typeInfo *TypeInfo, value reflect.Value) { // typeInfo is already read, don't read it again s.Read(ctx, refMode, false, false, value) } // readSameType handles deserialization of slices where all elements share the same type -func (s sliceDynSerializer) readSameType(ctx *ReadContext, buf *ByteBuffer, value reflect.Value, elemType reflect.Type, serializer Serializer, flag int8, length int) { +func (s *sliceDynSerializer) readSameType(ctx *ReadContext, buf *ByteBuffer, value reflect.Value, elemType reflect.Type, serializer Serializer, flag int8, length int) { trackRefs := (flag & CollectionTrackingRef) != 0 hasNull := (flag & CollectionHasNull) != 0 ctxErr := ctx.Err() @@ -402,7 +418,7 @@ func (s sliceDynSerializer) readSameType(ctx *ReadContext, buf *ByteBuffer, valu } // readDifferentTypes handles deserialization of slices with mixed element types -func (s sliceDynSerializer) readDifferentTypes( +func (s *sliceDynSerializer) readDifferentTypes( ctx *ReadContext, buf *ByteBuffer, value reflect.Value, flag int8, length int) { trackRefs := (flag & CollectionTrackingRef) != 0 hasNull := (flag & CollectionHasNull) != 0 @@ -464,7 +480,7 @@ func (s sliceDynSerializer) readDifferentTypes( // 1. Slice element type is pointer-to-interface and the deserialized type is not a pointer, OR // 2. Slice element type is interface and the deserialized type doesn't directly implement it // but the pointer type does (common case where interface has pointer receivers) -func (s sliceDynSerializer) wrapSerializerIfNeeded(elemType reflect.Type, serializer Serializer) (reflect.Type, Serializer) { +func (s *sliceDynSerializer) wrapSerializerIfNeeded(elemType reflect.Type, serializer Serializer) (reflect.Type, Serializer) { if elemType.Kind() == reflect.Ptr { return elemType, serializer } diff --git a/go/fory/slice_primitive.go b/go/fory/slice_primitive.go index 9b92691ac8..88e5d50b08 100644 --- a/go/fory/slice_primitive.go +++ b/go/fory/slice_primitive.go @@ -652,6 +652,9 @@ func (s stringSliceSerializer) ReadData(ctx *ReadContext, value reflect.Value) { return } ptr := (*[]string)(value.Addr().UnsafePointer()) + if !ctx.reserveSliceMemory(length, stringElementBytes, stringSliceMaxLength) { + return + } if length == 0 { *ptr = make([]string, 0) return diff --git a/go/fory/slice_primitive_list.go b/go/fory/slice_primitive_list.go index 0335b2a08e..e033fb4409 100644 --- a/go/fory/slice_primitive_list.go +++ b/go/fory/slice_primitive_list.go @@ -25,6 +25,18 @@ import ( type primitiveListSerializer struct { type_ reflect.Type elemTypeID TypeId + elemBytes int64 + maxLength int64 +} + +func newPrimitiveList(type_ reflect.Type, elemTypeID TypeId, elemType reflect.Type) primitiveListSerializer { + elemBytes := int64(elemType.Size()) + return primitiveListSerializer{ + type_: type_, + elemTypeID: elemTypeID, + elemBytes: elemBytes, + maxLength: maxSliceLength(elemBytes), + } } type compatiblePrimitiveListToArraySerializer struct { @@ -39,43 +51,43 @@ func newPrimitiveListSerializer(type_ reflect.Type, elemTypeID TypeId) (Serializ elemType := type_.Elem() switch elemType.Kind() { case reflect.Bool: - return primitiveListSerializer{type_: type_, elemTypeID: elemTypeID}, elemTypeID == BOOL + return newPrimitiveList(type_, elemTypeID, elemType), elemTypeID == BOOL case reflect.Int8: - return primitiveListSerializer{type_: type_, elemTypeID: elemTypeID}, elemTypeID == INT8 + return newPrimitiveList(type_, elemTypeID, elemType), elemTypeID == INT8 case reflect.Uint8: - return primitiveListSerializer{type_: type_, elemTypeID: elemTypeID}, elemTypeID == UINT8 + return newPrimitiveList(type_, elemTypeID, elemType), elemTypeID == UINT8 case reflect.Int16: - return primitiveListSerializer{type_: type_, elemTypeID: elemTypeID}, elemTypeID == INT16 + return newPrimitiveList(type_, elemTypeID, elemType), elemTypeID == INT16 case reflect.Uint16: if elemType == float16Type { - return primitiveListSerializer{type_: type_, elemTypeID: elemTypeID}, elemTypeID == FLOAT16 + return newPrimitiveList(type_, elemTypeID, elemType), elemTypeID == FLOAT16 } if elemType == bfloat16Type { - return primitiveListSerializer{type_: type_, elemTypeID: elemTypeID}, elemTypeID == BFLOAT16 + return newPrimitiveList(type_, elemTypeID, elemType), elemTypeID == BFLOAT16 } - return primitiveListSerializer{type_: type_, elemTypeID: elemTypeID}, elemTypeID == UINT16 + return newPrimitiveList(type_, elemTypeID, elemType), elemTypeID == UINT16 case reflect.Int32: - return primitiveListSerializer{type_: type_, elemTypeID: elemTypeID}, elemTypeID == INT32 || elemTypeID == VARINT32 + return newPrimitiveList(type_, elemTypeID, elemType), elemTypeID == INT32 || elemTypeID == VARINT32 case reflect.Uint32: - return primitiveListSerializer{type_: type_, elemTypeID: elemTypeID}, elemTypeID == UINT32 || elemTypeID == VAR_UINT32 + return newPrimitiveList(type_, elemTypeID, elemType), elemTypeID == UINT32 || elemTypeID == VAR_UINT32 case reflect.Int64: - return primitiveListSerializer{type_: type_, elemTypeID: elemTypeID}, elemTypeID == INT64 || elemTypeID == VARINT64 || elemTypeID == TAGGED_INT64 + return newPrimitiveList(type_, elemTypeID, elemType), elemTypeID == INT64 || elemTypeID == VARINT64 || elemTypeID == TAGGED_INT64 case reflect.Uint64: - return primitiveListSerializer{type_: type_, elemTypeID: elemTypeID}, elemTypeID == UINT64 || elemTypeID == VAR_UINT64 || elemTypeID == TAGGED_UINT64 + return newPrimitiveList(type_, elemTypeID, elemType), elemTypeID == UINT64 || elemTypeID == VAR_UINT64 || elemTypeID == TAGGED_UINT64 case reflect.Int: if reflect.TypeOf(int(0)).Size() == 8 { - return primitiveListSerializer{type_: type_, elemTypeID: elemTypeID}, elemTypeID == INT64 || elemTypeID == VARINT64 || elemTypeID == TAGGED_INT64 + return newPrimitiveList(type_, elemTypeID, elemType), elemTypeID == INT64 || elemTypeID == VARINT64 || elemTypeID == TAGGED_INT64 } - return primitiveListSerializer{type_: type_, elemTypeID: elemTypeID}, elemTypeID == INT32 || elemTypeID == VARINT32 + return newPrimitiveList(type_, elemTypeID, elemType), elemTypeID == INT32 || elemTypeID == VARINT32 case reflect.Uint: if reflect.TypeOf(uint(0)).Size() == 8 { - return primitiveListSerializer{type_: type_, elemTypeID: elemTypeID}, elemTypeID == UINT64 || elemTypeID == VAR_UINT64 || elemTypeID == TAGGED_UINT64 + return newPrimitiveList(type_, elemTypeID, elemType), elemTypeID == UINT64 || elemTypeID == VAR_UINT64 || elemTypeID == TAGGED_UINT64 } - return primitiveListSerializer{type_: type_, elemTypeID: elemTypeID}, elemTypeID == UINT32 || elemTypeID == VAR_UINT32 + return newPrimitiveList(type_, elemTypeID, elemType), elemTypeID == UINT32 || elemTypeID == VAR_UINT32 case reflect.Float32: - return primitiveListSerializer{type_: type_, elemTypeID: elemTypeID}, elemTypeID == FLOAT32 + return newPrimitiveList(type_, elemTypeID, elemType), elemTypeID == FLOAT32 case reflect.Float64: - return primitiveListSerializer{type_: type_, elemTypeID: elemTypeID}, elemTypeID == FLOAT64 + return newPrimitiveList(type_, elemTypeID, elemType), elemTypeID == FLOAT64 default: return nil, false } @@ -167,6 +179,9 @@ func (s primitiveListSerializer) ReadData(ctx *ReadContext, value reflect.Value) if ctx.HasError() { return } + if !ctx.reserveSliceMemory(length, s.elemBytes, s.maxLength) { + return + } if length == 0 { value.Set(reflect.MakeSlice(value.Type(), 0, 0)) return @@ -228,6 +243,9 @@ func (s compatiblePrimitiveListToArraySerializer) ReadData(ctx *ReadContext, val } if length == 0 { if value.Kind() == reflect.Slice { + if !ctx.reserveSliceMemory(length, s.listReader.elemBytes, s.listReader.maxLength) { + return + } value.Set(reflect.MakeSlice(value.Type(), 0, 0)) } else if value.Len() != 0 { ctx.SetError(DeserializationErrorf("array-compatible list length %d does not match array length %d", length, value.Len())) @@ -266,6 +284,9 @@ func (s compatiblePrimitiveListToArraySerializer) ReadData(ctx *ReadContext, val return } if value.Kind() == reflect.Slice { + if !ctx.reserveSliceMemory(length, s.listReader.elemBytes, s.listReader.maxLength) { + return + } temp := reflect.New(value.Type()).Elem() s.listReader.readValues(buf, err, temp, length, false) if ctx.HasError() { diff --git a/go/fory/stream.go b/go/fory/stream.go index bb86689598..45111695e5 100644 --- a/go/fory/stream.go +++ b/go/fory/stream.go @@ -96,6 +96,13 @@ func (is *InputStream) Shrink() { func (f *Fory) DeserializeFromStream(is *InputStream, v any) error { origBuffer := f.readCtx.buffer f.readCtx.buffer = is.buffer + f.readCtx.initContainerMemoryBudget(0, true) + if f.readCtx.HasError() { + err := f.readCtx.TakeError() + f.readCtx.buffer = origBuffer + f.resetReadState() + return err + } defer func() { f.readCtx.buffer = origBuffer f.resetReadState() @@ -123,6 +130,10 @@ func (f *Fory) DeserializeFromReader(r io.Reader, v any) error { defer f.resetReadState() // Always reset to enforce stateless semantics. f.readCtx.buffer.ResetWithReader(r, 0) + f.readCtx.initContainerMemoryBudget(0, true) + if f.readCtx.HasError() { + return f.readCtx.TakeError() + } readHeader(f.readCtx) if f.readCtx.HasError() { diff --git a/go/fory/tests/structs_fory_gen.go b/go/fory/tests/structs_fory_gen.go index 742135a8ba..ca639979a5 100644 --- a/go/fory/tests/structs_fory_gen.go +++ b/go/fory/tests/structs_fory_gen.go @@ -1,12 +1,13 @@ // Code generated by forygen. DO NOT EDIT. // source: structs.go -// generated at: 2026-06-12T06:41:26+08:00 +// generated at: 2026-06-26T15:00:42+08:00 package fory import ( "github.com/apache/fory/go/fory" "reflect" + "unsafe" ) func init() { @@ -189,6 +190,9 @@ func (g *DynamicSliceDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v if ctx.HasError() { return ctx.TakeError() } + if !ctx.ReserveSliceMemory(sliceLen, int64(unsafe.Sizeof(*new(any)))) { + return ctx.TakeError() + } if sliceLen == 0 { v.DynamicSlice = make([]any, 0) } else { @@ -217,6 +221,9 @@ func (g *DynamicSliceDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v if ctx.HasError() { return ctx.TakeError() } + if !ctx.ReserveSliceMemory(sliceLen, int64(unsafe.Sizeof(*new(any)))) { + return ctx.TakeError() + } if sliceLen == 0 { v.DynamicSlice = make([]any, 0) } else { @@ -662,6 +669,12 @@ func (g *MapDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v *MapDemo) if isXlang { // xlang mode: maps are not nullable, read directly without null flag mapLen := ctx.ReadCollectionLength() + if ctx.HasError() { + return ctx.TakeError() + } + if !ctx.ReserveMapMemory(mapLen, int64(unsafe.Sizeof(*new(int))), int64(unsafe.Sizeof(*new(int)))) { + return ctx.TakeError() + } if mapLen == 0 { v.IntMap = make(map[int]int) } else { @@ -709,6 +722,9 @@ func (g *MapDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v *MapDemo) if ctx.HasError() { return ctx.TakeError() } + if !ctx.ReserveMapMemory(mapLen, int64(unsafe.Sizeof(*new(int))), int64(unsafe.Sizeof(*new(int)))) { + return ctx.TakeError() + } if mapLen == 0 { v.IntMap = make(map[int]int) } else { @@ -755,6 +771,12 @@ func (g *MapDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v *MapDemo) if isXlang { // xlang mode: maps are not nullable, read directly without null flag mapLen := ctx.ReadCollectionLength() + if ctx.HasError() { + return ctx.TakeError() + } + if !ctx.ReserveMapMemory(mapLen, int64(unsafe.Sizeof(*new(string))), int64(unsafe.Sizeof(*new(int)))) { + return ctx.TakeError() + } if mapLen == 0 { v.MixedMap = make(map[string]int) } else { @@ -802,6 +824,9 @@ func (g *MapDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v *MapDemo) if ctx.HasError() { return ctx.TakeError() } + if !ctx.ReserveMapMemory(mapLen, int64(unsafe.Sizeof(*new(string))), int64(unsafe.Sizeof(*new(int)))) { + return ctx.TakeError() + } if mapLen == 0 { v.MixedMap = make(map[string]int) } else { @@ -848,6 +873,12 @@ func (g *MapDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v *MapDemo) if isXlang { // xlang mode: maps are not nullable, read directly without null flag mapLen := ctx.ReadCollectionLength() + if ctx.HasError() { + return ctx.TakeError() + } + if !ctx.ReserveMapMemory(mapLen, int64(unsafe.Sizeof(*new(string))), int64(unsafe.Sizeof(*new(string)))) { + return ctx.TakeError() + } if mapLen == 0 { v.StringMap = make(map[string]string) } else { @@ -895,6 +926,9 @@ func (g *MapDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v *MapDemo) if ctx.HasError() { return ctx.TakeError() } + if !ctx.ReserveMapMemory(mapLen, int64(unsafe.Sizeof(*new(string))), int64(unsafe.Sizeof(*new(string)))) { + return ctx.TakeError() + } if mapLen == 0 { v.StringMap = make(map[string]string) } else { @@ -1250,6 +1284,12 @@ func (g *SliceDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v *SliceD if isXlang { // xlang mode: slices are not nullable, read directly without null flag sliceLen := ctx.ReadCollectionLength() + if ctx.HasError() { + return ctx.TakeError() + } + if !ctx.ReserveSliceMemory(sliceLen, int64(unsafe.Sizeof(*new(bool)))) { + return ctx.TakeError() + } if sliceLen == 0 { v.BoolSlice = make([]bool, 0) } else { @@ -1289,6 +1329,9 @@ func (g *SliceDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v *SliceD if ctx.HasError() { return ctx.TakeError() } + if !ctx.ReserveSliceMemory(sliceLen, int64(unsafe.Sizeof(*new(bool)))) { + return ctx.TakeError() + } if sliceLen == 0 { v.BoolSlice = make([]bool, 0) } else { @@ -1327,6 +1370,12 @@ func (g *SliceDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v *SliceD if isXlang { // xlang mode: slices are not nullable, read directly without null flag sliceLen := ctx.ReadCollectionLength() + if ctx.HasError() { + return ctx.TakeError() + } + if !ctx.ReserveSliceMemory(sliceLen, int64(unsafe.Sizeof(*new(float64)))) { + return ctx.TakeError() + } if sliceLen == 0 { v.FloatSlice = make([]float64, 0) } else { @@ -1366,6 +1415,9 @@ func (g *SliceDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v *SliceD if ctx.HasError() { return ctx.TakeError() } + if !ctx.ReserveSliceMemory(sliceLen, int64(unsafe.Sizeof(*new(float64)))) { + return ctx.TakeError() + } if sliceLen == 0 { v.FloatSlice = make([]float64, 0) } else { @@ -1404,6 +1456,12 @@ func (g *SliceDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v *SliceD if isXlang { // xlang mode: slices are not nullable, read directly without null flag sliceLen := ctx.ReadCollectionLength() + if ctx.HasError() { + return ctx.TakeError() + } + if !ctx.ReserveSliceMemory(sliceLen, int64(unsafe.Sizeof(*new(int32)))) { + return ctx.TakeError() + } if sliceLen == 0 { v.IntSlice = make([]int32, 0) } else { @@ -1443,6 +1501,9 @@ func (g *SliceDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v *SliceD if ctx.HasError() { return ctx.TakeError() } + if !ctx.ReserveSliceMemory(sliceLen, int64(unsafe.Sizeof(*new(int32)))) { + return ctx.TakeError() + } if sliceLen == 0 { v.IntSlice = make([]int32, 0) } else { @@ -1481,6 +1542,12 @@ func (g *SliceDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v *SliceD if isXlang { // xlang mode: slices are not nullable, read directly without null flag sliceLen := ctx.ReadCollectionLength() + if ctx.HasError() { + return ctx.TakeError() + } + if !ctx.ReserveSliceMemory(sliceLen, int64(unsafe.Sizeof(*new(string)))) { + return ctx.TakeError() + } if sliceLen == 0 { v.StringSlice = make([]string, 0) } else { @@ -1528,6 +1595,9 @@ func (g *SliceDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v *SliceD if ctx.HasError() { return ctx.TakeError() } + if !ctx.ReserveSliceMemory(sliceLen, int64(unsafe.Sizeof(*new(string)))) { + return ctx.TakeError() + } if sliceLen == 0 { v.StringSlice = make([]string, 0) } else { diff --git a/go/fory/type_resolver.go b/go/fory/type_resolver.go index f0966b0596..ba4d3e5dc9 100644 --- a/go/fory/type_resolver.go +++ b/go/fory/type_resolver.go @@ -415,7 +415,7 @@ func (r *TypeResolver) initialize() { {stringPtrType, STRING, ptrToStringSerializer{}}, // Register interface types first so typeIDToTypeInfo maps to generic types // that can hold any element type when deserializing into any - {interfaceSliceType, LIST, sliceDynSerializer{}}, + {interfaceSliceType, LIST, mustNewSliceDynSerializer(interfaceType)}, {interfaceMapType, MAP, mapSerializer{type_: interfaceMapType, keyReferencable: true, valueReferencable: true}}, // stringSliceType uses dedicated stringSliceSerializer for optimized serialization // This ensures CollectionIsDeclElementType is set for Java compatibility @@ -1779,7 +1779,7 @@ func (r *TypeResolver) createSerializer(type_ reflect.Type, mapInStruct bool) (s } // For dynamic types, use dynamic slice serializer if isDynamicType(elem) { - return sliceDynSerializer{}, nil + return newSliceDynSerializer(elem) } else { elemSerializer, err := r.getSerializerByType(type_.Elem(), false) if err != nil { diff --git a/java/fory-core/src/main/java/org/apache/fory/Fory.java b/java/fory-core/src/main/java/org/apache/fory/Fory.java index a7916400ae..5b2bfef3ff 100644 --- a/java/fory-core/src/main/java/org/apache/fory/Fory.java +++ b/java/fory-core/src/main/java/org/apache/fory/Fory.java @@ -425,12 +425,17 @@ public T deserialize(byte[] bytes, Class type) { @Override public T deserialize(MemoryBuffer buffer, Class type) { + return deserialize(buffer, type, false); + } + + private T deserialize(MemoryBuffer buffer, Class type, boolean unknownLengthInput) { ensureRegistrationFinished(); + int rootInputBytes = buffer.remaining(); byte bitmap = buffer.readByte(); if (bitmap != headerBitmap) { checkHeaderBitmapWithoutOutOfBand(bitmap); } - readContext.prepare(buffer, null, false); + readContext.prepare(buffer, null, false, rootInputBytes, unknownLengthInput); try { try { jitContext.lock(); @@ -451,7 +456,7 @@ public T deserialize(MemoryBuffer buffer, Class type) { @Override public T deserialize(ForyInputStream inputStream, Class type) { try { - return deserialize(inputStream.getBuffer(), type); + return deserialize(inputStream.getBuffer(), type, true); } finally { inputStream.shrinkBuffer(); } @@ -459,7 +464,7 @@ public T deserialize(ForyInputStream inputStream, Class type) { @Override public T deserialize(ForyReadableChannel channel, Class type) { - return deserialize(channel.getBuffer(), type); + return deserialize(channel.getBuffer(), type, true); } @Override @@ -487,7 +492,13 @@ public Object deserialize(MemoryBuffer buffer) { */ @Override public Object deserialize(MemoryBuffer buffer, Iterable outOfBandBuffers) { + return deserialize(buffer, outOfBandBuffers, false); + } + + private Object deserialize( + MemoryBuffer buffer, Iterable outOfBandBuffers, boolean unknownLengthInput) { ensureRegistrationFinished(); + int rootInputBytes = buffer.remaining(); byte bitmap = buffer.readByte(); boolean peerOutOfBandEnabled = false; if (bitmap != headerBitmap) { @@ -505,7 +516,11 @@ public Object deserialize(MemoryBuffer buffer, Iterable outOfBandB + "produced with bufferCallback null."); } readContext.prepare( - buffer, peerOutOfBandEnabled ? outOfBandBuffers : null, peerOutOfBandEnabled); + buffer, + peerOutOfBandEnabled ? outOfBandBuffers : null, + peerOutOfBandEnabled, + rootInputBytes, + unknownLengthInput); try { try { jitContext.lock(); @@ -532,7 +547,7 @@ public Object deserialize(ForyInputStream inputStream) { public Object deserialize(ForyInputStream inputStream, Iterable outOfBandBuffers) { try { MemoryBuffer buf = inputStream.getBuffer(); - return deserialize(buf, outOfBandBuffers); + return deserialize(buf, outOfBandBuffers, true); } finally { inputStream.shrinkBuffer(); } @@ -546,7 +561,7 @@ public Object deserialize(ForyReadableChannel channel) { @Override public Object deserialize(ForyReadableChannel channel, Iterable outOfBandBuffers) { MemoryBuffer buf = channel.getBuffer(); - return deserialize(buf, outOfBandBuffers); + return deserialize(buf, outOfBandBuffers, true); } @SuppressWarnings("unchecked") diff --git a/java/fory-core/src/main/java/org/apache/fory/config/Config.java b/java/fory-core/src/main/java/org/apache/fory/config/Config.java index 2b8db3ec66..b96e4aa83d 100644 --- a/java/fory-core/src/main/java/org/apache/fory/config/Config.java +++ b/java/fory-core/src/main/java/org/apache/fory/config/Config.java @@ -68,6 +68,7 @@ public class Config implements Serializable { private final int maxTypeMetaBytes; private final int maxSchemaVersionsPerType; private final int maxAverageSchemaVersionsPerType; + private final long maxContainerMemoryBytes; private final float mapRefLoadFactor; private final boolean forVirtualThread; @@ -114,6 +115,7 @@ public Config(ForyBuilder builder) { maxTypeMetaBytes = builder.maxTypeMetaBytes; maxSchemaVersionsPerType = builder.maxSchemaVersionsPerType; maxAverageSchemaVersionsPerType = builder.maxAverageSchemaVersionsPerType; + maxContainerMemoryBytes = builder.maxContainerMemoryBytes; mapRefLoadFactor = builder.mapRefLoadFactor; forVirtualThread = builder.forVirtualThread; } @@ -320,6 +322,11 @@ public int maxAverageSchemaVersionsPerType() { return maxAverageSchemaVersionsPerType; } + /** Returns the root-operation estimated container memory limit in bytes, or -1 for auto. */ + public long maxContainerMemoryBytes() { + return maxContainerMemoryBytes; + } + /** Returns loadFactor of MacRef's writtenObjects. */ public float mapRefLoadFactor() { return mapRefLoadFactor; @@ -368,6 +375,7 @@ public boolean equals(Object o) { && maxTypeMetaBytes == config.maxTypeMetaBytes && maxSchemaVersionsPerType == config.maxSchemaVersionsPerType && maxAverageSchemaVersionsPerType == config.maxAverageSchemaVersionsPerType + && maxContainerMemoryBytes == config.maxContainerMemoryBytes && Objects.equals(defaultJDKStreamSerializerType, config.defaultJDKStreamSerializerType) && longEncoding == config.longEncoding && forVirtualThread == config.forVirtualThread; @@ -403,6 +411,7 @@ public int hashCode() { maxTypeMetaBytes, maxSchemaVersionsPerType, maxAverageSchemaVersionsPerType, + maxContainerMemoryBytes, metaShareEnabled, scopedMetaShareEnabled, metaCompressor, diff --git a/java/fory-core/src/main/java/org/apache/fory/config/ForyBuilder.java b/java/fory-core/src/main/java/org/apache/fory/config/ForyBuilder.java index 48d9dcb433..93d4943940 100644 --- a/java/fory-core/src/main/java/org/apache/fory/config/ForyBuilder.java +++ b/java/fory-core/src/main/java/org/apache/fory/config/ForyBuilder.java @@ -103,6 +103,7 @@ public final class ForyBuilder { int maxTypeMetaBytes = 4096; int maxSchemaVersionsPerType = 10; int maxAverageSchemaVersionsPerType = 3; + long maxContainerMemoryBytes = -1; float mapRefLoadFactor = 0.51f; boolean forVirtualThread = false; TypeChecker typeChecker; @@ -571,6 +572,22 @@ public ForyBuilder withMaxAverageSchemaVersionsPerType(int maxAverageSchemaVersi return this; } + /** + * Sets the maximum estimated container-owned memory accepted during one root deserialization. + * + *

The default is {@code -1}, which derives an automatic per-root budget from the input shape. + * Positive values are explicit byte limits. Other values are invalid. + */ + public ForyBuilder withMaxContainerMemoryBytes(long maxContainerMemoryBytes) { + Preconditions.checkArgument( + maxContainerMemoryBytes == -1 || maxContainerMemoryBytes > 0, + "maxContainerMemoryBytes must be positive or -1 for auto but got %s", + maxContainerMemoryBytes); + this.maxContainerMemoryBytes = maxContainerMemoryBytes; + recordAction(b -> b.withMaxContainerMemoryBytes(maxContainerMemoryBytes)); + return this; + } + /** Set loadFactor of MapRefResolver writtenObjects. Default value is 0.51 */ public ForyBuilder withMapRefLoadFactor(float loadFactor) { Preconditions.checkArgument( diff --git a/java/fory-core/src/main/java/org/apache/fory/context/ReadContext.java b/java/fory-core/src/main/java/org/apache/fory/context/ReadContext.java index b1000161b3..162ccdd9f5 100644 --- a/java/fory-core/src/main/java/org/apache/fory/context/ReadContext.java +++ b/java/fory-core/src/main/java/org/apache/fory/context/ReadContext.java @@ -51,6 +51,15 @@ */ @SuppressWarnings({"rawtypes", "unchecked"}) public final class ReadContext { + private static final long KNOWN_ROOT_BUDGET_MULTIPLIER = 8L; + private static final long KNOWN_ROOT_BUDGET_SLACK_BYTES = 64L * 1024; + private static final long STREAM_ROOT_BUDGET_BYTES = 128L * 1024 * 1024; + private static final long COLLECTION_OBJECT_BYTES = 24L; + private static final long MAP_OBJECT_BYTES = 48L; + private static final long ARRAY_HEADER_BYTES = 16L; + private static final long MAP_ENTRY_BYTES = 32L; + private static final int REFERENCE_BYTES = MemoryBuffer.objectArrayIndexScale(); + private final Config config; private final Generics generics; private final TypeResolver typeResolver; @@ -63,6 +72,7 @@ public final class ReadContext { private final boolean compressInt; private final Int64Encoding longEncoding; private final int maxDepth; + private final long maxContainerMemoryBytes; private final boolean scopedMetaShareEnabled; private final boolean forVirtualThread; private final IdentityHashMap contextObjects = new IdentityHashMap<>(); @@ -71,6 +81,8 @@ public final class ReadContext { private MetaReadContext metaReadContext; private boolean peerOutOfBandEnabled; private int depth; + private long containerMemoryLimitBytes; + private long remainingContainerMemoryBytes; /** * Creates read-side runtime state for one {@code Fory} instance. @@ -96,6 +108,7 @@ public ReadContext( compressInt = config.compressInt(); longEncoding = config.longEncoding(); maxDepth = config.maxDepth(); + maxContainerMemoryBytes = config.maxContainerMemoryBytes(); forVirtualThread = config.forVirtualThread(); scopedMetaShareEnabled = config.isScopedMetaShareEnabled(); if (scopedMetaShareEnabled) { @@ -108,10 +121,32 @@ public ReadContext( * flag for one operation. */ public void prepare( - MemoryBuffer buffer, Iterable outOfBandBuffers, boolean peerOutOfBandEnabled) { + MemoryBuffer buffer, + Iterable outOfBandBuffers, + boolean peerOutOfBandEnabled, + int rootInputBytes, + boolean unknownLengthInput) { this.buffer = buffer; this.peerOutOfBandEnabled = peerOutOfBandEnabled; this.outOfBandBuffers = outOfBandBuffers == null ? null : outOfBandBuffers.iterator(); + initContainerMemoryBudget(rootInputBytes, unknownLengthInput); + } + + private void initContainerMemoryBudget(int rootInputBytes, boolean unknownLengthInput) { + long limit = maxContainerMemoryBytes; + if (limit <= 0) { + if (unknownLengthInput) { + limit = STREAM_ROOT_BUDGET_BYTES; + } else { + if (rootInputBytes < 0) { + throw new IllegalArgumentException( + "Root input size must be non-negative: " + rootInputBytes); + } + limit = rootInputBytes * KNOWN_ROOT_BUDGET_MULTIPLIER + KNOWN_ROOT_BUDGET_SLACK_BYTES; + } + } + containerMemoryLimitBytes = limit; + remainingContainerMemoryBytes = limit; } /** @@ -307,6 +342,8 @@ public void reset() { outOfBandBuffers = null; peerOutOfBandEnabled = false; depth = 0; + containerMemoryLimitBytes = 0; + remainingContainerMemoryBytes = 0; } /** Returns the immutable runtime configuration for this context. */ @@ -314,6 +351,52 @@ public Config getConfig() { return config; } + public void reserveCollectionMemory(int numElements) { + reserveContainerMemory(COLLECTION_OBJECT_BYTES + (long) numElements * REFERENCE_BYTES); + } + + public void reserveCollectionCapacity(int numElements, int capacity) { + reserveContainerMemory((long) (capacity - numElements) * REFERENCE_BYTES); + } + + public void reserveMapMemory(int numElements) { + long entries = (long) numElements; + long tableBytes = entries * 2 * REFERENCE_BYTES; + long entryBytes = entries * (MAP_ENTRY_BYTES + 3L * REFERENCE_BYTES); + reserveContainerMemory(MAP_OBJECT_BYTES + tableBytes + entryBytes); + } + + public void reserveObjectArrayMemory(int numElements) { + reserveContainerMemory(ARRAY_HEADER_BYTES + (long) numElements * REFERENCE_BYTES); + } + + public void reserveContainerMemory(long bytes) { + if (bytes < 0) { + throwNegativeContainerMemory(bytes); + } + long remaining = remainingContainerMemoryBytes; + if (bytes > remaining) { + throwContainerMemoryExceeded(bytes, remaining); + } + remainingContainerMemoryBytes = remaining - bytes; + } + + private void throwNegativeContainerMemory(long bytes) { + throw new InsecureException( + "Estimated container memory must be non-negative, but got " + bytes + " bytes."); + } + + private void throwContainerMemoryExceeded(long bytes, long remaining) { + throw new InsecureException( + "Estimated container memory request " + + bytes + + " bytes exceeds maxContainerMemoryBytes remaining budget " + + remaining + + " bytes out of effective limit " + + containerMemoryLimitBytes + + " bytes. If the data is trusted, increase ForyBuilder#withMaxContainerMemoryBytes."); + } + /** Returns the generics stack shared by the owning runtime. */ public Generics getGenerics() { return generics; diff --git a/java/fory-core/src/main/java/org/apache/fory/memory/MemoryBuffer.java b/java/fory-core/src/main/java/org/apache/fory/memory/MemoryBuffer.java index c0f002ca98..9e82d2ba3f 100644 --- a/java/fory-core/src/main/java/org/apache/fory/memory/MemoryBuffer.java +++ b/java/fory-core/src/main/java/org/apache/fory/memory/MemoryBuffer.java @@ -77,6 +77,7 @@ public final class MemoryBuffer { private static final int LONG_ARRAY_OFFSET; private static final int FLOAT_ARRAY_OFFSET; private static final int DOUBLE_ARRAY_OFFSET; + private static final int OBJECT_ARRAY_INDEX_SCALE; // GraalVM native-image recognizes arrayBaseOffset only when the call stores directly into the // target static field. Keep these assignments in this shape so native images recompute heap array @@ -91,6 +92,7 @@ public final class MemoryBuffer { LONG_ARRAY_OFFSET = 0; FLOAT_ARRAY_OFFSET = 0; DOUBLE_ARRAY_OFFSET = 0; + OBJECT_ARRAY_INDEX_SCALE = 4; } else { BOOLEAN_ARRAY_OFFSET = UNSAFE.arrayBaseOffset(boolean[].class); BYTE_ARRAY_OFFSET = UNSAFE.arrayBaseOffset(byte[].class); @@ -100,6 +102,7 @@ public final class MemoryBuffer { LONG_ARRAY_OFFSET = UNSAFE.arrayBaseOffset(long[].class); FLOAT_ARRAY_OFFSET = UNSAFE.arrayBaseOffset(float[].class); DOUBLE_ARRAY_OFFSET = UNSAFE.arrayBaseOffset(double[].class); + OBJECT_ARRAY_INDEX_SCALE = UNSAFE.arrayIndexScale(Object[].class); } } @@ -4185,6 +4188,10 @@ public static MemoryBuffer fromDirectByteBuffer( return new MemoryBuffer(offHeapAddress, size, buffer, streamReader); } + public static int objectArrayIndexScale() { + return OBJECT_ARRAY_INDEX_SCALE > 0 ? OBJECT_ARRAY_INDEX_SCALE : 4; + } + /** * Create a heap buffer of specified initial size. The buffer will grow automatically if not * enough. diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/ArraySerializers.java b/java/fory-core/src/main/java/org/apache/fory/serializer/ArraySerializers.java index 9fe08fdfb5..52237e7082 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/ArraySerializers.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/ArraySerializers.java @@ -51,6 +51,19 @@ private static void throwInvalidObjectArraySize(int size) { throw new DeserializationException("Object array size must be non-negative: " + size); } + private static int readObjectArraySize(ReadContext readContext) { + MemoryBuffer buffer = readContext.getBuffer(); + int numElements = buffer.readVarUInt32Small7(); + // Keep this as direct primitive branches. Object-array reads allocate immediately; using + // Preconditions.checkArgument here would add helper/varargs overhead on the valid path. + if (numElements < 0) { + throwInvalidObjectArraySize(numElements); + } + readContext.reserveObjectArrayMemory(numElements); + buffer.checkReadableBytes(numElements); + return numElements; + } + /** * Returns the object-array serializer for {@code cls}. * @@ -128,14 +141,7 @@ public Object[] copy(CopyContext copyContext, Object[] originArray) { @Override public Object[] read(ReadContext readContext) { - MemoryBuffer buffer = readContext.getBuffer(); - int numElements = buffer.readVarUInt32Small7(); - // Keep this as direct primitive branches. Object-array reads allocate immediately; using - // Preconditions.checkArgument here would add helper/varargs overhead on the valid path. - if (numElements < 0) { - throwInvalidObjectArraySize(numElements); - } - buffer.checkReadableBytes(numElements); + int numElements = readObjectArraySize(readContext); Object[] value = newArray(numElements); readContext.reference(value); if (numElements != 0) { @@ -213,14 +219,7 @@ public Object[] copy(CopyContext copyContext, Object[] originArray) { @Override public Object[] read(ReadContext readContext) { - MemoryBuffer buffer = readContext.getBuffer(); - int numElements = buffer.readVarUInt32Small7(); - // Keep this as direct primitive branches. Object-array reads allocate immediately; using - // Preconditions.checkArgument here would add helper/varargs overhead on the valid path. - if (numElements < 0) { - throwInvalidObjectArraySize(numElements); - } - buffer.checkReadableBytes(numElements); + int numElements = readObjectArraySize(readContext); Object[] value = newArray(numElements); readContext.reference(value); if (numElements != 0) { @@ -654,14 +653,7 @@ public void write(WriteContext writeContext, Object[] value) { @Override public Object[] read(ReadContext readContext) { - MemoryBuffer buffer = readContext.getBuffer(); - int numElements = buffer.readVarUInt32Small7(); - // Keep this as direct primitive branches. Object-array reads allocate immediately; using - // Preconditions.checkArgument here would add helper/varargs overhead on the valid path. - if (numElements < 0) { - throwInvalidObjectArraySize(numElements); - } - buffer.checkReadableBytes(numElements); + int numElements = readObjectArraySize(readContext); Object[] value = newArray(numElements); readContext.reference(value); if (numElements != 0) { diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/CompatibleCollectionArrayReader.java b/java/fory-core/src/main/java/org/apache/fory/serializer/CompatibleCollectionArrayReader.java index 35eeca550a..b5853d433b 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/CompatibleCollectionArrayReader.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/CompatibleCollectionArrayReader.java @@ -343,18 +343,18 @@ private static Object readNotNull( if (array == null) { return null; } - return materializeTarget(array, arrayTypeId, targetType); + return materializeTarget(readContext, array, arrayTypeId, targetType); } if (readMode == READ_LIST_TO_LIST) { return readListBodyAsListTarget(readContext, arrayTypeId, elementTypeId, targetType); } if (readMode == READ_ARRAY_TO_LIST) { Object array = readDenseArrayBody(readContext, arrayTypeId); - return materializeTarget(array, arrayTypeId, targetType); + return materializeTarget(readContext, array, arrayTypeId, targetType); } if (readMode == READ_ARRAY_TO_ARRAY) { Object array = readDenseArrayBody(readContext, arrayTypeId); - return materializeTarget(array, arrayTypeId, targetType); + return materializeTarget(readContext, array, arrayTypeId, targetType); } throw new IllegalStateException("Unexpected compatible read mode " + readMode); } @@ -621,7 +621,7 @@ private static Object readListBodyAsListTarget( validateElementCount(numElements); if (numElements == 0) { Object array = readListPrimitiveElements(buffer, 0, arrayTypeId, elementTypeId, false); - return materializeTarget(array, arrayTypeId, targetType); + return materializeTarget(readContext, array, arrayTypeId, targetType); } int flags = buffer.readByte(); boolean hasNull = (flags & CollectionFlags.HAS_NULL) == CollectionFlags.HAS_NULL; @@ -654,11 +654,11 @@ private static Object readListBodyAsListTarget( throw new DeserializationException( "Cannot read null peer list element into local list field"); } - return readNullableListBoxedElements(buffer, numElements, arrayTypeId, elementTypeId); + return readNullableListBoxedElements(readContext, numElements, arrayTypeId, elementTypeId); } Object array = readListPrimitiveElements(buffer, numElements, arrayTypeId, elementTypeId, false); - return materializeTarget(array, arrayTypeId, targetType); + return materializeTarget(readContext, array, arrayTypeId, targetType); } private static Object readDenseArrayBody(ReadContext readContext, int arrayTypeId) { @@ -976,8 +976,11 @@ private static void readNonNullListElement(MemoryBuffer buffer) { } private static List readNullableListBoxedElements( - MemoryBuffer buffer, int numElements, int arrayTypeId, int elementTypeId) { - buffer.checkReadableBytes(minReadablePrimitiveListBytes(numElements, elementTypeId, true)); + ReadContext readContext, int numElements, int arrayTypeId, int elementTypeId) { + MemoryBuffer buffer = readContext.getBuffer(); + int bodyBytes = minReadablePrimitiveListBytes(numElements, elementTypeId, true); + readContext.reserveCollectionMemory(numElements); + buffer.checkReadableBytes(bodyBytes); ArrayList values = new ArrayList<>(numElements); for (int i = 0; i < numElements; i++) { byte headFlag = buffer.readByte(); @@ -1043,7 +1046,8 @@ private static Object readBoxedListElement( } } - private static Object materializeTarget(Object array, int arrayTypeId, Class targetType) { + private static Object materializeTarget( + ReadContext readContext, Object array, int arrayTypeId, Class targetType) { if (targetType.isArray()) { return array; } @@ -1058,7 +1062,7 @@ private static Object materializeTarget(Object array, int arrayTypeId, Class return primitiveList; } if (targetType.isAssignableFrom(ArrayList.class)) { - return materializeBoxedList(array, arrayTypeId); + return materializeBoxedList(readContext, array, arrayTypeId); } throw new DeserializationException("Unsupported compatible list/array target " + targetType); } @@ -1172,8 +1176,10 @@ private static boolean canMaterializePrimitiveListTarget(Class targetType, in } } - private static List materializeBoxedList(Object array, int arrayTypeId) { + private static List materializeBoxedList( + ReadContext readContext, Object array, int arrayTypeId) { int size = java.lang.reflect.Array.getLength(array); + readContext.reserveCollectionMemory(size); ArrayList list = new ArrayList<>(size); switch (arrayTypeId) { case Types.BOOL_ARRAY: diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/ChildContainerSerializers.java b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/ChildContainerSerializers.java index f7840349ef..f58d56d08b 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/ChildContainerSerializers.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/ChildContainerSerializers.java @@ -249,7 +249,7 @@ public Collection onCollectionWrite(WriteContext writeContext, T value) { @Override public T newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readCollectionSize(buffer); + int numElements = readCollectionSize(readContext); setNumElements(numElements); int refId = readContext.lastPreservedRefId(); Comparator comparator = (Comparator) readContext.readRef(); @@ -295,7 +295,7 @@ public Collection onCollectionWrite(WriteContext writeContext, T value) { @Override public T newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readCollectionSize(buffer); + int numElements = readCollectionSize(readContext); setNumElements(numElements); int refId = readContext.lastPreservedRefId(); Comparator comparator = (Comparator) readContext.readRef(); @@ -403,7 +403,7 @@ public Map onMapWrite(WriteContext writeContext, T value) { @Override public Map newMap(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readMapSize(buffer); + int numElements = readMapSize(readContext); setNumElements(numElements); int refId = readContext.lastPreservedRefId(); Comparator comparator = (Comparator) readContext.readRef(); diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/CollectionLikeSerializer.java b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/CollectionLikeSerializer.java index 3915b5d888..b6151f45a9 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/CollectionLikeSerializer.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/CollectionLikeSerializer.java @@ -461,7 +461,7 @@ public T read(ReadContext readContext) { */ public Collection newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - numElements = readCollectionSize(buffer); + numElements = readCollectionSize(readContext); if (AndroidSupport.IS_ANDROID) { try { Constructor constructor = type.getDeclaredConstructor(); @@ -560,9 +560,11 @@ protected void setNumElements(int numElements) { this.numElements = numElements; } - protected final int readCollectionSize(MemoryBuffer buffer) { + protected final int readCollectionSize(ReadContext readContext) { + MemoryBuffer buffer = readContext.getBuffer(); int numElements = buffer.readVarUInt32Small7(); checkCollectionSize(numElements); + readContext.reserveCollectionMemory(numElements); buffer.checkReadableBytes(numElements); return numElements; } diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/CollectionSerializers.java b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/CollectionSerializers.java index a81c38298f..2456377485 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/CollectionSerializers.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/CollectionSerializers.java @@ -127,7 +127,7 @@ public ArrayListSerializer(TypeResolver typeResolver) { @Override public ArrayList newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readCollectionSize(buffer); + int numElements = readCollectionSize(readContext); setNumElements(numElements); ArrayList arrayList = new ArrayList(numElements); readContext.reference(arrayList); @@ -189,7 +189,7 @@ public List read(ReadContext readContext) { @Override public ArrayList newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readCollectionSize(buffer); + int numElements = readCollectionSize(readContext); setNumElements(numElements); ArrayList arrayList = new ArrayList(numElements); readContext.reference(arrayList); @@ -205,7 +205,7 @@ public HashSetSerializer(TypeResolver typeResolver) { @Override public HashSet newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readCollectionSize(buffer); + int numElements = readCollectionSize(readContext); setNumElements(numElements); HashSet hashSet = new HashSet(numElements); readContext.reference(hashSet); @@ -221,7 +221,7 @@ public LinkedHashSetSerializer(TypeResolver typeResolver) { @Override public LinkedHashSet newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readCollectionSize(buffer); + int numElements = readCollectionSize(readContext); setNumElements(numElements); LinkedHashSet hashSet = new LinkedHashSet(numElements); readContext.reference(hashSet); @@ -270,7 +270,7 @@ public Collection onCollectionWrite(WriteContext writeContext, T value) { @Override public T newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readCollectionSize(buffer); + int numElements = readCollectionSize(readContext); setNumElements(numElements); T collection; Comparator comparator = config.isXlang() ? null : (Comparator) readContext.readRef(); @@ -335,7 +335,7 @@ public void write(WriteContext writeContext, List value) { @Override public List read(ReadContext readContext) { if (config.isXlang()) { - int numElements = readCollectionSize(readContext.getBuffer()); + int numElements = readCollectionSize(readContext); if (numElements != 0) { throw new DeserializationException( "Empty list body must have zero elements but got " + numElements); @@ -356,7 +356,7 @@ public CopyOnWriteArrayListSerializer( @Override public Collection newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readCollectionSize(buffer); + int numElements = readCollectionSize(readContext); setNumElements(numElements); return new CollectionContainer<>(numElements); } @@ -390,7 +390,7 @@ public CopyOnWriteArraySetSerializer( @Override public Collection newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readCollectionSize(buffer); + int numElements = readCollectionSize(readContext); setNumElements(numElements); return new CollectionContainer<>(numElements); } @@ -542,7 +542,7 @@ public CollectionSnapshot onCollectionWrite( @Override public ConcurrentSkipListSet newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readCollectionSize(buffer); + int numElements = readCollectionSize(readContext); setNumElements(numElements); if (config.isXlang()) { ConcurrentSkipListSet skipListSet = new ConcurrentSkipListSet(); @@ -726,7 +726,7 @@ public VectorSerializer(TypeResolver typeResolver, Class cls) { @Override public Vector newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readCollectionSize(buffer); + int numElements = readCollectionSize(readContext); setNumElements(numElements); Vector vector = new Vector<>(numElements); readContext.reference(vector); @@ -743,7 +743,7 @@ public ArrayDequeSerializer(TypeResolver typeResolver, Class cls) { @Override public ArrayDeque newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readCollectionSize(buffer); + int numElements = readCollectionSize(readContext); setNumElements(numElements); ArrayDeque deque = new ArrayDeque(numElements); readContext.reference(deque); @@ -786,9 +786,9 @@ public void write(WriteContext writeContext, EnumSet object) { public EnumSet read(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); Class elemClass = typeResolver.readTypeInfo(readContext).getType(); + int length = readCollectionSize(readContext); EnumSet object = EnumSet.noneOf(elemClass); Serializer elemSerializer = typeResolver.getSerializer(elemClass); - int length = readCollectionSize(buffer); for (int i = 0; i < length; i++) { object.add(elemSerializer.read(readContext)); } @@ -863,7 +863,7 @@ public Collection newCollection(CopyContext copyContext, Collection collection) public PriorityQueue newCollection(ReadContext readContext) { assert !config.isXlang(); MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readCollectionSize(buffer); + int numElements = readCollectionSize(readContext); setNumElements(numElements); Comparator comparator = (Comparator) readContext.readRef(); PriorityQueue queue = new PriorityQueue(comparator); @@ -923,10 +923,11 @@ public CollectionSnapshot onCollectionWrite( @Override public ArrayBlockingQueue newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readCollectionSize(buffer); + int numElements = readCollectionSize(readContext); setNumElements(numElements); int capacity = buffer.readVarUInt32Small7(); checkBoundedQueueCapacity(numElements, capacity); + readContext.reserveCollectionCapacity(numElements, capacity); buffer.checkReadableBytes(capacity); ArrayBlockingQueue queue = new ArrayBlockingQueue<>(capacity); readContext.reference(queue); @@ -990,10 +991,12 @@ public CollectionSnapshot onCollectionWrite( @Override public LinkedBlockingQueue newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readCollectionSize(buffer); + int numElements = readCollectionSize(readContext); setNumElements(numElements); int capacity = buffer.readVarUInt32Small7(); checkBoundedQueueCapacity(numElements, capacity); + // LinkedBlockingQueue capacity is a logical bound, not preallocated backing storage. The + // current node storage is already charged by readCollectionSize(numElements). LinkedBlockingQueue queue = new LinkedBlockingQueue<>(capacity); readContext.reference(queue); return queue; @@ -1130,7 +1133,7 @@ public XlangListDefaultSerializer(TypeResolver typeResolver, Class cls) { @Override public List newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readCollectionSize(buffer); + int numElements = readCollectionSize(readContext); setNumElements(numElements); ArrayList list = new ArrayList(numElements); readContext.reference(list); @@ -1146,7 +1149,7 @@ public XlangSetDefaultSerializer(TypeResolver typeResolver, Class cls) { @Override public Set newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readCollectionSize(buffer); + int numElements = readCollectionSize(readContext); setNumElements(numElements); HashSet set = new HashSet(numElements); readContext.reference(set); diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/GuavaCollectionSerializers.java b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/GuavaCollectionSerializers.java index c28aa04561..a5aba71aaa 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/GuavaCollectionSerializers.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/GuavaCollectionSerializers.java @@ -94,7 +94,7 @@ public ImmutableListSerializer(TypeResolver typeResolver, Class cls) { @Override public Collection newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readCollectionSize(buffer); + int numElements = readCollectionSize(readContext); setNumElements(numElements); return new CollectionContainer<>(numElements); } @@ -127,7 +127,7 @@ public RegularImmutableListSerializer(TypeResolver typeResolver, Class cls) { @Override public Collection newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readCollectionSize(buffer); + int numElements = readCollectionSize(readContext); setNumElements(numElements); return new CollectionContainer(numElements); } @@ -161,7 +161,7 @@ public ImmutableSetSerializer(TypeResolver typeResolver, Class cls) { @Override public Collection newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readCollectionSize(buffer); + int numElements = readCollectionSize(readContext); setNumElements(numElements); return new CollectionContainer<>(numElements); } @@ -203,7 +203,7 @@ public Collection onCollectionWrite(WriteContext writeContext, T value) { @Override public Collection newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readCollectionSize(buffer); + int numElements = readCollectionSize(readContext); setNumElements(numElements); Comparator comparator = (Comparator) readContext.readRef(); return new SortedCollectionContainer(comparator, numElements); @@ -236,7 +236,7 @@ public GuavaMapSerializer(TypeResolver typeResolver, Class cls) { @Override public Map newMap(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readMapSize(buffer); + int numElements = readMapSize(readContext); setNumElements(numElements); return new MapContainer(numElements); } @@ -264,7 +264,7 @@ public T onMapRead(Map map) { @Override public T read(ReadContext readContext) { - int size = readMapSize(readContext.getBuffer()); + int size = readMapSize(readContext); Map map = new HashMap(); readElements(readContext, size, map); return xnewInstance(map); @@ -574,7 +574,7 @@ public Map onMapWrite(WriteContext writeContext, T value) { @Override public Map newMap(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readMapSize(buffer); + int numElements = readMapSize(readContext); setNumElements(numElements); Comparator comparator = (Comparator) readContext.readRef(); return new SortedMapContainer<>(comparator, numElements); diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/ImmutableCollectionSerializers.java b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/ImmutableCollectionSerializers.java index cd69f2b6cf..7a9f9f017d 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/ImmutableCollectionSerializers.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/ImmutableCollectionSerializers.java @@ -125,7 +125,7 @@ public ImmutableListSerializer(TypeResolver typeResolver, Class cls) { @Override public Collection newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readCollectionSize(buffer); + int numElements = readCollectionSize(readContext); setNumElements(numElements); if (JdkVersion.MAJOR_VERSION > 8) { return new CollectionContainer<>(numElements); @@ -186,7 +186,7 @@ public ImmutableSetSerializer(TypeResolver typeResolver, Class cls) { @Override public Collection newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readCollectionSize(buffer); + int numElements = readCollectionSize(readContext); setNumElements(numElements); if (JdkVersion.MAJOR_VERSION > 8) { return new CollectionContainer<>(numElements); @@ -247,7 +247,7 @@ public ImmutableMapSerializer(TypeResolver typeResolver, Class cls) { @Override public Map newMap(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readMapSize(buffer); + int numElements = readMapSize(readContext); setNumElements(numElements); if (JdkVersion.MAJOR_VERSION > 8) { return new JDKImmutableMapContainer(numElements); diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/MapLikeSerializer.java b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/MapLikeSerializer.java index 4f6f828d11..9ac93e96a2 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/MapLikeSerializer.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/MapLikeSerializer.java @@ -895,7 +895,7 @@ public void onMapWriteFinish(Map map) {} */ public Map newMap(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - numElements = readMapSize(buffer); + numElements = readMapSize(readContext); if (AndroidSupport.IS_ANDROID) { try { Constructor constructor = type.getDeclaredConstructor(); @@ -964,12 +964,14 @@ public void setNumElements(int numElements) { this.numElements = numElements; } - protected final int readMapSize(MemoryBuffer buffer) { + protected final int readMapSize(ReadContext readContext) { + MemoryBuffer buffer = readContext.getBuffer(); int numElements = buffer.readVarUInt32Small7(); checkMapSize(numElements); if (numElements > Integer.MAX_VALUE / 2) { throwInvalidMapBodySize(numElements); } + readContext.reserveMapMemory(numElements); buffer.checkReadableBytes(numElements << 1); return numElements; } diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/MapSerializers.java b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/MapSerializers.java index 9b3825495a..91f9ba2d06 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/MapSerializers.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/MapSerializers.java @@ -86,7 +86,7 @@ public HashMapSerializer(TypeResolver typeResolver) { @Override public HashMap newMap(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readMapSize(buffer); + int numElements = readMapSize(readContext); setNumElements(numElements); HashMap hashMap = new HashMap(numElements); readContext.reference(hashMap); @@ -107,7 +107,7 @@ public LinkedHashMapSerializer(TypeResolver typeResolver) { @Override public LinkedHashMap newMap(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readMapSize(buffer); + int numElements = readMapSize(readContext); setNumElements(numElements); LinkedHashMap hashMap = new LinkedHashMap(numElements); readContext.reference(hashMap); @@ -146,7 +146,7 @@ public LazyMapSerializer(TypeResolver typeResolver) { @Override public LazyMap newMap(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readMapSize(buffer); + int numElements = readMapSize(readContext); setNumElements(numElements); LazyMap map = new LazyMap(numElements); readContext.reference(map); @@ -200,7 +200,7 @@ public Map onMapWrite(WriteContext writeContext, T value) { @Override public Map newMap(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - setNumElements(readMapSize(buffer)); + setNumElements(readMapSize(readContext)); T map; Comparator comparator = config.isXlang() ? null : (Comparator) readContext.readRef(); if (type == TreeMap.class) { @@ -322,7 +322,7 @@ public ConcurrentHashMapSerializer(TypeResolver typeResolver, Class keyType = typeResolver.readTypeInfo(readContext).getType(); EnumMap map = new EnumMap(keyType); readContext.reference(map); @@ -619,7 +619,7 @@ public Object onMapCopy(Map map) { public Map newMap(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readMapSize(buffer); + int numElements = readMapSize(readContext); setNumElements(numElements); HashMap map = new HashMap<>(numElements); readContext.reference(map); diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/SubListSerializers.java b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/SubListSerializers.java index f11f4b79a3..c011c91277 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/SubListSerializers.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/SubListSerializers.java @@ -158,7 +158,7 @@ public List read(ReadContext readContext) { @Override public Collection newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readCollectionSize(buffer); + int numElements = readCollectionSize(readContext); setNumElements(numElements); ArrayList list = new ArrayList(numElements); readContext.reference(list); diff --git a/java/fory-core/src/main/java25/org/apache/fory/memory/MemoryBuffer.java b/java/fory-core/src/main/java25/org/apache/fory/memory/MemoryBuffer.java index 52c4d0ce50..4c9347b65e 100644 --- a/java/fory-core/src/main/java25/org/apache/fory/memory/MemoryBuffer.java +++ b/java/fory-core/src/main/java25/org/apache/fory/memory/MemoryBuffer.java @@ -81,6 +81,7 @@ public final class MemoryBuffer { private static final int LONG_ARRAY_OFFSET = 0; private static final int FLOAT_ARRAY_OFFSET = 0; private static final int DOUBLE_ARRAY_OFFSET = 0; + private static final int OBJECT_ARRAY_INDEX_SCALE = 4; private static final VarHandle BYTE_ARRAY_CHAR = MethodHandles.byteArrayViewVarHandle(char[].class, NATIVE_ORDER); private static final VarHandle BYTE_ARRAY_SHORT = @@ -3924,6 +3925,10 @@ public static MemoryBuffer fromDirectByteBuffer( return new MemoryBuffer(offHeapAddress, size, buffer, streamReader); } + public static int objectArrayIndexScale() { + return OBJECT_ARRAY_INDEX_SCALE; + } + /** * Create a heap buffer of specified initial size. The buffer will grow automatically if not * enough. diff --git a/java/fory-core/src/test/java/org/apache/fory/ForyTestBase.java b/java/fory-core/src/test/java/org/apache/fory/ForyTestBase.java index 63e3ffcdc1..94c0e893b8 100644 --- a/java/fory-core/src/test/java/org/apache/fory/ForyTestBase.java +++ b/java/fory-core/src/test/java/org/apache/fory/ForyTestBase.java @@ -344,7 +344,7 @@ public static void withWriteContext( public static T withReadContext( Fory fory, MemoryBuffer buffer, Function action) { ReadContext context = (ReadContext) ReflectionUtils.getObjectFieldValue(fory, "readContext"); - context.prepare(buffer, null, false); + context.prepare(buffer, null, false, buffer.remaining(), false); try { return action.apply(context); } finally { diff --git a/java/fory-core/src/test/java/org/apache/fory/io/MemoryBufferObjectInputTest.java b/java/fory-core/src/test/java/org/apache/fory/io/MemoryBufferObjectInputTest.java index 71f73582c2..3b379c7b3a 100644 --- a/java/fory-core/src/test/java/org/apache/fory/io/MemoryBufferObjectInputTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/io/MemoryBufferObjectInputTest.java @@ -53,7 +53,7 @@ public void testForyStructInput(boolean compressNumber) throws IOException { buffer.writeFloat32(4.1f); buffer.writeFloat64(4.2); new StringSerializer(fory.getConfig()).writeString(buffer, "abc"); - fory.getReadContext().prepare(buffer, null, false); + fory.getReadContext().prepare(buffer, null, false, buffer.remaining(), false); try (MemoryBufferObjectInput input = new MemoryBufferObjectInput(fory.getConfig(), fory.getReadContext())) { assertEquals(input.readByte(), 1); diff --git a/java/fory-core/src/test/java/org/apache/fory/io/MemoryBufferObjectOutputTest.java b/java/fory-core/src/test/java/org/apache/fory/io/MemoryBufferObjectOutputTest.java index a04a0e9f3c..f5b69b8b13 100644 --- a/java/fory-core/src/test/java/org/apache/fory/io/MemoryBufferObjectOutputTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/io/MemoryBufferObjectOutputTest.java @@ -46,7 +46,7 @@ public void testForyStructOutput() throws IOException { output.writeChars("abc"); output.writeUTF("abc"); } - fory.getReadContext().prepare(buffer, null, false); + fory.getReadContext().prepare(buffer, null, false, buffer.remaining(), false); try (MemoryBufferObjectInput input = new MemoryBufferObjectInput(fory.getConfig(), fory.getReadContext())) { assertEquals(input.readByte(), 1); diff --git a/java/fory-core/src/test/java/org/apache/fory/resolver/ClassResolverTest.java b/java/fory-core/src/test/java/org/apache/fory/resolver/ClassResolverTest.java index 0b94da34cb..157b3950f5 100644 --- a/java/fory-core/src/test/java/org/apache/fory/resolver/ClassResolverTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/resolver/ClassResolverTest.java @@ -360,7 +360,7 @@ public void testRemoteTypeDefChecksTypeChecker() { ReadContext readContext = reader.getReadContext(); readContext.setMetaReadContext(new MetaReadContext()); MemoryBuffer buffer = MemoryBuffer.newHeapBuffer(256); - readContext.prepare(buffer, null, false); + readContext.prepare(buffer, null, false, buffer.remaining(), false); buffer.writeVarUInt32(0); typeDef.writeTypeDef(buffer); buffer.readerIndex(0); @@ -473,7 +473,7 @@ public void testExactLocalEnumTypeDefBypassesLimit() { ReadContext readContext = fory.getReadContext(); readContext.setMetaReadContext(new MetaReadContext()); MemoryBuffer buffer = MemoryBuffer.newHeapBuffer(256); - readContext.prepare(buffer, null, false); + readContext.prepare(buffer, null, false, buffer.remaining(), false); buffer.writeVarUInt32(0); exact.writeTypeDef(buffer); buffer.readerIndex(0); @@ -792,7 +792,7 @@ public void testWriteClassName() { } finally { fory.getWriteContext().reset(); } - fory.getReadContext().prepare(buffer, null, false); + fory.getReadContext().prepare(buffer, null, false, buffer.remaining(), false); try { Assert.assertSame(classResolver.readTypeInfo(fory.getReadContext()).getType(), getClass()); Assert.assertSame(classResolver.readTypeInfo(fory.getReadContext()).getType(), getClass()); diff --git a/java/fory-core/src/test/java/org/apache/fory/serializer/ArraySerializersTest.java b/java/fory-core/src/test/java/org/apache/fory/serializer/ArraySerializersTest.java index b0fb46f0f4..752eb77686 100644 --- a/java/fory-core/src/test/java/org/apache/fory/serializer/ArraySerializersTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/serializer/ArraySerializersTest.java @@ -373,11 +373,15 @@ private static Object readPrimitiveArrayBody( MemoryBuffer control = MemoryBuffer.newHeapBuffer(1); control.writeBoolean(false); readContext.prepare( - control, Collections.singletonList(MemoryUtils.wrap(new byte[byteSize])), true); + control, + Collections.singletonList(MemoryUtils.wrap(new byte[byteSize])), + true, + control.remaining(), + false); } else { MemoryBuffer buffer = MemoryBuffer.newHeapBuffer(5); buffer.writeVarUInt32Small7(byteSize); - readContext.prepare(buffer, null, false); + readContext.prepare(buffer, null, false, buffer.remaining(), false); } return fory.getSerializer(arrayType).read(readContext); } @@ -387,8 +391,8 @@ private static Object readTruncatedPrimitiveArrayBody( ReadContext readContext = fory.getReadContext(); MemoryBuffer buffer = MemoryBuffer.newHeapBuffer(5); buffer.writeVarUInt32Small7(byteSize); - readContext.prepare( - MemoryBuffer.fromByteArray(buffer.getBytes(0, buffer.writerIndex())), null, false); + MemoryBuffer truncated = MemoryBuffer.fromByteArray(buffer.getBytes(0, buffer.writerIndex())); + readContext.prepare(truncated, null, false, truncated.remaining(), false); return fory.getSerializer(arrayType).read(readContext); } @@ -396,7 +400,7 @@ private static Object readPrimitiveArrayRawBody(Fory fory, Class arrayType) { ReadContext readContext = fory.getReadContext(); MemoryBuffer buffer = MemoryBuffer.newHeapBuffer(5); writeNegativeDecodedVarUInt32(buffer); - readContext.prepare(buffer, null, false); + readContext.prepare(buffer, null, false, buffer.remaining(), false); return fory.getSerializer(arrayType).read(readContext); } @@ -404,7 +408,7 @@ private static Object readObjectArrayBody(Fory fory, Class arrayType, int num ReadContext readContext = fory.getReadContext(); MemoryBuffer buffer = MemoryBuffer.newHeapBuffer(5); buffer.writeVarUInt32Small7(numElements); - readContext.prepare(buffer, null, false); + readContext.prepare(buffer, null, false, buffer.remaining(), false); return fory.getSerializer(arrayType).read(readContext); } diff --git a/java/fory-core/src/test/java/org/apache/fory/serializer/CompatibleSerializerTest.java b/java/fory-core/src/test/java/org/apache/fory/serializer/CompatibleSerializerTest.java index cd3d31dcac..492bee83f8 100644 --- a/java/fory-core/src/test/java/org/apache/fory/serializer/CompatibleSerializerTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/serializer/CompatibleSerializerTest.java @@ -33,6 +33,7 @@ import org.apache.fory.ForyTestBase; import org.apache.fory.TestUtils; import org.apache.fory.config.Language; +import org.apache.fory.context.ReadContext; import org.apache.fory.memory.MemoryBuffer; import org.apache.fory.memory.MemoryUtils; import org.apache.fory.serializer.collection.UnmodifiableSerializersTest; @@ -138,14 +139,21 @@ public void testWriteCompatibleBasic() throws Exception { public void testNullableListBodyBounds() throws Exception { Method method = CompatibleCollectionArrayReader.class.getDeclaredMethod( - "readNullableListBoxedElements", MemoryBuffer.class, int.class, int.class, int.class); + "readNullableListBoxedElements", ReadContext.class, int.class, int.class, int.class); method.setAccessible(true); MemoryBuffer buffer = MemoryUtils.buffer(0); - InvocationTargetException exception = - Assert.expectThrows( - InvocationTargetException.class, - () -> method.invoke(null, buffer, 1024, Types.INT32_ARRAY, Types.INT32)); - Assert.assertTrue(exception.getCause() instanceof IndexOutOfBoundsException); + Fory fory = builder().build(); + ReadContext readContext = fory.getReadContext(); + readContext.prepare(buffer, null, false, buffer.remaining(), false); + try { + InvocationTargetException exception = + Assert.expectThrows( + InvocationTargetException.class, + () -> method.invoke(null, readContext, 1024, Types.INT32_ARRAY, Types.INT32)); + Assert.assertTrue(exception.getCause() instanceof IndexOutOfBoundsException); + } finally { + readContext.reset(); + } } @Test diff --git a/java/fory-core/src/test/java/org/apache/fory/serializer/ContainerMemoryBudgetTest.java b/java/fory-core/src/test/java/org/apache/fory/serializer/ContainerMemoryBudgetTest.java new file mode 100644 index 0000000000..09b73c25d8 --- /dev/null +++ b/java/fory-core/src/test/java/org/apache/fory/serializer/ContainerMemoryBudgetTest.java @@ -0,0 +1,318 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.fory.serializer; + +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertThrows; +import static org.testng.Assert.assertTrue; + +import java.io.ByteArrayInputStream; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import org.apache.fory.Fory; +import org.apache.fory.ForyTestBase; +import org.apache.fory.collection.Int32List; +import org.apache.fory.context.ReadContext; +import org.apache.fory.exception.DeserializationException; +import org.apache.fory.exception.InsecureException; +import org.apache.fory.io.ForyInputStream; +import org.apache.fory.memory.MemoryBuffer; +import org.testng.annotations.Test; + +public class ContainerMemoryBudgetTest extends ForyTestBase { + private static final long KNOWN_ROOT_MULTIPLIER = 8L; + private static final long KNOWN_ROOT_SLACK_BYTES = 64L * 1024; + private static final long STREAM_ROOT_BYTES = 128L * 1024 * 1024; + private static final long COLLECTION_OBJECT_BYTES = 24L; + private static final long MAP_OBJECT_BYTES = 48L; + private static final long ARRAY_HEADER_BYTES = 16L; + private static final long MAP_ENTRY_BYTES = 32L; + private static final int REFERENCE_BYTES = MemoryBuffer.objectArrayIndexScale(); + + @Test + public void testConfigValidation() { + assertEquals(newFory(-1).getConfig().maxContainerMemoryBytes(), -1); + assertEquals(newFory(123).getConfig().maxContainerMemoryBytes(), 123); + assertThrows(IllegalArgumentException.class, () -> builder().withMaxContainerMemoryBytes(0)); + assertThrows(IllegalArgumentException.class, () -> builder().withMaxContainerMemoryBytes(-2)); + } + + @Test + public void testKnownAutoBudget() { + Fory fory = newFory(-1); + ReadContext readContext = prepareContext(fory, 17, false); + try { + long budget = knownAutoBytes(17); + readContext.reserveContainerMemory(budget); + assertThrows(InsecureException.class, () -> readContext.reserveContainerMemory(1)); + } finally { + readContext.reset(); + } + } + + @Test + public void testStreamAutoBudget() { + Fory fory = newFory(-1); + ReadContext readContext = prepareContext(fory, 17, true); + try { + readContext.reserveContainerMemory(STREAM_ROOT_BYTES); + assertThrows(InsecureException.class, () -> readContext.reserveContainerMemory(1)); + } finally { + readContext.reset(); + } + + StreamPayload payload = findStreamPayload(); + assertThrows(InsecureException.class, () -> newFory(-1).deserialize(payload.bytes)); + Object copy = + newFory(-1).deserialize(new ForyInputStream(new ByteArrayInputStream(payload.bytes), 1)); + assertEquals(copy, payload.value); + } + + @Test + public void testExplicitBudgetWins() { + Fory fory = newFory(7); + ReadContext readContext = prepareContext(fory, 1024 * 1024, false); + try { + readContext.reserveContainerMemory(7); + assertThrows(InsecureException.class, () -> readContext.reserveContainerMemory(1)); + } finally { + readContext.reset(); + } + } + + @Test + public void testNestedEmptyFixedCost() { + List value = emptyLists(1); + byte[] bytes = newFory(-1).serialize(value); + + assertThrows(InsecureException.class, () -> newFory(collectionBytes(1)).deserialize(bytes)); + assertEquals(newFory(collectionBytes(1) + collectionBytes(0)).deserialize(bytes), value); + } + + @Test + public void testSiblingBudgetIsCumulative() { + List value = nullLists(2, 64); + byte[] bytes = newFory(-1).serialize(value); + long firstChildOnly = collectionBytes(2) + collectionBytes(64); + + assertThrows(InsecureException.class, () -> newFory(firstChildOnly).deserialize(bytes)); + assertEquals(newFory(firstChildOnly + collectionBytes(64)).deserialize(bytes), value); + } + + @Test + public void testMapBudgetAndOverflow() { + Fory fory = newFory(mapBytes(1) - 1); + ReadContext readContext = prepareContext(fory, 8, false); + try { + assertThrows(InsecureException.class, () -> readContext.reserveMapMemory(1)); + } finally { + readContext.reset(); + } + + Fory exactFory = newFory(mapBytes(1)); + ReadContext exactContext = prepareContext(exactFory, 8, false); + try { + exactContext.reserveMapMemory(1); + assertThrows(InsecureException.class, () -> exactContext.reserveContainerMemory(1)); + } finally { + exactContext.reset(); + } + + MemoryBuffer buffer = MemoryBuffer.newHeapBuffer(8); + buffer.writeVarUInt32Small7(Integer.MAX_VALUE); + buffer = trimBuffer(buffer); + Fory reader = newFory(STREAM_ROOT_BYTES); + ReadContext mapContext = reader.getReadContext(); + mapContext.prepare(buffer, null, false, buffer.remaining(), false); + try { + assertThrows( + DeserializationException.class, + () -> reader.getSerializer(HashMap.class).read(mapContext)); + } finally { + mapContext.reset(); + } + } + + @Test + public void testObjectArrayBudget() { + Fory lowFory = newFory(objectArrayBytes(0) - 1); + ReadContext lowContext = lowFory.getReadContext(); + MemoryBuffer lowBuffer = objectArraySizeBuffer(0); + lowContext.prepare(lowBuffer, null, false, lowBuffer.remaining(), false); + try { + assertThrows( + InsecureException.class, () -> lowFory.getSerializer(Object[].class).read(lowContext)); + } finally { + lowContext.reset(); + } + + Fory exactFory = newFory(objectArrayBytes(0)); + ReadContext exactContext = exactFory.getReadContext(); + MemoryBuffer exactBuffer = objectArraySizeBuffer(0); + exactContext.prepare(exactBuffer, null, false, exactBuffer.remaining(), false); + try { + Object[] array = (Object[]) exactFory.getSerializer(Object[].class).read(exactContext); + assertEquals(array.length, 0); + } finally { + exactContext.reset(); + } + + Fory slotFory = newFory(objectArrayBytes(2) - 1); + ReadContext slotContext = slotFory.getReadContext(); + MemoryBuffer slotBuffer = objectArraySizeBuffer(2); + slotContext.prepare(slotBuffer, null, false, slotBuffer.remaining(), false); + try { + assertThrows( + InsecureException.class, () -> slotFory.getSerializer(Object[].class).read(slotContext)); + } finally { + slotContext.reset(); + } + } + + @Test + public void testScalarOwnersSkipBudget() { + Fory fory = newFory(1); + assertEquals(fory.deserialize(fory.serialize("container budget")), "container budget"); + + byte[] bytes = new byte[] {1, 2, 3}; + assertTrue(Arrays.equals((byte[]) fory.deserialize(fory.serialize(bytes)), bytes)); + + int[] ints = new int[] {4, 5, 6}; + assertTrue(Arrays.equals((int[]) fory.deserialize(fory.serialize(ints)), ints)); + + Int32List denseList = new Int32List(new int[] {7, 8, 9}); + assertEquals(fory.deserialize(fory.serialize(denseList)), denseList); + } + + @Test + public void testTruncatedCollectionStillFails() { + Fory fory = newFory(collectionBytes(3)); + MemoryBuffer buffer = MemoryBuffer.newHeapBuffer(8); + buffer.writeVarUInt32Small7(3); + buffer.writeByte(0); + buffer.writeByte(0); + buffer = trimBuffer(buffer); + ReadContext readContext = fory.getReadContext(); + readContext.prepare(buffer, null, false, buffer.remaining(), false); + try { + assertThrows( + IndexOutOfBoundsException.class, + () -> fory.getSerializer(ArrayList.class).read(readContext)); + } finally { + readContext.reset(); + } + } + + private static Fory newFory(long maxContainerMemoryBytes) { + return builder().withMaxContainerMemoryBytes(maxContainerMemoryBytes).build(); + } + + private static ReadContext prepareContext( + Fory fory, int rootInputBytes, boolean unknownLengthInput) { + MemoryBuffer buffer = MemoryBuffer.newHeapBuffer(0); + ReadContext readContext = fory.getReadContext(); + readContext.prepare(buffer, null, false, rootInputBytes, unknownLengthInput); + return readContext; + } + + private static long collectionBytes(int numElements) { + return COLLECTION_OBJECT_BYTES + (long) numElements * REFERENCE_BYTES; + } + + private static long mapBytes(int numElements) { + long entries = numElements; + return MAP_OBJECT_BYTES + + entries * 2 * REFERENCE_BYTES + + entries * (MAP_ENTRY_BYTES + 3L * REFERENCE_BYTES); + } + + private static long objectArrayBytes(int numElements) { + return ARRAY_HEADER_BYTES + (long) numElements * REFERENCE_BYTES; + } + + private static long knownAutoBytes(int inputBytes) { + return inputBytes * KNOWN_ROOT_MULTIPLIER + KNOWN_ROOT_SLACK_BYTES; + } + + private static List emptyLists(int numElements) { + List root = new ArrayList<>(numElements); + for (int i = 0; i < numElements; i++) { + root.add(new ArrayList<>()); + } + return root; + } + + private static List nullLists(int siblings, int childElements) { + List root = new ArrayList<>(siblings); + for (int i = 0; i < siblings; i++) { + List child = new ArrayList<>(childElements); + for (int j = 0; j < childElements; j++) { + child.add(null); + } + root.add(child); + } + return root; + } + + private static List emptyMaps(int numElements) { + List root = new ArrayList<>(numElements); + for (int i = 0; i < numElements; i++) { + root.add(new HashMap<>()); + } + return root; + } + + private static MemoryBuffer objectArraySizeBuffer(int numElements) { + MemoryBuffer buffer = MemoryBuffer.newHeapBuffer(8); + buffer.writeVarUInt32Small7(numElements); + return trimBuffer(buffer); + } + + private static MemoryBuffer trimBuffer(MemoryBuffer buffer) { + return MemoryBuffer.fromByteArray(buffer.getBytes(0, buffer.writerIndex())); + } + + private static StreamPayload findStreamPayload() { + Fory writer = newFory(-1); + int numElements = 128; + while (numElements <= 1 << 20) { + List value = emptyMaps(numElements); + byte[] bytes = writer.serialize(value); + long estimatedMemory = collectionBytes(numElements) + (long) numElements * mapBytes(0); + if (estimatedMemory > knownAutoBytes(bytes.length) && estimatedMemory < STREAM_ROOT_BYTES) { + return new StreamPayload(value, bytes); + } + numElements <<= 1; + } + throw new AssertionError("Unable to build compact stream-budget payload"); + } + + private static final class StreamPayload { + final List value; + final byte[] bytes; + + StreamPayload(List value, byte[] bytes) { + this.value = value; + this.bytes = bytes; + } + } +} diff --git a/java/fory-core/src/test/java/org/apache/fory/serializer/ExceptionSerializersTest.java b/java/fory-core/src/test/java/org/apache/fory/serializer/ExceptionSerializersTest.java index 2b0505e077..2ec74b12a0 100644 --- a/java/fory-core/src/test/java/org/apache/fory/serializer/ExceptionSerializersTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/serializer/ExceptionSerializersTest.java @@ -179,7 +179,7 @@ public void testThrowableReadsMainWireOrderWithCyclicCause() { payload.readerIndex(0); ReadContext readContext = fory.getReadContext(); - readContext.prepare(payload, null, false); + readContext.prepare(payload, null, false, payload.remaining(), false); readContext.preserveRefId(); Serializer serializer = (Serializer) fory.getTypeResolver().getSerializer(RuntimeException.class); @@ -251,7 +251,7 @@ public void testThrowableRejectsMismatchedClassLayerCount() { payload.readerIndex(0); ReadContext readContext = fory.getReadContext(); - readContext.prepare(payload, null, false); + readContext.prepare(payload, null, false, payload.remaining(), false); Serializer serializer = (Serializer) fory.getTypeResolver().getSerializer(CustomException.class); diff --git a/java/fory-core/src/test/java/org/apache/fory/serializer/PrimitiveSerializersTest.java b/java/fory-core/src/test/java/org/apache/fory/serializer/PrimitiveSerializersTest.java index bc945d60ff..4c7d650331 100644 --- a/java/fory-core/src/test/java/org/apache/fory/serializer/PrimitiveSerializersTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/serializer/PrimitiveSerializersTest.java @@ -307,7 +307,7 @@ private static Object readPrimitiveListBody(Fory fory, Class listType, int he MemoryBuffer buffer = MemoryBuffer.newHeapBuffer(5); buffer.writeVarUInt32Small7(headerSize); ReadContext readContext = fory.getReadContext(); - readContext.prepare(buffer, null, false); + readContext.prepare(buffer, null, false, buffer.remaining(), false); return fory.getSerializer(listType).read(readContext); } @@ -315,7 +315,7 @@ private static Object readPrimitiveListRawBody(Fory fory, Class listType) { MemoryBuffer buffer = MemoryBuffer.newHeapBuffer(5); writeNegativeDecodedVarUInt32(buffer); ReadContext readContext = fory.getReadContext(); - readContext.prepare(buffer, null, false); + readContext.prepare(buffer, null, false, buffer.remaining(), false); return fory.getSerializer(listType).read(readContext); } diff --git a/java/fory-core/src/test/java/org/apache/fory/serializer/collection/ChildContainerSerializersTest.java b/java/fory-core/src/test/java/org/apache/fory/serializer/collection/ChildContainerSerializersTest.java index afc536be28..cfc4cfd60f 100644 --- a/java/fory-core/src/test/java/org/apache/fory/serializer/collection/ChildContainerSerializersTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/serializer/collection/ChildContainerSerializersTest.java @@ -143,7 +143,7 @@ public void testChildCollectionRejectsMismatchedClassLayerCount() { payload.readerIndex(0); ReadContext readContext = fory.getReadContext(); - readContext.prepare(payload, null, false); + readContext.prepare(payload, null, false, payload.remaining(), false); Serializer serializer = (Serializer) fory.getTypeResolver().getSerializer(ChildArrayList.class); diff --git a/java/fory-core/src/test/java/org/apache/fory/serializer/collection/CollectionSerializersTest.java b/java/fory-core/src/test/java/org/apache/fory/serializer/collection/CollectionSerializersTest.java index ff42616527..c8ccbc8310 100644 --- a/java/fory-core/src/test/java/org/apache/fory/serializer/collection/CollectionSerializersTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/serializer/collection/CollectionSerializersTest.java @@ -1421,7 +1421,7 @@ public void testBitSetRejectsNegativeBinary() { MemoryBuffer buffer = MemoryBuffer.newHeapBuffer(5); writeNegativeDecodedVarUInt32(buffer); ReadContext readContext = fory.getReadContext(); - readContext.prepare(buffer, null, false); + readContext.prepare(buffer, null, false, buffer.remaining(), false); Assert.expectThrows( DeserializationException.class, () -> fory.getSerializer(BitSet.class).read(readContext)); } diff --git a/javascript/packages/core/lib/context.ts b/javascript/packages/core/lib/context.ts index 3b50a58337..f95c38d72f 100644 --- a/javascript/packages/core/lib/context.ts +++ b/javascript/packages/core/lib/context.ts @@ -19,7 +19,11 @@ import { BinaryReader } from "./reader"; import { BinaryWriter } from "./writer"; -import { MetaString, MetaStringDecoder, MetaStringEncoder } from "./meta/MetaString"; +import { + MetaString, + MetaStringDecoder, + MetaStringEncoder, +} from "./meta/MetaString"; import { InnerFieldInfo, TypeMeta } from "./meta/TypeMeta"; import { Type, TypeInfo } from "./typeInfo"; import { Config, RefFlags, Serializer, TypeId } from "./type"; @@ -48,7 +52,9 @@ type CompatibleReadSerializerCacheEntry = { serializer: Serializer; }; -function remoteListElementType(fieldInfo: InnerFieldInfo): InnerFieldInfo | undefined { +function remoteListElementType( + fieldInfo: InnerFieldInfo, +): InnerFieldInfo | undefined { if (fieldInfo.typeId !== TypeId.LIST) { return undefined; } @@ -525,6 +531,13 @@ export class WriteContext { export class ReadContext { private static readonly MIN_REMOTE_TYPE_META_LIMIT = 8192; + private static readonly KNOWN_ROOT_BUDGET_MULTIPLIER = 8; + private static readonly KNOWN_ROOT_BUDGET_SLACK_BYTES = 64 * 1024; + private static readonly COLLECTION_OBJECT_BYTES = 24; + private static readonly MAP_OBJECT_BYTES = 48; + private static readonly ARRAY_HEADER_BYTES = 16; + private static readonly MAP_ENTRY_BYTES = 32; + private static readonly REFERENCE_BYTES = 4; readonly reader: BinaryReader; readonly refReader: RefReader; @@ -535,11 +548,18 @@ export class ReadContext { private typeMetaCache: Map = new Map(); private totalAcceptedSchemaVersions = 0; private cachedTypeMeta: TypeMeta | undefined; - private compatibleReadSerializers = new Map(); + private compatibleReadSerializers = new Map< + number, + CompatibleReadSerializerCacheEntry + >(); private _depth = 0; private _maxDepth: number; - private remoteSchemaVersionsByType: Map | undefined = undefined; + private readonly maxContainerMemoryBytes: number; + private effectiveContainerMemoryBytes = 0; + private remainingContainerMemoryBytes = 0; + private remoteSchemaVersionsByType: Map | undefined + = undefined; constructor( readonly typeResolver: TypeResolverLike, @@ -549,6 +569,7 @@ export class ReadContext { this.refReader = new RefReader(this.reader); this.metaStringReader = new MetaStringReader(); this._maxDepth = config.maxDepth ?? 50; + this.maxContainerMemoryBytes = config.maxContainerMemoryBytes; } reset(bytes: Uint8Array) { @@ -557,6 +578,71 @@ export class ReadContext { this.metaStringReader.reset(); this.typeMeta = []; this._depth = 0; + this.effectiveContainerMemoryBytes = this.maxContainerMemoryBytes > 0 + ? this.maxContainerMemoryBytes + : bytes.byteLength * ReadContext.KNOWN_ROOT_BUDGET_MULTIPLIER + + ReadContext.KNOWN_ROOT_BUDGET_SLACK_BYTES; + this.remainingContainerMemoryBytes = this.effectiveContainerMemoryBytes; + } + + reserveCollectionMemory(numElements: number) { + const bytes + = ReadContext.COLLECTION_OBJECT_BYTES + + numElements * ReadContext.REFERENCE_BYTES; + const remaining = this.remainingContainerMemoryBytes - bytes; + if (remaining < 0) { + this.throwContainerBudgetExceeded(bytes); + } + this.remainingContainerMemoryBytes = remaining; + } + + reserveMapMemory(numElements: number) { + const bytes = ReadContext.MAP_OBJECT_BYTES + + numElements + * ( + ReadContext.REFERENCE_BYTES * 2 + + ReadContext.MAP_ENTRY_BYTES + + ReadContext.REFERENCE_BYTES * 3 + ); + const remaining = this.remainingContainerMemoryBytes - bytes; + if (remaining < 0) { + this.throwContainerBudgetExceeded(bytes); + } + this.remainingContainerMemoryBytes = remaining; + } + + reserveTypedArrayMemory(numElements: number, elementBytes: number) { + const bytes = ReadContext.ARRAY_HEADER_BYTES + numElements * elementBytes; + const remaining = this.remainingContainerMemoryBytes - bytes; + if (remaining < 0) { + this.throwContainerBudgetExceeded(bytes); + } + this.remainingContainerMemoryBytes = remaining; + } + + reserveContainerMemory(bytes: number) { + if (!Number.isSafeInteger(bytes) || bytes < 0) { + this.throwContainerMemoryOverflow(bytes); + } + const remaining = this.remainingContainerMemoryBytes - bytes; + if (remaining < 0) { + this.throwContainerBudgetExceeded(bytes); + } + this.remainingContainerMemoryBytes = remaining; + } + + private throwContainerMemoryOverflow(bytes: number): never { + throw new Error( + `maxContainerMemoryBytes overflow: requested ${bytes} estimated container bytes`, + ); + } + + private throwContainerBudgetExceeded(bytes: number): never { + throw new Error( + `maxContainerMemoryBytes exceeded: requested ${bytes} estimated container bytes, ` + + `${this.remainingContainerMemoryBytes} remaining, effective limit ` + + `${this.effectiveContainerMemoryBytes}`, + ); } isCompatible() { @@ -567,8 +653,8 @@ export class ReadContext { this._depth++; if (this._depth > this._maxDepth) { throw new Error( - `Deserialization depth limit exceeded: ${this._depth} > ${this._maxDepth}. ` + - "The data may be malicious, or increase maxDepth if needed.", + `Deserialization depth limit exceeded: ${this._depth} > ${this._maxDepth}. ` + + "The data may be malicious, or increase maxDepth if needed.", ); } } @@ -628,7 +714,12 @@ export class ReadContext { const idOrLen = this.reader.readVarUInt32(); if (idOrLen & 1) { const typeMeta = this.readTypeMetaRef(idOrLen); - this.checkNamedTypeMeta(typeMeta, expectedTypeId, expectedNamespace, expectedTypeName); + this.checkNamedTypeMeta( + typeMeta, + expectedTypeId, + expectedNamespace, + expectedTypeName, + ); return typeMeta; } const dynamicTypeId = idOrLen >> 1; @@ -661,14 +752,21 @@ export class ReadContext { this.typeResolver.config.maxTypeMetaBytes, ); const typeMetaEnd = this.reader.readGetCursor(); - this.checkNamedTypeMeta(typeMeta, expectedTypeId, expectedNamespace, expectedTypeName); + this.checkNamedTypeMeta( + typeMeta, + expectedTypeId, + expectedNamespace, + expectedTypeName, + ); const localSerializer = this.serializerByTypeMeta(typeMeta); if (localSerializer === undefined) { throw new Error( `can't find serializer for TypeMeta ${typeMeta.getNs()}$${typeMeta.getTypeName()}`, ); } - if (this.matchesExactLocalTypeMeta(typeMeta, typeMetaStart, typeMetaEnd)) { + if ( + this.matchesExactLocalTypeMeta(typeMeta, typeMetaStart, typeMetaEnd) + ) { this.cacheTypeMeta(headerHash, typeMeta, undefined); } else { const typeKey = this.checkRemoteTypeMetaLimit(typeMeta); @@ -678,12 +776,20 @@ export class ReadContext { return typeMeta; } } - this.checkNamedTypeMeta(typeMeta, expectedTypeId, expectedNamespace, expectedTypeName); + this.checkNamedTypeMeta( + typeMeta, + expectedTypeId, + expectedNamespace, + expectedTypeName, + ); this.typeMeta[dynamicTypeId] = typeMeta; return typeMeta; } - readCompatibleStructSerializer(localHash: number, original?: Serializer): Serializer | undefined { + readCompatibleStructSerializer( + localHash: number, + original?: Serializer, + ): Serializer | undefined { const idOrLen = this.reader.readVarUInt32(); let typeMeta: TypeMeta; let remoteHash: number; @@ -709,7 +815,12 @@ export class ReadContext { remoteHash = headerHash; } if (localHash !== remoteHash) { - return this.ensureCompatibleReadSerializer(typeMeta, localHash, remoteHash, original); + return this.ensureCompatibleReadSerializer( + typeMeta, + localHash, + remoteHash, + original, + ); } return undefined; } @@ -730,14 +841,14 @@ export class ReadContext { expectedTypeName: string, ) { if ( - typeMeta.getTypeId() !== expectedTypeId || - typeMeta.getNs() !== expectedNamespace || - typeMeta.getTypeName() !== expectedTypeName + typeMeta.getTypeId() !== expectedTypeId + || typeMeta.getNs() !== expectedNamespace + || typeMeta.getTypeName() !== expectedTypeName ) { throw new Error( - `TypeMeta mismatch: expected ${expectedNamespace}$${expectedTypeName} ` + - `type ${expectedTypeId}, got ${typeMeta.getNs()}$${typeMeta.getTypeName()} ` + - `type ${typeMeta.getTypeId()}`, + `TypeMeta mismatch: expected ${expectedNamespace}$${expectedTypeName} ` + + `type ${expectedTypeId}, got ${typeMeta.getNs()}$${typeMeta.getTypeName()} ` + + `type ${typeMeta.getTypeId()}`, ); } } @@ -778,17 +889,25 @@ export class ReadContext { this.typeResolver.config.maxTypeMetaBytes, ); const typeMetaEnd = this.reader.readGetCursor(); - if (this.matchesExactLocalTypeMeta(typeMeta, typeMetaStart, typeMetaEnd)) { + if ( + this.matchesExactLocalTypeMeta(typeMeta, typeMetaStart, typeMetaEnd) + ) { this.cacheTypeMeta(headerHash, typeMeta, undefined); } else { const localSerializer = original ?? this.serializerByTypeMeta(typeMeta); - if (localSerializer === undefined && !TypeId.structType(typeMeta.getTypeId())) { + if ( + localSerializer === undefined + && !TypeId.structType(typeMeta.getTypeId()) + ) { throw new Error( `can't find serializer for TypeMeta ${typeMeta.getNs()}$${typeMeta.getTypeName()}`, ); } const typeKey = this.checkRemoteTypeMetaLimit(typeMeta); - if (localSerializer !== undefined && TypeId.structType(typeMeta.getTypeId())) { + if ( + localSerializer !== undefined + && TypeId.structType(typeMeta.getTypeId()) + ) { const expectedHash = localHash ?? localSerializer.getHash(); if (expectedHash !== typeMeta.getHash()) { this.ensureCompatibleReadSerializer( @@ -798,8 +917,16 @@ export class ReadContext { localSerializer, ); } - } else if (localHash !== undefined && localHash !== typeMeta.getHash()) { - this.ensureCompatibleReadSerializer(typeMeta, localHash, typeMeta.getHash(), original); + } else if ( + localHash !== undefined + && localHash !== typeMeta.getHash() + ) { + this.ensureCompatibleReadSerializer( + typeMeta, + localHash, + typeMeta.getHash(), + original, + ); } this.cacheTypeMeta(headerHash, typeMeta, typeKey); } @@ -832,30 +959,33 @@ export class ReadContext { : typeMeta.getUserTypeId(); const versionsByType = this.remoteSchemaVersionsByType; const versionsForType = versionsByType?.get(typeKey) ?? 0; - const maxSchemaVersionsPerType = this.typeResolver.config.maxSchemaVersionsPerType; + const maxSchemaVersionsPerType + = this.typeResolver.config.maxSchemaVersionsPerType; if (versionsForType >= maxSchemaVersionsPerType) { throw new Error( - `Remote schema version limit exceeded for type ${String(typeKey)}: ` + - `${versionsForType} >= ${maxSchemaVersionsPerType}. The data may ` + - "be malicious. If the data is not malicious, please increase " + - "maxSchemaVersionsPerType.", + `Remote schema version limit exceeded for type ${String(typeKey)}: ` + + `${versionsForType} >= ${maxSchemaVersionsPerType}. The data may ` + + "be malicious. If the data is not malicious, please increase " + + "maxSchemaVersionsPerType.", ); } - const acceptedTypeCount = - versionsForType === 0 ? (versionsByType?.size ?? 0) + 1 : versionsByType!.size; - const maxAverageSchemaVersionsPerType = - this.typeResolver.config.maxAverageSchemaVersionsPerType; + const acceptedTypeCount + = versionsForType === 0 + ? (versionsByType?.size ?? 0) + 1 + : versionsByType!.size; + const maxAverageSchemaVersionsPerType + = this.typeResolver.config.maxAverageSchemaVersionsPerType; const globalLimit = Math.max( ReadContext.MIN_REMOTE_TYPE_META_LIMIT, acceptedTypeCount * maxAverageSchemaVersionsPerType, ); if (this.totalAcceptedSchemaVersions >= globalLimit) { throw new Error( - `Remote schema version limit exceeded: ${this.totalAcceptedSchemaVersions} ` + - `metadata versions for ${acceptedTypeCount} accepted remote types ` + - `exceeds the average limit ${maxAverageSchemaVersionsPerType}. ` + - "The data may be malicious. If the data is not malicious, please " + - "increase maxAverageSchemaVersionsPerType.", + `Remote schema version limit exceeded: ${this.totalAcceptedSchemaVersions} ` + + `metadata versions for ${acceptedTypeCount} accepted remote types ` + + `exceeds the average limit ${maxAverageSchemaVersionsPerType}. ` + + "The data may be malicious. If the data is not malicious, please " + + "increase maxAverageSchemaVersionsPerType.", ); } return typeKey; @@ -883,15 +1013,24 @@ export class ReadContext { private serializerByTypeMeta(typeMeta: TypeMeta) { const typeId = typeMeta.getTypeId(); if (TypeId.isNamedType(typeId)) { - return this.typeResolver.getSerializerByName(`${typeMeta.getNs()}$${typeMeta.getTypeName()}`); + return this.typeResolver.getSerializerByName( + `${typeMeta.getNs()}$${typeMeta.getTypeName()}`, + ); } if (TypeId.needsUserTypeId(typeId)) { - return this.typeResolver.getSerializerById(typeId, typeMeta.getUserTypeId()); + return this.typeResolver.getSerializerById( + typeId, + typeMeta.getUserTypeId(), + ); } return this.typeResolver.getSerializerById(typeId); } - private matchesExactLocalTypeMeta(remoteTypeMeta: TypeMeta, start: number, end: number): boolean { + private matchesExactLocalTypeMeta( + remoteTypeMeta: TypeMeta, + start: number, + end: number, + ): boolean { const serializer = this.serializerByTypeMeta(remoteTypeMeta); const localBytes = serializer?.getTypeMetaBytes?.(); if (localBytes === undefined) { @@ -939,23 +1078,28 @@ export class ReadContext { if (remote === undefined || local === undefined) { return false; } - if (this.canonicalTypeId(remote.typeId) !== this.canonicalFieldTypeId(local)) { + if ( + this.canonicalTypeId(remote.typeId) !== this.canonicalFieldTypeId(local) + ) { return false; } if ( - (remote.trackingRef === true) !== (local.trackingRef === true) || - (remote.nullable === true) !== (local.nullable === true) + (remote.trackingRef === true) !== (local.trackingRef === true) + || (remote.nullable === true) !== (local.nullable === true) ) { return false; } switch (remote.typeId) { case TypeId.MAP: return ( - this.fieldSchemasEqual(remote.options?.key, local.options?.key) && - this.fieldSchemasEqual(remote.options?.value, local.options?.value) + this.fieldSchemasEqual(remote.options?.key, local.options?.key) + && this.fieldSchemasEqual(remote.options?.value, local.options?.value) ); case TypeId.LIST: - return this.fieldSchemasEqual(remote.options?.inner, local.options?.inner); + return this.fieldSchemasEqual( + remote.options?.inner, + local.options?.inner, + ); case TypeId.SET: return this.fieldSchemasEqual(remote.options?.key, local.options?.key); default: @@ -972,39 +1116,62 @@ export class ReadContext { if (this.fieldSchemasEqual(fieldInfo, fallbackTypeInfo)) { return fallbackTypeInfo.clone(); } - const compatible = this.compatibleFieldTypeInfo(fieldInfo, fallbackTypeInfo); + const compatible = this.compatibleFieldTypeInfo( + fieldInfo, + fallbackTypeInfo, + ); if (compatible) { return compatible; } if ( - isCompatibleScalarType(fieldInfo.typeId) && - isCompatibleScalarType(fallbackTypeInfo.typeId) && - ((fieldInfo.trackingRef === true) !== (fallbackTypeInfo.trackingRef === true) || - ((fieldInfo.trackingRef === true || fallbackTypeInfo.trackingRef === true) && - (fieldInfo.typeId !== fallbackTypeInfo.typeId || - fieldInfo.nullable !== fallbackTypeInfo.nullable))) + isCompatibleScalarType(fieldInfo.typeId) + && isCompatibleScalarType(fallbackTypeInfo.typeId) + && ((fieldInfo.trackingRef === true) + !== (fallbackTypeInfo.trackingRef === true) + || ((fieldInfo.trackingRef === true + || fallbackTypeInfo.trackingRef === true) + && (fieldInfo.typeId !== fallbackTypeInfo.typeId + || fieldInfo.nullable !== fallbackTypeInfo.nullable))) ) { - throw new Error("unsupported compatible scalar tracking-ref schema mismatch"); + throw new Error( + "unsupported compatible scalar tracking-ref schema mismatch", + ); } if ( - isCompatibleScalarPair(fieldInfo.typeId, fallbackTypeInfo.typeId) && - fieldInfo.typeId !== fallbackTypeInfo.typeId && - (fieldInfo.trackingRef === true || fallbackTypeInfo.trackingRef === true) + isCompatibleScalarPair(fieldInfo.typeId, fallbackTypeInfo.typeId) + && fieldInfo.typeId !== fallbackTypeInfo.typeId + && (fieldInfo.trackingRef === true + || fallbackTypeInfo.trackingRef === true) ) { - throw new Error("unsupported compatible scalar tracking-ref schema mismatch"); + throw new Error( + "unsupported compatible scalar tracking-ref schema mismatch", + ); } - if (this.hasUnsupportedListArrayMismatch(fieldInfo, fallbackTypeInfo, topLevel)) { + if ( + this.hasUnsupportedListArrayMismatch( + fieldInfo, + fallbackTypeInfo, + topLevel, + ) + ) { throw new Error("unsupported compatible list/array schema mismatch"); } if ( - fieldInfo.typeId !== TypeId.UNKNOWN && - this.canonicalFieldTypeId(fallbackTypeInfo) !== TypeId.UNKNOWN && - this.canonicalTypeId(fieldInfo.typeId) !== this.canonicalFieldTypeId(fallbackTypeInfo) + fieldInfo.typeId !== TypeId.UNKNOWN + && this.canonicalFieldTypeId(fallbackTypeInfo) !== TypeId.UNKNOWN + && this.canonicalTypeId(fieldInfo.typeId) + !== this.canonicalFieldTypeId(fallbackTypeInfo) ) { throw new Error("unsupported compatible field schema mismatch"); } } - if (this.hasUnsupportedListArrayMismatch(fieldInfo, fallbackTypeInfo, topLevel)) { + if ( + this.hasUnsupportedListArrayMismatch( + fieldInfo, + fallbackTypeInfo, + topLevel, + ) + ) { throw new Error("unsupported compatible list/array schema mismatch"); } if (this.hasNestedSchemaMismatch(fieldInfo, fallbackTypeInfo, topLevel)) { @@ -1013,7 +1180,11 @@ export class ReadContext { switch (fieldInfo.typeId) { case TypeId.MAP: return Type.map( - this.fieldInfoToTypeInfo(fieldInfo.options!.key!, fallbackTypeInfo?.options?.key, false), + this.fieldInfoToTypeInfo( + fieldInfo.options!.key!, + fallbackTypeInfo?.options?.key, + false, + ), this.fieldInfoToTypeInfo( fieldInfo.options!.value!, fallbackTypeInfo?.options?.value, @@ -1030,7 +1201,11 @@ export class ReadContext { ); case TypeId.SET: return Type.set( - this.fieldInfoToTypeInfo(fieldInfo.options!.key!, fallbackTypeInfo?.options?.key, false), + this.fieldInfoToTypeInfo( + fieldInfo.options!.key!, + fallbackTypeInfo?.options?.key, + false, + ), ); default: { // Remote TypeMeta only carries the nested user-defined type kind, not the @@ -1075,37 +1250,53 @@ export class ReadContext { return false; } if ( - this.schemaMatchTypeId(remote.typeId) !== - this.schemaMatchTypeId(this.typeResolver.computeTypeId(local)) + this.schemaMatchTypeId(remote.typeId) + !== this.schemaMatchTypeId(this.typeResolver.computeTypeId(local)) ) { return true; } const remoteTracksRef = remote.trackingRef === true; const localTracksRef = local.trackingRef === true; if ( - remoteTracksRef !== localTracksRef || - ((remoteTracksRef || localTracksRef) && - (remote.nullable === true) !== (local.nullable === true)) + remoteTracksRef !== localTracksRef + || ((remoteTracksRef || localTracksRef) + && (remote.nullable === true) !== (local.nullable === true)) ) { return true; } switch (remote.typeId) { case TypeId.MAP: return ( - local.options?.key === undefined || - local.options?.value === undefined || - this.hasNestedSchemaMismatch(remote.options!.key!, local.options.key, false) || - this.hasNestedSchemaMismatch(remote.options!.value!, local.options.value, false) + local.options?.key === undefined + || local.options?.value === undefined + || this.hasNestedSchemaMismatch( + remote.options!.key!, + local.options.key, + false, + ) + || this.hasNestedSchemaMismatch( + remote.options!.value!, + local.options.value, + false, + ) ); case TypeId.LIST: return ( - local.options?.inner === undefined || - this.hasNestedSchemaMismatch(remote.options!.inner!, local.options.inner, false) + local.options?.inner === undefined + || this.hasNestedSchemaMismatch( + remote.options!.inner!, + local.options.inner, + false, + ) ); case TypeId.SET: return ( - local.options?.key === undefined || - this.hasNestedSchemaMismatch(remote.options!.key!, local.options.key, false) + local.options?.key === undefined + || this.hasNestedSchemaMismatch( + remote.options!.key!, + local.options.key, + false, + ) ); default: return false; @@ -1116,22 +1307,25 @@ export class ReadContext { return this.canonicalTypeId(typeId); } - private compatibleFieldTypeInfo(remote: InnerFieldInfo, local: TypeInfo): TypeInfo | undefined { + private compatibleFieldTypeInfo( + remote: InnerFieldInfo, + local: TypeInfo, + ): TypeInfo | undefined { if (this.isByteSequenceRootPair(remote, local)) { if ( - (remote.nullable === true) !== (local.nullable === true) || - (remote.trackingRef === true) !== (local.trackingRef === true) + (remote.nullable === true) !== (local.nullable === true) + || (remote.trackingRef === true) !== (local.trackingRef === true) ) { return undefined; } return local.clone(); } if ( - this.isListArrayRootPair(remote, local) && - (remote.nullable === true || - local.nullable === true || - remote.trackingRef === true || - local.trackingRef === true) + this.isListArrayRootPair(remote, local) + && (remote.nullable === true + || local.nullable === true + || remote.trackingRef === true + || local.trackingRef === true) ) { return undefined; } @@ -1151,20 +1345,22 @@ export class ReadContext { } const remoteArrayElement = denseArrayElementTypeId(remote.typeId); if ( - remoteArrayElement !== undefined && - local.typeId === TypeId.LIST && - local.options?.inner && - compatibleArrayElementTypeId(local.options.inner.typeId) === remoteArrayElement + remoteArrayElement !== undefined + && local.typeId === TypeId.LIST + && local.options?.inner + && compatibleArrayElementTypeId(local.options.inner.typeId) + === remoteArrayElement ) { return compatibleArrayToListTypeInfo(remoteArrayElement); } if ( - remote.trackingRef !== true && - local.trackingRef !== true && - !( - remote.typeId === local.typeId && (remote.nullable === true) === (local.nullable === true) - ) && - isCompatibleScalarPair(remote.typeId, local.typeId) + remote.trackingRef !== true + && local.trackingRef !== true + && !( + remote.typeId === local.typeId + && (remote.nullable === true) === (local.nullable === true) + ) + && isCompatibleScalarPair(remote.typeId, local.typeId) ) { return markCompatibleScalarRead(local.clone(), { remoteTypeId: remote.typeId, @@ -1192,8 +1388,16 @@ export class ReadContext { switch (remote.typeId) { case TypeId.MAP: return ( - this.hasUnsupportedListArrayMismatch(remote.options!.key!, local.options?.key, false) || - this.hasUnsupportedListArrayMismatch(remote.options!.value!, local.options?.value, false) + this.hasUnsupportedListArrayMismatch( + remote.options!.key!, + local.options?.key, + false, + ) + || this.hasUnsupportedListArrayMismatch( + remote.options!.value!, + local.options?.value, + false, + ) ); case TypeId.LIST: return this.hasUnsupportedListArrayMismatch( @@ -1212,17 +1416,26 @@ export class ReadContext { } } - private isListArrayRootPair(remote: InnerFieldInfo, local: TypeInfo): boolean { + private isListArrayRootPair( + remote: InnerFieldInfo, + local: TypeInfo, + ): boolean { return ( - (remote.typeId === TypeId.LIST && denseArrayElementTypeId(local.typeId) !== undefined) || - (denseArrayElementTypeId(remote.typeId) !== undefined && local.typeId === TypeId.LIST) + (remote.typeId === TypeId.LIST + && denseArrayElementTypeId(local.typeId) !== undefined) + || (denseArrayElementTypeId(remote.typeId) !== undefined + && local.typeId === TypeId.LIST) ); } - private isByteSequenceRootPair(remote: InnerFieldInfo, local: TypeInfo): boolean { + private isByteSequenceRootPair( + remote: InnerFieldInfo, + local: TypeInfo, + ): boolean { return ( - (remote.typeId === TypeId.BINARY && local.typeId === TypeId.UINT8_ARRAY) || - (remote.typeId === TypeId.UINT8_ARRAY && local.typeId === TypeId.BINARY) + (remote.typeId === TypeId.BINARY + && local.typeId === TypeId.UINT8_ARRAY) + || (remote.typeId === TypeId.UINT8_ARRAY && local.typeId === TypeId.BINARY) ); } @@ -1236,7 +1449,10 @@ export class ReadContext { const named = `${typeMeta.getNs()}$${typeMeta.getTypeName()}`; original = this.typeResolver.getSerializerByName(named); } else { - original = this.typeResolver.getSerializerById(typeId, typeMeta.getUserTypeId()); + original = this.typeResolver.getSerializerById( + typeId, + typeMeta.getUserTypeId(), + ); } } let typeInfo: TypeInfo; @@ -1251,18 +1467,25 @@ export class ReadContext { }); } const localProps = original?.getTypeInfo().options?.props; - const fieldEntries = typeMeta.remapFieldNames(localProps).map((fieldInfo) => { - const localFieldTypeInfo = localProps?.[fieldInfo.getFieldName()]; - let fieldTypeInfo = this.fieldInfoToTypeInfo(fieldInfo, localFieldTypeInfo) - .setNullable(fieldInfo.nullable) - .setTrackingRef(fieldInfo.trackingRef) - .setId(fieldInfo.fieldId); - if (localFieldTypeInfo === undefined) { - fieldTypeInfo = markCompatibleSkipRead(fieldTypeInfo); - } - return { key: fieldInfo.getFieldName(), typeInfo: fieldTypeInfo }; - }); - const props = Object.fromEntries(fieldEntries.map(({ key, typeInfo }) => [key, typeInfo])); + const fieldEntries = typeMeta + .remapFieldNames(localProps) + .map((fieldInfo) => { + const localFieldTypeInfo = localProps?.[fieldInfo.getFieldName()]; + let fieldTypeInfo = this.fieldInfoToTypeInfo( + fieldInfo, + localFieldTypeInfo, + ) + .setNullable(fieldInfo.nullable) + .setTrackingRef(fieldInfo.trackingRef) + .setId(fieldInfo.fieldId); + if (localFieldTypeInfo === undefined) { + fieldTypeInfo = markCompatibleSkipRead(fieldTypeInfo); + } + return { key: fieldInfo.getFieldName(), typeInfo: fieldTypeInfo }; + }); + const props = Object.fromEntries( + fieldEntries.map(({ key, typeInfo }) => [key, typeInfo]), + ); typeInfo.options = { ...typeInfo.options, preserveFieldOrder: true, diff --git a/javascript/packages/core/lib/fory.ts b/javascript/packages/core/lib/fory.ts index f50d2fcebf..979216a5b9 100644 --- a/javascript/packages/core/lib/fory.ts +++ b/javascript/packages/core/lib/fory.ts @@ -38,6 +38,7 @@ const DEFAULT_MAX_TYPE_FIELDS = 512 as const; const DEFAULT_MAX_TYPE_META_BYTES = 4096 as const; const DEFAULT_MAX_SCHEMA_VERSIONS_PER_TYPE = 10 as const; const DEFAULT_MAX_AVERAGE_SCHEMA_VERSIONS_PER_TYPE = 3 as const; +const DEFAULT_MAX_CONTAINER_MEMORY_BYTES = -1 as const; export default class Fory { readonly typeResolver: TypeResolver; readonly anySerializer: Serializer; @@ -88,10 +89,21 @@ export default class Fory { `maxAverageSchemaVersionsPerType must be a positive integer but got ${maxAverageSchemaVersionsPerType}`, ); } + const maxContainerMemoryBytes + = config?.maxContainerMemoryBytes ?? DEFAULT_MAX_CONTAINER_MEMORY_BYTES; + if ( + !Number.isSafeInteger(maxContainerMemoryBytes) + || (maxContainerMemoryBytes !== -1 && maxContainerMemoryBytes <= 0) + ) { + throw new Error( + `maxContainerMemoryBytes must be -1 or a positive safe integer but got ${maxContainerMemoryBytes}`, + ); + } return { ref: Boolean(config?.ref), useSliceString: Boolean(config?.useSliceString), maxDepth: config?.maxDepth, + maxContainerMemoryBytes, maxTypeFields, maxTypeMetaBytes, maxSchemaVersionsPerType, diff --git a/javascript/packages/core/lib/gen/collection.ts b/javascript/packages/core/lib/gen/collection.ts index 03139bebb1..c551dffef4 100644 --- a/javascript/packages/core/lib/gen/collection.ts +++ b/javascript/packages/core/lib/gen/collection.ts @@ -97,6 +97,30 @@ function compatibleArrayCollectionExpr(elementTypeId: number, len: string): stri } } +function compatibleArrayElementBytes(elementTypeId: number): number { + switch (elementTypeId) { + case TypeId.BOOL: + case TypeId.INT8: + case TypeId.UINT8: + return 1; + case TypeId.INT16: + case TypeId.UINT16: + case TypeId.FLOAT16: + case TypeId.BFLOAT16: + return 2; + case TypeId.INT32: + case TypeId.UINT32: + case TypeId.FLOAT32: + return 4; + case TypeId.INT64: + case TypeId.UINT64: + case TypeId.FLOAT64: + return 8; + default: + return 4; + } +} + function compatibleArrayPutAccessor( elementTypeId: number, result: string, @@ -234,10 +258,12 @@ class CollectionAnySerializer { ): any { void fromRef; const len = this.readContext.reader.readVarUint32Small7(); + this.readContext.reserveCollectionMemory(len); if (len === 0) { return createCollection(len); } const flags = this.readContext.reader.readUint8(); + this.readContext.reader.checkReadableBytes(len); const result = createCollection(len); // IMPORTANT: collection readers must obey the ref/null bits written on the // wire, not local TypeScript metadata that may imply a different ref @@ -418,6 +444,9 @@ export abstract class CollectionSerializerGenerator extends BaseSerializerGenera const newCollection = compatibleListToArray ? compatibleArrayCollectionExpr(compatibleReadAction!.elementTypeId, len) : this.newCollection(len); + const reserveMemory = compatibleListToArray + ? `${readContextName}.reserveTypedArrayMemory(${len}, ${compatibleArrayElementBytes(compatibleReadAction!.elementTypeId)});` + : `${readContextName}.reserveCollectionMemory(${len});`; const putAccessor = (item: string, index: string) => compatibleListToArray ? compatibleArrayPutAccessor(compatibleReadAction!.elementTypeId, result, item, index) @@ -449,6 +478,7 @@ export abstract class CollectionSerializerGenerator extends BaseSerializerGenera : `${elemSerializer} = ${anyHelper}.detectSerializer(${readContextName});`; return ` const ${len} = ${this.builder.reader.readVarUint32Small7()}; + ${reserveMemory} let ${flags} = 0; if (${len} > 0) { ${flags} = ${this.builder.reader.readUint8()}; diff --git a/javascript/packages/core/lib/gen/map.ts b/javascript/packages/core/lib/gen/map.ts index db0e147a4d..ebb5f3b588 100644 --- a/javascript/packages/core/lib/gen/map.ts +++ b/javascript/packages/core/lib/gen/map.ts @@ -272,6 +272,7 @@ class MapAnySerializer { read(fromRef: boolean): any { let count = this.readContext.reader.readVarUint32Small7(); + this.readContext.reserveMapMemory(count); const result = new Map(); if (fromRef) { this.readContext.reference(result); @@ -491,6 +492,7 @@ export class MapSerializerGenerator extends BaseSerializerGenerator { return ` let ${count} = ${this.builder.reader.readVarUint32Small7()}; + ${readContextName}.reserveMapMemory(${count}); const ${result} = new Map(); if (${refState}) { ${this.builder.referenceResolver.reference(result)} diff --git a/javascript/packages/core/lib/type.ts b/javascript/packages/core/lib/type.ts index ddbef54ec9..6acf24bc96 100644 --- a/javascript/packages/core/lib/type.ts +++ b/javascript/packages/core/lib/type.ts @@ -291,6 +291,7 @@ export interface Config { ref: boolean; useSliceString: boolean; maxDepth?: number; + maxContainerMemoryBytes: number; maxTypeFields: number; maxTypeMetaBytes: number; maxSchemaVersionsPerType: number; diff --git a/javascript/test/containerMemoryBudget.test.ts b/javascript/test/containerMemoryBudget.test.ts new file mode 100644 index 0000000000..77907ea3e3 --- /dev/null +++ b/javascript/test/containerMemoryBudget.test.ts @@ -0,0 +1,225 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +import Fory, { Type } from '../packages/core/index'; +import { describe, expect, test } from '@jest/globals'; + +const KNOWN_SLACK_BYTES = 64 * 1024; + +function serializeAny(value: unknown) { + return new Fory({ compatible: false, ref: true }).serialize(value); +} + +function deserializeAny(bytes: Uint8Array, maxContainerMemoryBytes: number) { + return new Fory({ + compatible: false, + ref: true, + maxContainerMemoryBytes, + }).deserialize(bytes); +} + +describe('container memory budget', () => { + test('uses known length auto budget', () => { + const inputBytes = 17; + const fory = new Fory({ compatible: false }); + const budget = inputBytes * 8 + KNOWN_SLACK_BYTES; + + fory.readContext.reset(new Uint8Array(inputBytes)); + expect(() => fory.readContext.reserveContainerMemory(budget)).not.toThrow(); + expect(() => fory.readContext.reserveContainerMemory(1)).toThrow( + /maxContainerMemoryBytes/, + ); + }); + + test('validates explicit config', () => { + expect(() => new Fory({ maxContainerMemoryBytes: 0 })).toThrow( + /maxContainerMemoryBytes/, + ); + expect(() => new Fory({ maxContainerMemoryBytes: -2 })).toThrow( + /maxContainerMemoryBytes/, + ); + + const fory = new Fory({ maxContainerMemoryBytes: 24 }); + fory.readContext.reset(new Uint8Array(1)); + expect(() => fory.readContext.reserveCollectionMemory(0)).not.toThrow(); + expect(() => fory.readContext.reserveContainerMemory(1)).toThrow( + /maxContainerMemoryBytes/, + ); + }); + + test('charges nested empty containers', () => { + const typeInfo = Type.struct('budget.nested.empty', { + values: Type.list(Type.list(Type.int32({ encoding: 'fixed' }))).setId(1), + }); + const writer = new Fory({ compatible: false, ref: true }); + const bytes = writer.register(typeInfo).serialize({ values: [[]] }); + const passingReader = new Fory({ + compatible: false, + ref: true, + maxContainerMemoryBytes: 52, + }).register(typeInfo); + const failingReader = new Fory({ + compatible: false, + ref: true, + maxContainerMemoryBytes: 51, + }).register(typeInfo); + + expect(() => failingReader.deserialize(bytes)).toThrow( + /maxContainerMemoryBytes/, + ); + expect(passingReader.deserialize(bytes)).toEqual({ values: [[]] }); + }); + + test('charges sibling containers cumulatively', () => { + const typeInfo = Type.struct('budget.sibling.empty', { + values: Type.list(Type.list(Type.int32({ encoding: 'fixed' }))).setId(1), + }); + const writer = new Fory({ compatible: false, ref: true }); + const bytes = writer.register(typeInfo).serialize({ + values: [[], [], []], + }); + const passingReader = new Fory({ + compatible: false, + ref: true, + maxContainerMemoryBytes: 108, + }).register(typeInfo); + const failingReader = new Fory({ + compatible: false, + ref: true, + maxContainerMemoryBytes: 107, + }).register(typeInfo); + + expect(() => failingReader.deserialize(bytes)).toThrow( + /maxContainerMemoryBytes/, + ); + expect(passingReader.deserialize(bytes)).toEqual({ + values: [[], [], []], + }); + }); + + test('charges map entries', () => { + const bytes = serializeAny(new Map([[1, 2]])); + + expect(() => deserializeAny(bytes, 99)).toThrow(/maxContainerMemoryBytes/); + expect(deserializeAny(bytes, 100)).toEqual(new Map([[1, 2]])); + }); + + test('charges generated containers', () => { + const typeInfo = Type.struct('budget.generated', { + list: Type.list(Type.int32({ encoding: 'fixed' })).setId(1), + set: Type.set(Type.string()).setId(2), + map: Type.map(Type.string(), Type.int32({ encoding: 'fixed' })).setId(3), + }); + const writer = new Fory({ compatible: false, ref: true }); + const bytes = writer.register(typeInfo).serialize({ + list: [1], + set: new Set(['a']), + map: new Map([['k', 1]]), + }); + const passingReader = new Fory({ + compatible: false, + ref: true, + maxContainerMemoryBytes: 156, + }).register(typeInfo); + const failingReader = new Fory({ + compatible: false, + ref: true, + maxContainerMemoryBytes: 155, + }).register(typeInfo); + + expect(() => failingReader.deserialize(bytes)).toThrow( + /maxContainerMemoryBytes/, + ); + expect(passingReader.deserialize(bytes)).toEqual({ + list: [1], + set: new Set(['a']), + map: new Map([['k', 1]]), + }); + }); + + test('charges compatible typed arrays', () => { + const writerType = Type.struct(9010, { + values: Type.list(Type.int32({ encoding: 'fixed' })).setId(1), + }); + const readerType = Type.struct(9010, { + values: Type.int32Array().setId(1), + }); + const writer = new Fory({ compatible: true }); + const bytes = writer.register(writerType).serialize({ values: [1, 2, 3] }); + const passingReader = new Fory({ + compatible: true, + maxContainerMemoryBytes: 28, + }).register(readerType); + const failingReader = new Fory({ + compatible: true, + maxContainerMemoryBytes: 27, + }).register(readerType); + + expect(() => failingReader.deserialize(bytes)).toThrow( + /maxContainerMemoryBytes/, + ); + expect(Array.from(passingReader.deserialize(bytes).values)).toEqual([ + 1, + 2, + 3, + ]); + }); + + test('skips scalar dense owners', () => { + const typeInfo = Type.struct('budget.skipped', { + text: Type.string().setId(1), + binary: Type.binary().setId(2), + values: Type.int32Array().setId(3), + }); + const writer = new Fory({ compatible: false, ref: true }); + const bytes = writer.register(typeInfo).serialize({ + text: 'hello', + binary: new Uint8Array([1, 2, 3]), + values: new Int32Array([1, 2, 3]), + }); + const reader = new Fory({ + compatible: false, + ref: true, + maxContainerMemoryBytes: 1, + }).register(typeInfo); + + expect(reader.deserialize(bytes)).toEqual({ + text: 'hello', + binary: new Uint8Array([1, 2, 3]), + values: new Int32Array([1, 2, 3]), + }); + }); + + test('keeps byte checks', () => { + const typeInfo = Type.struct('budget.bytecheck', { + values: Type.int32Array().setId(1), + }); + const writer = new Fory({ compatible: false, ref: true }); + const bytes = writer.register(typeInfo).serialize({ + values: new Int32Array([1, 2, 3]), + }); + const reader = new Fory({ + compatible: false, + ref: true, + maxContainerMemoryBytes: 1024 * 1024, + }).register(typeInfo); + + expect(() => reader.deserialize(bytes.slice(0, bytes.length - 1))).toThrow(); + }); +}); diff --git a/kotlin/fory-kotlin/src/main/kotlin/org/apache/fory/serializer/kotlin/CollectionSerializer.kt b/kotlin/fory-kotlin/src/main/kotlin/org/apache/fory/serializer/kotlin/CollectionSerializer.kt index e3b36d4785..be2980313f 100644 --- a/kotlin/fory-kotlin/src/main/kotlin/org/apache/fory/serializer/kotlin/CollectionSerializer.kt +++ b/kotlin/fory-kotlin/src/main/kotlin/org/apache/fory/serializer/kotlin/CollectionSerializer.kt @@ -21,7 +21,6 @@ package org.apache.fory.serializer.kotlin import org.apache.fory.context.ReadContext import org.apache.fory.context.WriteContext -import org.apache.fory.exception.DeserializationException import org.apache.fory.resolver.TypeResolver import org.apache.fory.serializer.collection.CollectionLikeSerializer @@ -57,15 +56,8 @@ public class KotlinArrayDequeSerializer( } override fun newCollection(readContext: ReadContext): Collection { - val buffer = readContext.buffer - val numElements = buffer.readVarUInt32Small7() - if (numElements < 0) { - throw DeserializationException("Collection size must be non-negative: $numElements") - } + val numElements = readCollectionSize(readContext) setNumElements(numElements) - if (numElements != 0) { - buffer.checkReadableBytes(numElements) - } return ArrayDequeBuilder(ArrayDeque(numElements)) } } diff --git a/kotlin/fory-kotlin/src/test/kotlin/org/apache/fory/serializer/kotlin/CollectionSerializerTest.kt b/kotlin/fory-kotlin/src/test/kotlin/org/apache/fory/serializer/kotlin/CollectionSerializerTest.kt index 5684821ebe..1d17f39b91 100644 --- a/kotlin/fory-kotlin/src/test/kotlin/org/apache/fory/serializer/kotlin/CollectionSerializerTest.kt +++ b/kotlin/fory-kotlin/src/test/kotlin/org/apache/fory/serializer/kotlin/CollectionSerializerTest.kt @@ -20,8 +20,10 @@ package org.apache.fory.serializer.kotlin import org.apache.fory.Fory +import org.apache.fory.exception.InsecureException import org.apache.fory.kotlin.ForyKotlin import org.testng.Assert.assertEquals +import org.testng.Assert.fail import org.testng.annotations.Test class CollectionSerializerTest { @@ -33,6 +35,22 @@ class CollectionSerializerTest { assertEquals(arrayDeque, fory.deserialize(fory.serialize(arrayDeque))) } + @Test + fun testArrayDequeContainerMemoryBudget() { + val writer: Fory = ForyKotlin.builder().withXlang(false).requireClassRegistration(true).build() + val reader: Fory = + ForyKotlin.builder() + .withXlang(false) + .requireClassRegistration(true) + .withMaxContainerMemoryBytes(23) + .build() + + try { + reader.deserialize(writer.serialize(ArrayDeque())) + fail("Expected container memory budget failure") + } catch (ignored: InsecureException) {} + } + @Test fun testSerializeArrayList() { val fory: Fory = ForyKotlin.builder().withXlang(false).requireClassRegistration(true).build() diff --git a/python/pyfory/_fory.py b/python/pyfory/_fory.py index e4819ba424..0d6f4d0591 100644 --- a/python/pyfory/_fory.py +++ b/python/pyfory/_fory.py @@ -124,6 +124,7 @@ class Fory: "strict", "buffer", "max_depth", + "max_container_memory_bytes", "field_nullable", "policy", ) @@ -139,6 +140,7 @@ def __init__( max_type_meta_bytes: int = 4096, max_schema_versions_per_type: int = 10, max_average_schema_versions_per_type: int = 3, + max_container_memory_bytes: int = -1, policy: DeserializationPolicy = None, field_nullable: bool = False, meta_compressor=None, @@ -183,6 +185,9 @@ def __init__( max_average_schema_versions_per_type: Average remote metadata versions allowed across accepted remote types. + max_container_memory_bytes: Maximum estimated container-owned memory per root + deserialization. `-1` means auto; positive values are explicit byte limits. + policy: Custom deserialization policy for security checks. When provided, it controls which types can be deserialized, overriding the default policy. **Strongly recommended** when strict=False to maintain security controls. @@ -213,6 +218,13 @@ def __init__( raise ValueError("max_schema_versions_per_type must be a positive integer") if not isinstance(max_average_schema_versions_per_type, int) or max_average_schema_versions_per_type <= 0: raise ValueError("max_average_schema_versions_per_type must be a positive integer") + if ( + not isinstance(max_container_memory_bytes, int) + or (max_container_memory_bytes != -1 and max_container_memory_bytes <= 0) + or max_container_memory_bytes > (1 << 63) - 1 + ): + raise ValueError("max_container_memory_bytes must be -1 or a positive 63-bit integer") + self.max_container_memory_bytes = max_container_memory_bytes self.config = Config( xlang=xlang, track_ref=ref, @@ -225,6 +237,7 @@ def __init__( max_type_meta_bytes=max_type_meta_bytes, max_schema_versions_per_type=max_schema_versions_per_type, max_average_schema_versions_per_type=max_average_schema_versions_per_type, + max_container_memory_bytes=max_container_memory_bytes, field_nullable=field_nullable, policy=self.policy, meta_compressor=meta_compressor, @@ -559,6 +572,7 @@ def _deserialize( buffers=buffers, unsupported_objects=unsupported_objects, peer_out_of_band_enabled=peer_out_of_band_enabled, + root_input_bytes=buffer.size() - reader_index, ) return read_context.read_ref() diff --git a/python/pyfory/collection.pxi b/python/pyfory/collection.pxi index 0183b26231..6dd5c5c4dc 100644 --- a/python/pyfory/collection.pxi +++ b/python/pyfory/collection.pxi @@ -466,10 +466,24 @@ cdef class ListSerializer(CollectionSerializer): cdef int8_t head_flag cdef int32_t ref_id cdef int64_t i + cdef int64_t container_bytes if len_ == 0: + container_bytes = read_context.remaining_container_memory_bytes - _COLLECTION_OBJECT_BYTES + if container_bytes < 0: + read_context.reserve_container_memory_fast(_COLLECTION_OBJECT_BYTES) + else: + read_context.remaining_container_memory_bytes = container_bytes list_ = PyList_New(0) return list_ + if len_ < 0: + read_context.reserve_collection_memory_c(len_) + else: + container_bytes = _COLLECTION_OBJECT_BYTES + len_ * _REFERENCE_BYTES + if container_bytes > read_context.remaining_container_memory_bytes: + read_context.reserve_container_memory_fast(container_bytes) + else: + read_context.remaining_container_memory_bytes -= container_bytes read_context.check_readable_bytes(len_) collect_flag = buffer.read_int8() @@ -583,10 +597,24 @@ cdef class TupleSerializer(CollectionSerializer): cdef bint has_null cdef int8_t head_flag cdef int64_t i + cdef int64_t container_bytes if len_ == 0: + container_bytes = read_context.remaining_container_memory_bytes - _COLLECTION_OBJECT_BYTES + if container_bytes < 0: + read_context.reserve_container_memory_fast(_COLLECTION_OBJECT_BYTES) + else: + read_context.remaining_container_memory_bytes = container_bytes tuple_ = PyTuple_New(0) return tuple_ + if len_ < 0: + read_context.reserve_collection_memory_c(len_) + else: + container_bytes = _COLLECTION_OBJECT_BYTES + len_ * _REFERENCE_BYTES + if container_bytes > read_context.remaining_container_memory_bytes: + read_context.reserve_container_memory_fast(container_bytes) + else: + read_context.remaining_container_memory_bytes -= container_bytes read_context.check_readable_bytes(len_) collect_flag = buffer.read_int8() @@ -684,7 +712,7 @@ cdef class StringArraySerializer(ListSerializer): @cython.final cdef class SetSerializer(CollectionSerializer): cpdef read(self, ReadContext read_context): - cdef set instance = set() + cdef set instance cdef int32_t len_ cdef int8_t collect_flag cdef TypeInfo typeinfo @@ -701,11 +729,29 @@ cdef class SetSerializer(CollectionSerializer): cdef int8_t head_flag cdef int32_t ref_id cdef int64_t i + cdef int64_t container_bytes - read_context.reference(instance) len_ = buffer.read_var_uint32() if len_ == 0: + container_bytes = read_context.remaining_container_memory_bytes - _COLLECTION_OBJECT_BYTES + if container_bytes < 0: + read_context.reserve_container_memory_fast(_COLLECTION_OBJECT_BYTES) + else: + read_context.remaining_container_memory_bytes = container_bytes + instance = set() + read_context.reference(instance) return instance + if len_ < 0: + read_context.reserve_collection_memory_c(len_) + else: + container_bytes = _COLLECTION_OBJECT_BYTES + len_ * _REFERENCE_BYTES + if container_bytes > read_context.remaining_container_memory_bytes: + read_context.reserve_container_memory_fast(container_bytes) + else: + read_context.remaining_container_memory_bytes -= container_bytes + read_context.check_readable_bytes(len_) + instance = set() + read_context.reference(instance) collect_flag = buffer.read_int8() if (collect_flag & COLL_IS_SAME_TYPE) != 0: @@ -1048,9 +1094,23 @@ cdef class MapSerializer(Serializer): cdef int32_t ref_id cdef dict map_ cdef int8_t chunk_header = 0 + cdef int64_t container_bytes if size == 0: + container_bytes = read_context.remaining_container_memory_bytes - _MAP_OBJECT_BYTES + if container_bytes < 0: + read_context.reserve_container_memory_fast(_MAP_OBJECT_BYTES) + else: + read_context.remaining_container_memory_bytes = container_bytes + map_ = {} + elif size < 0: + read_context.reserve_map_memory_c(size) map_ = {} else: + container_bytes = _MAP_OBJECT_BYTES + size * (_MAP_ENTRY_BYTES + 5 * _REFERENCE_BYTES) + if container_bytes > read_context.remaining_container_memory_bytes: + read_context.reserve_container_memory_fast(container_bytes) + else: + read_context.remaining_container_memory_bytes -= container_bytes read_context.check_readable_bytes(size) chunk_header = read_context.read_uint8() map_ = _PyDict_NewPresized(size) diff --git a/python/pyfory/collection.py b/python/pyfory/collection.py index d78673a6dc..c2e2e2a058 100644 --- a/python/pyfory/collection.py +++ b/python/pyfory/collection.py @@ -176,6 +176,9 @@ def _write_different_types(self, write_context, value, collect_flag=0): def read(self, read_context): length = read_context.read_var_uint32() + read_context.reserve_collection_memory(length) + if length != 0: + read_context.check_readable_bytes(length) collection_ = self.new_instance(read_context, self.type_) if length == 0: return collection_ @@ -455,6 +458,9 @@ def write(self, write_context, obj): def read(self, read_context): size = read_context.read_var_uint32() + read_context.reserve_map_memory(size) + if size != 0: + read_context.check_readable_bytes(size) map_ = {} ref_reader = read_context.ref_reader read_context.reference(map_) diff --git a/python/pyfory/context.pxi b/python/pyfory/context.pxi index 702f09769c..a27084d466 100644 --- a/python/pyfory/context.pxi +++ b/python/pyfory/context.pxi @@ -30,6 +30,14 @@ STRING_TYPE_ID = TypeId.STRING SMALL_STRING_THRESHOLD = 16 cdef int32_t MAX_CACHED_META_STRINGS = 8192 cdef int32_t MAX_CACHED_META_STRING_LENGTH = 2048 +cdef int64_t _KNOWN_ROOT_BUDGET_MULTIPLIER = 8 +cdef int64_t _KNOWN_ROOT_BUDGET_SLACK_BYTES = 64 * 1024 +cdef int64_t _STREAM_ROOT_BUDGET_BYTES = 128 * 1024 * 1024 +cdef int64_t _COLLECTION_OBJECT_BYTES = 56 +cdef int64_t _MAP_OBJECT_BYTES = 64 +cdef int64_t _MAP_ENTRY_BYTES = 32 +cdef int64_t _REFERENCE_BYTES = sizeof(PyObject*) +cdef int64_t _MAX_CONTAINER_MEMORY_BYTES = 9223372036854775807 cdef inline uint64_t _mix64(uint64_t x): @@ -746,6 +754,9 @@ cdef class ReadContext: cdef readonly bint field_nullable cdef readonly object policy cdef readonly int32_t max_depth + cdef public int64_t max_container_memory_bytes + cdef public int64_t container_memory_limit_bytes + cdef public int64_t remaining_container_memory_bytes cdef readonly RefReader ref_reader cdef readonly MetaStringReader meta_string_reader cdef readonly MetaShareReadContext meta_share_context @@ -766,6 +777,9 @@ cdef class ReadContext: self.field_nullable = config.field_nullable self.policy = config.policy self.max_depth = config.max_depth + self.max_container_memory_bytes = config.max_container_memory_bytes + self.container_memory_limit_bytes = 0 + self.remaining_container_memory_bytes = 0 self.ref_reader = RefReader(self.track_ref) self.meta_string_reader = MetaStringReader(self.type_resolver.shared_registry) self.meta_share_context = MetaShareReadContext() if config.scoped_meta_share_enabled else None @@ -783,12 +797,26 @@ cdef class ReadContext: buffers=None, unsupported_objects=None, bint peer_out_of_band_enabled=False, + int64_t root_input_bytes=-1, ): + cdef int64_t limit + if self.max_container_memory_bytes > 0: + limit = self.max_container_memory_bytes + elif buffer.has_input_stream(): + limit = _STREAM_ROOT_BUDGET_BYTES + else: + if root_input_bytes < 0: + root_input_bytes = buffer.size() - buffer.get_reader_index() + if root_input_bytes > (_MAX_CONTAINER_MEMORY_BYTES - _KNOWN_ROOT_BUDGET_SLACK_BYTES) // _KNOWN_ROOT_BUDGET_MULTIPLIER: + raise ValueError("max_container_memory_bytes auto budget overflow") + limit = root_input_bytes * _KNOWN_ROOT_BUDGET_MULTIPLIER + _KNOWN_ROOT_BUDGET_SLACK_BYTES self.buffer = buffer self.c_buffer = buffer.c_buffer self.buffers = iter(buffers) if buffers is not None else None self.unsupported_objects = iter(unsupported_objects) if unsupported_objects is not None else None self.peer_out_of_band_enabled = peer_out_of_band_enabled + self.container_memory_limit_bytes = limit + self.remaining_container_memory_bytes = limit self.depth = 0 cpdef inline reset(self): @@ -803,8 +831,61 @@ cdef class ReadContext: self.buffers = None self.unsupported_objects = None self.peer_out_of_band_enabled = False + self.container_memory_limit_bytes = 0 + self.remaining_container_memory_bytes = 0 self.depth = 0 + cdef inline void reserve_container_memory_c(self, int64_t num_bytes): + cdef int64_t used + if num_bytes < 0: + raise ValueError("Estimated container memory is negative") + if num_bytes > _MAX_CONTAINER_MEMORY_BYTES: + raise ValueError("Estimated container memory overflow") + if num_bytes > self.remaining_container_memory_bytes: + used = self.container_memory_limit_bytes - self.remaining_container_memory_bytes + raise ValueError( + f"Estimated container memory budget exceeded: requested {num_bytes} bytes, " + f"used {used} bytes, limit {self.container_memory_limit_bytes} bytes. " + "Increase Fory(..., max_container_memory_bytes=...) for trusted larger payloads." + ) + self.remaining_container_memory_bytes -= num_bytes + + cdef inline void reserve_container_memory_fast(self, int64_t num_bytes): + cdef int64_t used + if num_bytes > self.remaining_container_memory_bytes: + used = self.container_memory_limit_bytes - self.remaining_container_memory_bytes + raise ValueError( + f"Estimated container memory budget exceeded: requested {num_bytes} bytes, " + f"used {used} bytes, limit {self.container_memory_limit_bytes} bytes. " + "Increase Fory(..., max_container_memory_bytes=...) for trusted larger payloads." + ) + self.remaining_container_memory_bytes -= num_bytes + + cpdef inline reserve_container_memory(self, int64_t num_bytes): + self.reserve_container_memory_c(num_bytes) + + cdef inline void reserve_collection_memory_c(self, int64_t num_elements): + if num_elements < 0: + raise ValueError("Container element count is negative") + if num_elements > (_MAX_CONTAINER_MEMORY_BYTES - _COLLECTION_OBJECT_BYTES) // _REFERENCE_BYTES: + raise ValueError("Estimated container memory overflow") + self.reserve_container_memory_c(_COLLECTION_OBJECT_BYTES + num_elements * _REFERENCE_BYTES) + + cpdef inline reserve_collection_memory(self, int64_t num_elements): + self.reserve_collection_memory_c(num_elements) + + cdef inline void reserve_map_memory_c(self, int64_t num_elements): + cdef int64_t bytes_per_entry + if num_elements < 0: + raise ValueError("Map entry count is negative") + bytes_per_entry = _MAP_ENTRY_BYTES + 5 * _REFERENCE_BYTES + if num_elements > (_MAX_CONTAINER_MEMORY_BYTES - _MAP_OBJECT_BYTES) // bytes_per_entry: + raise ValueError("Estimated container memory overflow") + self.reserve_container_memory_c(_MAP_OBJECT_BYTES + num_elements * bytes_per_entry) + + cpdef inline reserve_map_memory(self, int64_t num_elements): + self.reserve_map_memory_c(num_elements) + cpdef inline add_context_object(self, key, obj): self.context_objects[id(key)] = obj diff --git a/python/pyfory/context.py b/python/pyfory/context.py index 3abfb46e3d..a923731c4b 100644 --- a/python/pyfory/context.py +++ b/python/pyfory/context.py @@ -17,6 +17,8 @@ from __future__ import annotations +import struct + from pyfory.serialization import Config from pyfory.lib import mmh3 from pyfory.meta.metastring import Encoding @@ -37,6 +39,14 @@ FLOAT64_TYPE_ID = TypeId.FLOAT64 BOOL_TYPE_ID = TypeId.BOOL STRING_TYPE_ID = TypeId.STRING +_KNOWN_ROOT_BUDGET_MULTIPLIER = 8 +_KNOWN_ROOT_BUDGET_SLACK_BYTES = 64 * 1024 +_STREAM_ROOT_BUDGET_BYTES = 128 * 1024 * 1024 +_COLLECTION_OBJECT_BYTES = 56 +_MAP_OBJECT_BYTES = 64 +_MAP_ENTRY_BYTES = 32 +_REFERENCE_BYTES = struct.calcsize("P") +_MAX_CONTAINER_MEMORY_BYTES = (1 << 63) - 1 def _mix64(x: int) -> int: @@ -470,6 +480,9 @@ class ReadContext: "field_nullable", "policy", "max_depth", + "max_container_memory_bytes", + "container_memory_limit_bytes", + "remaining_container_memory_bytes", "ref_reader", "meta_string_reader", "meta_share_context", @@ -490,6 +503,9 @@ def __init__(self, config: Config, type_resolver): self.field_nullable = config.field_nullable self.policy = config.policy self.max_depth = config.max_depth + self.max_container_memory_bytes = config.max_container_memory_bytes + self.container_memory_limit_bytes = 0 + self.remaining_container_memory_bytes = 0 self.ref_reader = MapRefReader() if self.track_ref else NoRefReader() self.meta_string_reader = MetaStringReader(type_resolver.shared_registry) self.meta_share_context = MetaShareReadContext() if config.scoped_meta_share_enabled else None @@ -520,11 +536,26 @@ def prepare( buffers=None, unsupported_objects=None, peer_out_of_band_enabled=False, + root_input_bytes=None, ): + if self.max_container_memory_bytes > 0: + limit = self.max_container_memory_bytes + elif buffer.has_input_stream(): + limit = _STREAM_ROOT_BUDGET_BYTES + else: + if root_input_bytes is None: + root_input_bytes = buffer.size() - buffer.get_reader_index() + if root_input_bytes < 0: + raise ValueError("root input byte count is negative") + if root_input_bytes > (_MAX_CONTAINER_MEMORY_BYTES - _KNOWN_ROOT_BUDGET_SLACK_BYTES) // _KNOWN_ROOT_BUDGET_MULTIPLIER: + raise ValueError("max_container_memory_bytes auto budget overflow") + limit = root_input_bytes * _KNOWN_ROOT_BUDGET_MULTIPLIER + _KNOWN_ROOT_BUDGET_SLACK_BYTES self.buffer = buffer self.buffers = iter(buffers) if buffers is not None else None self.unsupported_objects = iter(unsupported_objects) if unsupported_objects is not None else None self.peer_out_of_band_enabled = peer_out_of_band_enabled + self.container_memory_limit_bytes = limit + self.remaining_container_memory_bytes = limit self.depth = 0 def reset(self): @@ -538,8 +569,40 @@ def reset(self): self.buffers = None self.unsupported_objects = None self.peer_out_of_band_enabled = False + self.container_memory_limit_bytes = 0 + self.remaining_container_memory_bytes = 0 self.depth = 0 + def reserve_container_memory(self, num_bytes): + if num_bytes < 0: + raise ValueError("Estimated container memory is negative") + if num_bytes > _MAX_CONTAINER_MEMORY_BYTES: + raise ValueError("Estimated container memory overflow") + remaining = self.remaining_container_memory_bytes + if num_bytes > remaining: + used = self.container_memory_limit_bytes - remaining + raise ValueError( + f"Estimated container memory budget exceeded: requested {num_bytes} bytes, " + f"used {used} bytes, limit {self.container_memory_limit_bytes} bytes. " + "Increase Fory(..., max_container_memory_bytes=...) for trusted larger payloads." + ) + self.remaining_container_memory_bytes = remaining - num_bytes + + def reserve_collection_memory(self, num_elements): + if num_elements < 0: + raise ValueError("Container element count is negative") + if num_elements > (_MAX_CONTAINER_MEMORY_BYTES - _COLLECTION_OBJECT_BYTES) // _REFERENCE_BYTES: + raise ValueError("Estimated container memory overflow") + self.reserve_container_memory(_COLLECTION_OBJECT_BYTES + num_elements * _REFERENCE_BYTES) + + def reserve_map_memory(self, num_elements): + if num_elements < 0: + raise ValueError("Map entry count is negative") + bytes_per_entry = _MAP_ENTRY_BYTES + 5 * _REFERENCE_BYTES + if num_elements > (_MAX_CONTAINER_MEMORY_BYTES - _MAP_OBJECT_BYTES) // bytes_per_entry: + raise ValueError("Estimated container memory overflow") + self.reserve_container_memory(_MAP_OBJECT_BYTES + num_elements * bytes_per_entry) + def add_context_object(self, key, obj): self.context_objects[id(key)] = obj diff --git a/python/pyfory/serialization.pyx b/python/pyfory/serialization.pyx index 899adcaf3c..2e4fede422 100644 --- a/python/pyfory/serialization.pyx +++ b/python/pyfory/serialization.pyx @@ -113,6 +113,8 @@ cdef class Config: max_type_meta_bytes: Maximum accepted body size in one received TypeDef. max_schema_versions_per_type: Maximum accepted remote metadata versions for one logical type. max_average_schema_versions_per_type: Average remote schema versions allowed across accepted remote types. + max_container_memory_bytes: Maximum estimated container-owned memory per root + deserialization. -1 means auto; positive values are explicit byte limits. field_nullable: Treats struct/dataclass fields as nullable by default. policy: Deserialization policy used for security-sensitive checks. meta_compressor: Optional typedef/meta compressor implementation. @@ -129,6 +131,7 @@ cdef class Config: cdef public int32_t max_type_meta_bytes cdef public int32_t max_schema_versions_per_type cdef public int32_t max_average_schema_versions_per_type + cdef public int64_t max_container_memory_bytes cdef public bint field_nullable cdef public object policy cdef public object meta_compressor @@ -147,6 +150,7 @@ cdef class Config: max_type_meta_bytes, max_schema_versions_per_type, max_average_schema_versions_per_type, + max_container_memory_bytes, field_nullable, policy, meta_compressor, @@ -166,6 +170,8 @@ cdef class Config: max_type_meta_bytes: Maximum accepted body size in one received TypeDef. max_schema_versions_per_type: Maximum accepted remote metadata versions for one logical type. max_average_schema_versions_per_type: Average remote schema versions allowed across accepted remote types. + max_container_memory_bytes: Maximum estimated container-owned memory per root + deserialization. -1 means auto; positive values are explicit byte limits. field_nullable: Treat all struct fields as nullable by default. policy: Deserialization policy implementation. meta_compressor: Optional typedef/meta compressor. @@ -185,10 +191,17 @@ cdef class Config: raise ValueError("max_schema_versions_per_type must be a positive integer") if max_average_schema_versions_per_type <= 0: raise ValueError("max_average_schema_versions_per_type must be a positive integer") + if ( + not isinstance(max_container_memory_bytes, int) + or (max_container_memory_bytes != -1 and max_container_memory_bytes <= 0) + or max_container_memory_bytes > 9223372036854775807 + ): + raise ValueError("max_container_memory_bytes must be -1 or a positive 63-bit integer") self.max_type_fields = max_type_fields self.max_type_meta_bytes = max_type_meta_bytes self.max_schema_versions_per_type = max_schema_versions_per_type self.max_average_schema_versions_per_type = max_average_schema_versions_per_type + self.max_container_memory_bytes = max_container_memory_bytes self.field_nullable = field_nullable self.policy = policy self.meta_compressor = meta_compressor @@ -829,6 +842,7 @@ cdef class Fory: cdef public bint compatible cdef public bint field_nullable cdef public int32_t max_depth + cdef public int64_t max_container_memory_bytes cdef public object policy cdef public Config config cdef public TypeResolver type_resolver @@ -847,6 +861,7 @@ cdef class Fory: max_type_meta_bytes=4096, max_schema_versions_per_type=10, max_average_schema_versions_per_type=3, + max_container_memory_bytes=-1, policy=None, field_nullable=False, meta_compressor=None, @@ -865,6 +880,8 @@ cdef class Fory: max_type_meta_bytes: Maximum accepted body size in one received TypeDef. max_schema_versions_per_type: Maximum accepted remote metadata versions for one logical type. max_average_schema_versions_per_type: Average remote schema versions allowed across accepted remote types. + max_container_memory_bytes: Maximum estimated container-owned memory per root + deserialization. -1 means auto; positive values are explicit byte limits. policy: Optional deserialization policy implementation. field_nullable: Treat struct fields as nullable by default. meta_compressor: Optional typedef/meta compressor implementation. @@ -882,6 +899,13 @@ cdef class Fory: self.compatible = compatible self.field_nullable = field_nullable self.max_depth = max_depth + if ( + not isinstance(max_container_memory_bytes, int) + or (max_container_memory_bytes != -1 and max_container_memory_bytes <= 0) + or max_container_memory_bytes > 9223372036854775807 + ): + raise ValueError("max_container_memory_bytes must be -1 or a positive 63-bit integer") + self.max_container_memory_bytes = max_container_memory_bytes self.config = Config( xlang=xlang, track_ref=ref, @@ -894,6 +918,7 @@ cdef class Fory: max_type_meta_bytes=max_type_meta_bytes, max_schema_versions_per_type=max_schema_versions_per_type, max_average_schema_versions_per_type=max_average_schema_versions_per_type, + max_container_memory_bytes=max_container_memory_bytes, field_nullable=field_nullable, policy=self.policy, meta_compressor=meta_compressor, @@ -1051,6 +1076,8 @@ cdef class Fory: cdef int32_t reader_index cdef uint8_t bitmap cdef bint peer_out_of_band_enabled + cdef int64_t root_input_bytes + cdef int64_t container_memory_limit if isinstance(buffer, bytes): buffer = Buffer(buffer) read_buffer = buffer @@ -1066,6 +1093,13 @@ cdef class Fory: raise ValueError("Out-of-band buffers are required by the root header") if not peer_out_of_band_enabled and buffers is not None: raise ValueError("Out-of-band buffers were provided for an in-band root payload") + if self.max_container_memory_bytes > 0: + container_memory_limit = self.max_container_memory_bytes + elif read_buffer.has_input_stream(): + container_memory_limit = _STREAM_ROOT_BUDGET_BYTES + else: + root_input_bytes = read_buffer.size() - reader_index + container_memory_limit = root_input_bytes * _KNOWN_ROOT_BUDGET_MULTIPLIER + _KNOWN_ROOT_BUDGET_SLACK_BYTES # Keep the root context setup inline. Top-level deserialize is a hot path, # so it should not pay an extra method call just to bind the active buffer. read_context.buffer = read_buffer @@ -1075,6 +1109,8 @@ cdef class Fory: iter(unsupported_objects) if unsupported_objects is not None else None ) read_context.peer_out_of_band_enabled = peer_out_of_band_enabled + read_context.container_memory_limit_bytes = container_memory_limit + read_context.remaining_container_memory_bytes = container_memory_limit read_context.depth = 0 return read_context.read_ref() diff --git a/python/pyfory/serializer.py b/python/pyfory/serializer.py index d3e43de30f..8ed4aa2255 100644 --- a/python/pyfory/serializer.py +++ b/python/pyfory/serializer.py @@ -933,6 +933,7 @@ def read(self, read_context): if dtype.kind == "O": length = read_context.read_varint32() _check_non_negative_size(length, "ndarray object") + read_context.reserve_collection_memory(length) read_context.check_readable_bytes(length) items = [read_context.read_ref() for _ in range(length)] return np.array(items, dtype=object) diff --git a/python/pyfory/tests/test_container_memory_budget.py b/python/pyfory/tests/test_container_memory_budget.py new file mode 100644 index 0000000000..09069d412b --- /dev/null +++ b/python/pyfory/tests/test_container_memory_budget.py @@ -0,0 +1,220 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import array +import struct + +import pytest + +import pyfory +from pyfory.serialization import Buffer +from pyfory.serializer import ListSerializer + +try: + import numpy as np +except ImportError: + np = None + + +KNOWN_ROOT_BUDGET_MULTIPLIER = 8 +KNOWN_ROOT_BUDGET_SLACK_BYTES = 64 * 1024 +STREAM_ROOT_BUDGET_BYTES = 128 * 1024 * 1024 +COLLECTION_OBJECT_BYTES = 56 +MAP_OBJECT_BYTES = 64 +MAP_ENTRY_BYTES = 32 +REFERENCE_BYTES = struct.calcsize("P") +MAX_CONTAINER_MEMORY_BYTES = (1 << 63) - 1 + + +class OneByteStream: + def __init__(self, data: bytes): + self._data = data + self._offset = 0 + + def read(self, size=-1): + if self._offset >= len(self._data): + return b"" + if size < 0: + size = len(self._data) - self._offset + if size == 0: + return b"" + read_size = min(1, size, len(self._data) - self._offset) + start = self._offset + self._offset += read_size + return self._data[start : start + read_size] + + def readinto(self, buffer): + if self._offset >= len(self._data): + return 0 + view = memoryview(buffer).cast("B") + if len(view) == 0: + return 0 + read_size = min(1, len(view), len(self._data) - self._offset) + start = self._offset + self._offset += read_size + view[:read_size] = self._data[start : start + read_size] + return read_size + + def recv_into(self, buffer, size=-1): + if self._offset >= len(self._data): + return 0 + view = memoryview(buffer).cast("B") + if size < 0 or size > len(view): + size = len(view) + if size == 0: + return 0 + read_size = min(1, size, len(self._data) - self._offset) + start = self._offset + self._offset += read_size + view[:read_size] = self._data[start : start + read_size] + return read_size + + +def collection_memory(num_elements): + return COLLECTION_OBJECT_BYTES + num_elements * REFERENCE_BYTES + + +def map_memory(num_entries): + return MAP_OBJECT_BYTES + num_entries * (MAP_ENTRY_BYTES + 5 * REFERENCE_BYTES) + + +def new_fory(limit=-1, *, xlang=True): + return pyfory.Fory( + xlang=xlang, + ref=True, + strict=False, + compatible=xlang, + max_container_memory_bytes=limit, + ) + + +def expect_budget(value, budget, *, xlang=True): + writer = new_fory(xlang=xlang) + data = writer.serialize(value) + with pytest.raises(ValueError, match="Estimated container memory budget exceeded"): + new_fory(budget - 1, xlang=xlang).deserialize(data) + return new_fory(budget, xlang=xlang).deserialize(data) + + +def varuint_payload(value): + buffer = Buffer.allocate(16) + buffer.write_var_uint32(value) + return buffer.to_bytes(0, buffer.get_writer_index()) + + +def test_known_length_auto_budget(): + fory = new_fory(xlang=False) + root_input_bytes = 17 + try: + fory.read_context.prepare(Buffer(b"x" * root_input_bytes), root_input_bytes=root_input_bytes) + expected = root_input_bytes * KNOWN_ROOT_BUDGET_MULTIPLIER + KNOWN_ROOT_BUDGET_SLACK_BYTES + assert fory.read_context.container_memory_limit_bytes == expected + fory.read_context.reserve_container_memory(expected) + with pytest.raises(ValueError, match="Estimated container memory budget exceeded"): + fory.read_context.reserve_container_memory(1) + finally: + fory.reset_read() + + +def test_stream_auto_budget(): + fory = new_fory(xlang=False) + try: + buffer = Buffer.from_stream(OneByteStream(b"streamed")) + fory.read_context.prepare(buffer, root_input_bytes=1) + assert fory.read_context.container_memory_limit_bytes == STREAM_ROOT_BUDGET_BYTES + finally: + fory.reset_read() + + +def test_explicit_config_overrides_auto(): + value = [1] + budget = collection_memory(1) + assert expect_budget(value, budget) == value + + +def test_nested_empty_containers_charge_fixed_cost(): + value = [[]] + budget = collection_memory(1) + collection_memory(0) + assert expect_budget(value, budget) == value + + +def test_sibling_nested_containers_are_cumulative(): + value = [[], [], []] + budget = collection_memory(3) + 3 * collection_memory(0) + assert expect_budget(value, budget) == value + + +def test_map_entry_budget_and_overflow(): + value = {"a": 1} + assert expect_budget(value, map_memory(1)) == value + + fory = new_fory(xlang=False) + try: + fory.read_context.prepare(Buffer(b""), root_input_bytes=0) + max_map_entries = (MAX_CONTAINER_MEMORY_BYTES - MAP_OBJECT_BYTES) // (MAP_ENTRY_BYTES + 5 * REFERENCE_BYTES) + with pytest.raises(ValueError, match="Estimated container memory overflow"): + fory.read_context.reserve_map_memory(max_map_entries + 1) + finally: + fory.reset_read() + + +def test_object_reference_array_budget(): + value = (1, 2, 3) + assert expect_budget(value, collection_memory(3), xlang=False) == value + + +def test_object_ndarray_budget(): + if np is None: + pytest.skip("numpy is not installed") + value = np.array([1, 2, 3], dtype=object) + restored = expect_budget(value, collection_memory(3), xlang=False) + np.testing.assert_array_equal(restored, value) + + +def test_string_binary_and_dense_arrays_skip_budget(): + values = [ + "x" * 256, + b"x" * 256, + array.array("i", range(32)), + ] + if np is not None: + values.append(np.array(list(range(32)), dtype=np.int32)) + for value in values: + fory = new_fory(1, xlang=False) + restored = fory.deserialize(fory.serialize(value)) + if np is not None and isinstance(value, np.ndarray): + np.testing.assert_array_equal(restored, value) + else: + assert restored == value + + +def test_declared_large_list_still_needs_bytes(): + fory = new_fory(10_000_000, xlang=False) + serializer = ListSerializer(fory.type_resolver, list) + try: + fory.read_context.prepare(Buffer(varuint_payload(1000)), root_input_bytes=1) + with pytest.raises(Exception) as exc_info: + serializer.read(fory.read_context) + assert "Estimated container memory" not in str(exc_info.value) + finally: + fory.reset_read() + + +@pytest.mark.parametrize("limit", [0, -2, 1 << 63]) +def test_invalid_config(limit): + with pytest.raises(ValueError, match="max_container_memory_bytes"): + new_fory(limit) diff --git a/rust/fory-core/src/config.rs b/rust/fory-core/src/config.rs index 167d67e72b..f5d003cd0f 100644 --- a/rust/fory-core/src/config.rs +++ b/rust/fory-core/src/config.rs @@ -40,6 +40,9 @@ pub struct Config { /// When enabled, shared references and circular references are tracked /// and preserved during serialization/deserialization. pub track_ref: bool, + /// Maximum estimated container-owned memory accepted during one root deserialization. + /// `-1` selects the automatic input-shaped limit. + pub max_container_memory_bytes: i64, /// Maximum accepted field count in one received struct TypeMeta. pub max_type_fields: u32, /// Maximum accepted body size in one received TypeMeta. @@ -61,6 +64,7 @@ impl Default for Config { max_dyn_depth: 5, check_struct_version: false, track_ref: false, + max_container_memory_bytes: -1, max_type_fields: 512, max_type_meta_bytes: 4096, max_schema_versions_per_type: 10, @@ -123,6 +127,12 @@ impl Config { self.track_ref } + /// Get maximum estimated container-owned memory per root deserialization. + #[inline(always)] + pub fn max_container_memory_bytes(&self) -> i64 { + self.max_container_memory_bytes + } + /// Get maximum accepted field count in one received struct TypeMeta. #[inline(always)] pub fn max_type_fields(&self) -> usize { diff --git a/rust/fory-core/src/context.rs b/rust/fory-core/src/context.rs index 260f94ea4c..f36e150d2f 100644 --- a/rust/fory-core/src/context.rs +++ b/rust/fory-core/src/context.rs @@ -31,6 +31,13 @@ use crate::type_id as types; use crate::TypeId; use std::rc::Rc; +const KNOWN_ROOT_BUDGET_MULTIPLIER: usize = 8; +const KNOWN_ROOT_BUDGET_SLACK_BYTES: usize = 64 * 1024; +const VEC_OBJECT_BYTES: usize = mem::size_of::>(); +const MAP_ENTRY_OVERHEAD_BYTES: usize = 16; +const REFERENCE_SLOT_BYTES: usize = mem::size_of::(); +const MAX_CONTAINER_LEN: usize = u32::MAX as usize; + /// Thread-local context cache with fast path for single Fory instance. /// Uses (cached_id, context) for O(1) access when using same Fory instance repeatedly. /// Falls back to HashMap for multiple Fory instances per thread. @@ -359,6 +366,9 @@ pub struct ReadContext<'a> { max_dyn_depth: u32, check_struct_version: bool, check_string_read: bool, + max_container_memory_bytes: i64, + container_memory_limit_bytes: usize, + remaining_container_memory_bytes: usize, // Context-specific fields pub reader: Reader<'a>, @@ -388,6 +398,9 @@ impl<'a> ReadContext<'a> { max_dyn_depth: config.max_dyn_depth, check_struct_version: config.check_struct_version, check_string_read: config.check_string_read, + max_container_memory_bytes: config.max_container_memory_bytes, + container_memory_limit_bytes: 0, + remaining_container_memory_bytes: 0, reader: Reader::default(), meta_resolver: MetaReaderResolver::default(), meta_string_resolver: MetaStringReaderResolver::default(), @@ -443,6 +456,112 @@ impl<'a> ReadContext<'a> { self.reader = reader; } + #[inline(always)] + pub(crate) fn init_container_memory_budget( + &mut self, + root_input_bytes: usize, + ) -> Result<(), Error> { + let limit = if self.max_container_memory_bytes > 0 { + usize::try_from(self.max_container_memory_bytes).map_err(|_| { + container_memory_error("max_container_memory_bytes does not fit usize") + })? + } else { + if root_input_bytes + > (usize::MAX - KNOWN_ROOT_BUDGET_SLACK_BYTES) / KNOWN_ROOT_BUDGET_MULTIPLIER + { + return Err(container_memory_error( + "root input size overflows automatic container memory budget", + )); + } + root_input_bytes * KNOWN_ROOT_BUDGET_MULTIPLIER + KNOWN_ROOT_BUDGET_SLACK_BYTES + }; + self.container_memory_limit_bytes = limit; + self.remaining_container_memory_bytes = limit; + Ok(()) + } + + #[inline(always)] + pub(crate) fn reserve_vec_memory(&mut self, len: u32) -> Result { + let len = len as usize; + self.reserve_counted_memory(len, VEC_OBJECT_BYTES, mem::size_of::())?; + Ok(len) + } + + #[inline(always)] + pub(crate) fn reserve_collection_memory(&mut self, len: u32) -> Result { + let len = len as usize; + let elem_size = mem::size_of::(); + if elem_size > usize::MAX - REFERENCE_SLOT_BYTES { + return Err(container_memory_overflow(len, elem_size)); + } + let elem_bytes = elem_size + REFERENCE_SLOT_BYTES; + self.reserve_counted_memory(len, mem::size_of::(), elem_bytes)?; + Ok(len) + } + + #[inline(always)] + pub(crate) fn reserve_map_memory(&mut self, len: u32) -> Result { + let len = len as usize; + let key_size = mem::size_of::(); + let value_size = mem::size_of::(); + let overhead = MAP_ENTRY_OVERHEAD_BYTES + REFERENCE_SLOT_BYTES * 3; + if key_size > usize::MAX - value_size || key_size + value_size > usize::MAX - overhead { + return Err(container_memory_overflow(len, key_size)); + } + let elem_bytes = key_size + value_size + overhead; + self.reserve_counted_memory(len, mem::size_of::(), elem_bytes)?; + Ok(len) + } + + #[inline(always)] + pub(crate) fn reserve_container_bytes(&mut self, bytes: usize) -> Result<(), Error> { + let remaining = self.remaining_container_memory_bytes; + if bytes > remaining { + return Err(container_memory_exceeded( + bytes, + remaining, + self.container_memory_limit_bytes, + )); + } + self.remaining_container_memory_bytes = remaining - bytes; + Ok(()) + } + + #[inline(always)] + fn reserve_counted_memory( + &mut self, + len: usize, + fixed_bytes: usize, + elem_bytes: usize, + ) -> Result<(), Error> { + if len == 0 { + return self.reserve_container_bytes(fixed_bytes); + } + if elem_bytes <= (usize::MAX - fixed_bytes) / MAX_CONTAINER_LEN { + return self.reserve_container_bytes(len * elem_bytes + fixed_bytes); + } + self.reserve_counted_memory_checked(len, fixed_bytes, elem_bytes) + } + + #[cold] + #[inline(never)] + fn reserve_counted_memory_checked( + &mut self, + len: usize, + fixed_bytes: usize, + elem_bytes: usize, + ) -> Result<(), Error> { + let elem_total = match len.checked_mul(elem_bytes) { + Some(bytes) => bytes, + None => return Err(container_memory_overflow(len, elem_bytes)), + }; + let bytes = match elem_total.checked_add(fixed_bytes) { + Some(bytes) => bytes, + None => return Err(container_memory_overflow(len, elem_bytes)), + }; + self.reserve_container_bytes(bytes) + } + #[inline(always)] pub fn detach_reader(&mut self) -> Reader<'_> { mem::take(&mut self.reader) @@ -552,3 +671,27 @@ impl<'a> ReadContext<'a> { self.current_depth = 0; } } + +#[cold] +#[inline(never)] +fn container_memory_error(message: &'static str) -> Error { + Error::invalid_data(message) +} + +#[cold] +#[inline(never)] +fn container_memory_overflow(len: usize, elem_bytes: usize) -> Error { + Error::invalid_data(format!( + "container memory estimate overflows: length={} elementBytes={}", + len, elem_bytes + )) +} + +#[cold] +#[inline(never)] +fn container_memory_exceeded(bytes: usize, remaining: usize, limit: usize) -> Error { + Error::invalid_data(format!( + "estimated container memory request {} bytes exceeds max_container_memory_bytes remaining budget {} bytes out of effective limit {} bytes", + bytes, remaining, limit + )) +} diff --git a/rust/fory-core/src/fory.rs b/rust/fory-core/src/fory.rs index 4b6c98419a..8eb9d3794f 100644 --- a/rust/fory-core/src/fory.rs +++ b/rust/fory-core/src/fory.rs @@ -261,6 +261,18 @@ impl ForyBuilder { self } + /// Sets the maximum estimated container-owned memory accepted during one root deserialization. + /// + /// Use `-1` for the automatic input-shaped limit. Positive values are explicit byte limits. + pub fn max_container_memory_bytes(mut self, max_bytes: i64) -> Self { + assert!( + max_bytes == -1 || max_bytes > 0, + "max_container_memory_bytes must be positive or -1 for auto" + ); + self.config.max_container_memory_bytes = max_bytes; + self + } + /// Sets the maximum depth for nested dynamic object serialization. /// /// # Arguments @@ -988,7 +1000,13 @@ impl Fory { self.with_read_context(|context| { let outlive_buffer = unsafe { mem::transmute::<&[u8], &[u8]>(bf) }; context.attach_reader(Reader::new(outlive_buffer)); - let result = self.deserialize_with_context(context); + let result = match context.init_container_memory_budget(bf.len()) { + Ok(()) => self.deserialize_with_context(context), + Err(err) => { + context.reset(); + Err(err) + } + }; context.detach_reader(); result }) @@ -1050,8 +1068,15 @@ impl Fory { let outlive_buffer = unsafe { mem::transmute::<&[u8], &[u8]>(reader.bf) }; let mut new_reader = Reader::new(outlive_buffer); new_reader.set_cursor(reader.cursor); + let root_input_bytes = reader.bf.len().saturating_sub(reader.cursor); context.attach_reader(new_reader); - let result = self.deserialize_with_context(context); + let result = match context.init_container_memory_budget(root_input_bytes) { + Ok(()) => self.deserialize_with_context(context), + Err(err) => { + context.reset(); + Err(err) + } + }; let end = context.detach_reader().get_cursor(); reader.set_cursor(end); result diff --git a/rust/fory-core/src/serializer/codec.rs b/rust/fory-core/src/serializer/codec.rs index 34059103f5..675fc133e9 100644 --- a/rust/fory-core/src/serializer/codec.rs +++ b/rust/fory-core/src/serializer/codec.rs @@ -1700,6 +1700,7 @@ where fn read_data(context: &mut ReadContext) -> Result, Error> { let len = context.reader.read_var_u32()?; + context.reserve_vec_memory::(len)?; if len == 0 { return Ok(Vec::new()); } @@ -1728,6 +1729,7 @@ where remote_field_type: &FieldType, ) -> Result, Error> { let len = context.reader.read_var_u32()?; + context.reserve_vec_memory::(len)?; if len == 0 { return Ok(Vec::new()); } @@ -2270,6 +2272,7 @@ where fn read_data(context: &mut ReadContext) -> Result, Error> { let len = context.reader.read_var_u32()?; + context.reserve_map_memory::, K, V>(len)?; if len == 0 { return Ok(HashMap::new()); } @@ -2289,6 +2292,7 @@ where remote_field_type: &FieldType, ) -> Result, Error> { let len = context.reader.read_var_u32()?; + let capacity = context.reserve_map_memory::, K, V>(len)?; if len == 0 { return Ok(HashMap::new()); } @@ -2299,7 +2303,8 @@ where { return read_map_dynamic::(context, len, remote_field_type); } - let mut map = HashMap::with_capacity(check_map_len(context, len)?); + context.reader.check_bound(capacity)?; + let mut map = HashMap::with_capacity(capacity); let mut len_counter = 0; while len_counter < len { let header = context.reader.read_u8()?; diff --git a/rust/fory-core/src/serializer/collection.rs b/rust/fory-core/src/serializer/collection.rs index ee16166bb4..b2dd1950f9 100644 --- a/rust/fory-core/src/serializer/collection.rs +++ b/rust/fory-core/src/serializer/collection.rs @@ -239,6 +239,7 @@ where C: FromIterator, { let len = context.reader.read_var_u32()?; + let len_usize = context.reserve_collection_memory::(len)?; if len == 0 { return Ok(C::from_iter(std::iter::empty())); } @@ -257,7 +258,7 @@ where (header & IS_SAME_TYPE) != 0, Error::type_error("Type inconsistent, target type is not polymorphic") ); - let _ = check_collection_len(context, len)?; + context.reader.check_bound(len_usize)?; if !has_null { (0..len) .map(|_| T::fory_read_data(context)) @@ -281,6 +282,7 @@ where T: Serializer + ForyDefault, { let len = context.reader.read_var_u32()?; + let len_usize = context.reserve_vec_memory::(len)?; if len == 0 { return Ok(Vec::new()); } @@ -297,7 +299,8 @@ where (header & IS_SAME_TYPE) != 0, Error::type_error("Type inconsistent, target type is not polymorphic") ); - let mut vec = Vec::with_capacity(check_collection_len(context, len)?); + context.reader.check_bound(len_usize)?; + let mut vec = Vec::with_capacity(len_usize); if !has_null { for _ in 0..len { vec.push(T::fory_read_data(context)?); @@ -343,7 +346,8 @@ where } else { T::fory_get_type_info(context.get_type_resolver())? }; - let mut vec = Vec::with_capacity(check_collection_len(context, len)?); + let len_usize = check_collection_len(context, len)?; + let mut vec = Vec::with_capacity(len_usize); if elem_ref_mode == RefMode::None { for _ in 0..len { vec.push(T::fory_read_with_type_info( @@ -363,7 +367,8 @@ where } Ok(vec) } else { - let mut vec = Vec::with_capacity(check_collection_len(context, len)?); + let len_usize = check_collection_len(context, len)?; + let mut vec = Vec::with_capacity(len_usize); for _ in 0..len { vec.push(T::fory_read(context, elem_ref_mode, true)?); } @@ -724,6 +729,7 @@ where { let element_type = generic_field_type(remote_field_type, 0, "list")?; let len = context.reader.read_var_u32()?; + let len_usize = context.reserve_vec_memory::(len)?; if len == 0 { return Ok(Vec::new()); } @@ -748,8 +754,8 @@ where "array-compatible list must declare element type", )); } - context.reader.check_bound(len as usize)?; - let mut vec = Vec::with_capacity(len as usize); + context.reader.check_bound(len_usize)?; + let mut vec = Vec::with_capacity(len_usize); for _ in 0..len { vec.push(T::read_list_array_element(context, element_type.type_id)?); } diff --git a/rust/fory-core/src/serializer/map.rs b/rust/fory-core/src/serializer/map.rs index 3d0dc094e7..158e020edc 100644 --- a/rust/fory-core/src/serializer/map.rs +++ b/rust/fory-core/src/serializer/map.rs @@ -35,12 +35,6 @@ const TRACKING_VALUE_REF: u8 = 0b1000; pub const VALUE_NULL: u8 = 0b10000; pub const DECL_VALUE_TYPE: u8 = 0b100000; -fn check_map_len(context: &ReadContext, len: u32) -> Result { - let len = len as usize; - context.reader.check_bound(len)?; - Ok(len) -} - fn write_chunk_size(context: &mut WriteContext, header_offset: usize, size: u8) { context.writer.set_bytes(header_offset + 1, &[size]); } @@ -559,10 +553,11 @@ impl Result { let len = context.reader.read_var_u32()?; + let capacity = context.reserve_map_memory::, K, V>(len)?; if len == 0 { return Ok(HashMap::new()); } - let capacity = check_map_len(context, len)?; + context.reader.check_bound(capacity)?; if K::fory_is_polymorphic() || K::fory_is_shared_ref() || V::fory_is_polymorphic() @@ -711,10 +706,11 @@ impl Result { let len = context.reader.read_var_u32()?; + let len_usize = context.reserve_map_memory::, K, V>(len)?; if len == 0 { return Ok(BTreeMap::new()); } - let _ = check_map_len(context, len)?; + context.reader.check_bound(len_usize)?; let mut map = BTreeMap::::new(); if K::fory_is_polymorphic() || K::fory_is_shared_ref() diff --git a/rust/tests/tests/mod.rs b/rust/tests/tests/mod.rs index c66d727c8a..74f62c87d3 100644 --- a/rust/tests/tests/mod.rs +++ b/rust/tests/tests/mod.rs @@ -18,6 +18,7 @@ mod compatible; mod test_any; mod test_collection; +mod test_container_memory_budget; mod test_field_meta; mod test_max_dyn_depth; mod test_tuple; diff --git a/rust/tests/tests/test_container_memory_budget.rs b/rust/tests/tests/test_container_memory_budget.rs new file mode 100644 index 0000000000..29f70d10bf --- /dev/null +++ b/rust/tests/tests/test_container_memory_budget.rs @@ -0,0 +1,244 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use fory_core::{Error, Fory, Reader}; +use fory_derive::ForyStruct; +use std::collections::HashMap; +use std::panic; + +#[derive(ForyStruct, Debug, PartialEq)] +struct BudgetSiblings { + first: Vec, + second: Vec, +} + +#[derive(ForyStruct, Debug, PartialEq)] +struct BudgetItem { + left: u64, + right: u64, +} + +#[derive(ForyStruct, Debug)] +struct ListWireInts { + values: Vec>, +} + +#[derive(ForyStruct, Debug, PartialEq)] +struct DenseWireInts { + values: Vec, +} + +fn fory_with_budget(max_container_memory_bytes: i64) -> Fory { + let mut fory = Fory::builder() + .xlang(false) + .compatible(false) + .max_container_memory_bytes(max_container_memory_bytes) + .build(); + fory.register_by_name::("BudgetSiblings") + .unwrap(); + fory.register_by_name::("BudgetItem").unwrap(); + fory +} + +fn compatible_fory(max_container_memory_bytes: i64) -> Fory +where + T: fory_core::Serializer + fory_core::StructSerializer + fory_core::ForyDefault, +{ + let mut fory = Fory::builder() + .xlang(false) + .compatible(true) + .max_container_memory_bytes(max_container_memory_bytes) + .build(); + fory.register::(88_001).unwrap(); + fory +} + +fn compact_empty_lists(count: usize) -> Vec> { + (0..count).map(|_| Vec::new()).collect() +} + +fn assert_budget_error(err: Error, effective_limit: usize) { + let message = err.to_string(); + assert!( + message.contains("estimated container memory request"), + "{message}" + ); + assert!( + message.contains(&format!("effective limit {effective_limit}")), + "{message}" + ); +} + +#[test] +fn config_validation() { + assert!(panic::catch_unwind(|| Fory::builder().max_container_memory_bytes(0)).is_err()); + assert!(panic::catch_unwind(|| Fory::builder().max_container_memory_bytes(-2)).is_err()); + let _ = Fory::builder().max_container_memory_bytes(-1).build(); + let _ = Fory::builder().max_container_memory_bytes(1).build(); +} + +#[test] +fn known_auto_budget() { + let value = compact_empty_lists(3000); + let writer = fory_with_budget(-1); + let bytes = writer.serialize(&value).unwrap(); + let auto_limit = bytes.len() * 8 + 64 * 1024; + + let err = writer.deserialize::>>(&bytes).unwrap_err(); + assert_budget_error(err, auto_limit); + + let explicit = fory_with_budget(auto_limit as i64); + let err = explicit + .deserialize::>>(&bytes) + .unwrap_err(); + assert_budget_error(err, auto_limit); +} + +#[test] +fn reader_known_auto_budget() { + let value = compact_empty_lists(3000); + let writer = fory_with_budget(-1); + let bytes = writer.serialize(&value).unwrap(); + let auto_limit = bytes.len() * 8 + 64 * 1024; + + let mut reader = Reader::new(&bytes); + let err = writer + .deserialize_from::>>(&mut reader) + .unwrap_err(); + assert_budget_error(err, auto_limit); +} + +#[test] +fn explicit_override() { + let value = compact_empty_lists(3000); + let writer = fory_with_budget(-1); + let bytes = writer.serialize(&value).unwrap(); + assert!(writer.deserialize::>>(&bytes).is_err()); + + let vec_bytes = std::mem::size_of::>(); + let estimate = std::mem::size_of::>>() + value.len() * vec_bytes * 2; + let explicit = fory_with_budget(estimate as i64); + let decoded: Vec> = explicit.deserialize(&bytes).unwrap(); + assert_eq!(decoded, value); +} + +#[test] +fn empty_container_cost() { + let value: Vec = Vec::new(); + let writer = fory_with_budget(-1); + let bytes = writer.serialize(&value).unwrap(); + let fixed = std::mem::size_of::>() as i64; + + let limited = fory_with_budget(fixed - 1); + assert!(limited.deserialize::>(&bytes).is_err()); +} + +#[test] +fn sibling_cumulative_budget() { + let value = BudgetSiblings { + first: Vec::new(), + second: Vec::new(), + }; + let writer = fory_with_budget(-1); + let bytes = writer.serialize(&value).unwrap(); + let one_vec = std::mem::size_of::>() as i64; + + let limited = fory_with_budget(one_vec); + assert!(limited.deserialize::(&bytes).is_err()); +} + +#[test] +fn map_budget() { + let value: HashMap = HashMap::new(); + let writer = fory_with_budget(-1); + let bytes = writer.serialize(&value).unwrap(); + let fixed = std::mem::size_of::>() as i64; + + let limited = fory_with_budget(fixed - 1); + assert!(limited.deserialize::>(&bytes).is_err()); +} + +#[test] +fn inline_value_vec_budget() { + let value = (0..16) + .map(|i| BudgetItem { + left: i, + right: i + 1, + }) + .collect::>(); + let writer = fory_with_budget(-1); + let bytes = writer.serialize(&value).unwrap(); + let under_inline = + std::mem::size_of::>() + value.len() * std::mem::size_of::(); + + let limited = fory_with_budget(under_inline as i64); + assert!(limited.deserialize::>(&bytes).is_err()); +} + +#[test] +fn compatible_list_array_budget() { + let value = ListWireInts { + values: (0..64).map(Some).collect(), + }; + let writer = compatible_fory::(-1); + let bytes = writer.serialize(&value).unwrap(); + + let limited = compatible_fory::(std::mem::size_of::>() as i64); + assert!(limited.deserialize::(&bytes).is_err()); + + let enough = compatible_fory::(i64::MAX); + let decoded = enough.deserialize::(&bytes).unwrap(); + assert_eq!( + decoded, + DenseWireInts { + values: (0..64).collect() + } + ); +} + +#[test] +fn dense_paths_skipped() { + let fory = fory_with_budget(1); + + let string_bytes = fory_with_budget(-1) + .serialize(&"hello".to_string()) + .unwrap(); + let decoded: String = fory.deserialize(&string_bytes).unwrap(); + assert_eq!(decoded, "hello"); + + let binary = vec![1_u8, 2, 3, 4]; + let binary_bytes = fory_with_budget(-1).serialize(&binary).unwrap(); + let decoded: Vec = fory.deserialize(&binary_bytes).unwrap(); + assert_eq!(decoded, binary); + + let ints = vec![1_i32, 2, 3, 4]; + let int_bytes = fory_with_budget(-1).serialize(&ints).unwrap(); + let decoded: Vec = fory.deserialize(&int_bytes).unwrap(); + assert_eq!(decoded, ints); +} + +#[test] +fn byte_check_preserved() { + let writer = fory_with_budget(-1); + let mut bytes = writer.serialize(&Vec::::new()).unwrap(); + let last = bytes.len() - 1; + bytes[last] = 64; + + let reader = fory_with_budget(i64::MAX); + let err = reader.deserialize::>(&bytes).unwrap_err(); + assert!(matches!(err, Error::BufferOutOfBound(..)), "{err}"); +} diff --git a/scala/src/main/scala/org/apache/fory/serializer/scala/CollectionSerializer.scala b/scala/src/main/scala/org/apache/fory/serializer/scala/CollectionSerializer.scala index 066e24c629..8b688bcb37 100644 --- a/scala/src/main/scala/org/apache/fory/serializer/scala/CollectionSerializer.scala +++ b/scala/src/main/scala/org/apache/fory/serializer/scala/CollectionSerializer.scala @@ -53,15 +53,10 @@ abstract class AbstractScalaCollectionSerializer[A, T <: Iterable[A]]( value: T): util.Collection[_] override def newCollection(readContext: ReadContext): util.Collection[_] = { - val buffer = readContext.getBuffer - val numElements = buffer.readVarUInt32() - checkCollectionSize(numElements) + val numElements = readCollectionSize(readContext) setNumElements(numElements) val factory = readContext.readRef().asInstanceOf[Factory[A, T]] val builder = factory.newBuilder - if (numElements != 0) { - buffer.checkReadableBytes(numElements) - } builder.sizeHint(numElements) new JavaCollectionBuilder[A, T](builder) } diff --git a/scala/src/main/scala/org/apache/fory/serializer/scala/MapSerializer.scala b/scala/src/main/scala/org/apache/fory/serializer/scala/MapSerializer.scala index 9c21954b7d..3891361615 100644 --- a/scala/src/main/scala/org/apache/fory/serializer/scala/MapSerializer.scala +++ b/scala/src/main/scala/org/apache/fory/serializer/scala/MapSerializer.scala @@ -50,15 +50,10 @@ abstract class AbstractScalaMapSerializer[K, V, T](typeResolver: TypeResolver, c def onMapWrite(writeContext: WriteContext, value: T): util.Map[_, _] override def newMap(readContext: ReadContext): util.Map[_, _] = { - val buffer = readContext.getBuffer - val numElements = buffer.readVarUInt32() - checkMapSize(numElements) + val numElements = readMapSize(readContext) setNumElements(numElements) val factory = readContext.readRef().asInstanceOf[Factory[(K, V), T]] val builder = factory.newBuilder - if (numElements != 0) { - buffer.checkReadableBytes(numElements) - } builder.sizeHint(numElements) new MapBuilder[K, V, T](builder) } diff --git a/scala/src/main/scala/org/apache/fory/serializer/scala/XlangCollectionSerializer.scala b/scala/src/main/scala/org/apache/fory/serializer/scala/XlangCollectionSerializer.scala index 9eeab286d2..9439f3493e 100644 --- a/scala/src/main/scala/org/apache/fory/serializer/scala/XlangCollectionSerializer.scala +++ b/scala/src/main/scala/org/apache/fory/serializer/scala/XlangCollectionSerializer.scala @@ -43,12 +43,8 @@ abstract class AbstractScalaXlangCollectionSerializer[A, T <: scala.collection.I } override def newCollection(readContext: ReadContext): util.Collection[_] = { - val buffer = readContext.getBuffer - val numElements = readCollectionSize(buffer) + val numElements = readCollectionSize(readContext) setNumElements(numElements) - if (numElements != 0) { - buffer.checkReadableBytes(numElements) - } val builder = newBuilder(numElements) if (ScalaXlangCollectionShape.hasOptionElement(readContext)) { new XlangOptionCollectionBuilder[A, T](builder) @@ -368,12 +364,8 @@ abstract class AbstractScalaXlangMapSerializer[K, V, T <: scala.collection.Map[K } override def newMap(readContext: ReadContext): util.Map[_, _] = { - val buffer = readContext.getBuffer - val numElements = readMapSize(buffer) + val numElements = readMapSize(readContext) setNumElements(numElements) - if (numElements != 0) { - buffer.checkReadableBytes(numElements) - } val builder = ScalaXlangCollectionShape.mapBuilder[K, V, T](cls, numElements) val optionKey = ScalaXlangCollectionShape.hasOptionKey(readContext) diff --git a/scala/src/test/scala/org/apache/fory/serializer/scala/CollectionSerializerTest.scala b/scala/src/test/scala/org/apache/fory/serializer/scala/CollectionSerializerTest.scala index d1b7e67952..fc386639dd 100644 --- a/scala/src/test/scala/org/apache/fory/serializer/scala/CollectionSerializerTest.scala +++ b/scala/src/test/scala/org/apache/fory/serializer/scala/CollectionSerializerTest.scala @@ -20,6 +20,7 @@ package org.apache.fory.serializer.scala import org.apache.fory.Fory +import org.apache.fory.exception.InsecureException import org.apache.fory.scala.ForyScala import org.scalatest.matchers.should.Matchers import org.scalatest.wordspec.AnyWordSpec @@ -89,6 +90,37 @@ class CollectionSerializerTest extends AnyWordSpec with Matchers { } } } + + "fory scala container memory budget" should { + def runtime(maxContainerMemoryBytes: Long = -1): Fory = { + val builder = ForyScala.builder() + .withXlang(false) + .withRefTracking(true) + .requireClassRegistration(false) + .suppressClassRegistrationWarnings(false) + .withSerializerFactory(new ScalaSerializerFactory()) + if (maxContainerMemoryBytes > 0) { + builder.withMaxContainerMemoryBytes(maxContainerMemoryBytes) + } + builder.build() + } + + "charge scala collection fixed cost" in { + val writer = runtime() + val reader = runtime(maxContainerMemoryBytes = 23) + intercept[InsecureException] { + reader.deserialize(writer.serialize(List.empty[String])) + } + } + + "charge scala map fixed cost" in { + val writer = runtime() + val reader = runtime(maxContainerMemoryBytes = 47) + intercept[InsecureException] { + reader.deserialize(writer.serialize(Map("k" -> "v"))) + } + } + } } case class CollectionStruct1(list: List[String]) diff --git a/scala/src/test/scala/org/apache/fory/serializer/scala/ScalaXlangSerializerTest.scala b/scala/src/test/scala/org/apache/fory/serializer/scala/ScalaXlangSerializerTest.scala index 6cc4f880a1..b637bc4f6c 100644 --- a/scala/src/test/scala/org/apache/fory/serializer/scala/ScalaXlangSerializerTest.scala +++ b/scala/src/test/scala/org/apache/fory/serializer/scala/ScalaXlangSerializerTest.scala @@ -20,6 +20,7 @@ package org.apache.fory.serializer.scala import org.apache.fory.Fory +import org.apache.fory.exception.InsecureException import org.apache.fory.scala.ForyScala import org.scalatest.matchers.should.Matchers import org.scalatest.wordspec.AnyWordSpec @@ -120,5 +121,24 @@ class ScalaXlangSerializerTest extends AnyWordSpec with Matchers { copiedCyclic should not be theSameInstanceAs(cyclic) copiedCyclic(0) shouldBe theSameInstanceAs(copiedCyclic) } + + "enforce container memory budget" in { + val writer = fory + val reader = ForyScala.builder() + .withXlang(true) + .withRefTracking(true) + .withRefCopy(true) + .requireClassRegistration(false) + .suppressClassRegistrationWarnings(false) + .withMaxContainerMemoryBytes(23) + .build() + + intercept[InsecureException] { + reader.deserialize(writer.serialize(List.empty[String])) + } + intercept[InsecureException] { + reader.deserialize(writer.serialize(Map("k" -> 1))) + } + } } } diff --git a/swift/Sources/Fory/AnySerializer.swift b/swift/Sources/Fory/AnySerializer.swift index fdd99b38a4..ca31cce2da 100644 --- a/swift/Sources/Fory/AnySerializer.swift +++ b/swift/Sources/Fory/AnySerializer.swift @@ -570,7 +570,11 @@ public func readListOfAny( refMode: refMode, readTypeInfo: readTypeInfo ) - return wrapped?.map { $0.anyValueForCollection() } + guard let wrapped else { + return nil + } + try context.reserveReferenceArrayMemory(count: wrapped.count) + return wrapped.map { $0.anyValueForCollection() } } public func writeMapStringToAny( @@ -604,6 +608,7 @@ public func readMapStringToAny( guard let wrapped else { return nil } + try context.reserveReferenceMapMemory(count: wrapped.count) var map: [String: Any] = [:] map.reserveCapacity(wrapped.count) for pair in wrapped { @@ -643,6 +648,7 @@ public func readMapInt32ToAny( guard let wrapped else { return nil } + try context.reserveReferenceMapMemory(count: wrapped.count) var map: [Int32: Any] = [:] map.reserveCapacity(wrapped.count) for pair in wrapped { @@ -682,6 +688,7 @@ public func readMapAnyHashableToAny( guard let wrapped else { return nil } + try context.reserveReferenceMapMemory(count: wrapped.count) var map: [AnyHashable: Any] = [:] map.reserveCapacity(wrapped.count) for pair in wrapped { @@ -693,8 +700,10 @@ public func readMapAnyHashableToAny( func readDynamicAnyMapValue(context: ReadContext) throws -> Any { let map = try readMapAnyHashableToAny(context: context, refMode: .none) ?? [:] if map.isEmpty { + try context.reserveReferenceMapMemory(count: 0) return [String: Any]() } + try context.reserveReferenceMapMemory(count: map.count) var stringMap: [String: Any] = [:] stringMap.reserveCapacity(map.count) for pair in map { @@ -708,6 +717,7 @@ func readDynamicAnyMapValue(context: ReadContext) throws -> Any { return stringMap } + try context.reserveReferenceMapMemory(count: map.count) var int32Map: [Int32: Any] = [:] int32Map.reserveCapacity(map.count) for pair in map { diff --git a/swift/Sources/Fory/CollectionSerializers.swift b/swift/Sources/Fory/CollectionSerializers.swift index 1be59fb6b4..a7b943a4b6 100644 --- a/swift/Sources/Fory/CollectionSerializers.swift +++ b/swift/Sources/Fory/CollectionSerializers.swift @@ -234,18 +234,35 @@ func writePrimitiveArray(_ value: [Element], context: Write } } -func readPrimitiveArray(_ context: ReadContext) throws -> [Element] { +@inline(__always) +private func preparePrimitiveArray( + _ context: ReadContext, + chargeContainerMemory: Bool, + type: Element.Type, + count: Int, + label: String +) throws { + try context.ensureCollectionLength(count, label: label) + if chargeContainerMemory { + try context.reserveArrayMemory(type, count: count) + } +} + +func readPrimitiveArray( + _ context: ReadContext, + chargeContainerMemory: Bool = false +) throws -> [Element] { let byteSize = Int(try context.buffer.readVarUInt32()) try context.ensureRemainingBytes(byteSize, label: "primitive_array_bytes") if Element.self == UInt8.self { - try context.ensureCollectionLength(byteSize, label: "uint8_array") + try preparePrimitiveArray(context, chargeContainerMemory: chargeContainerMemory, type: Element.self, count: byteSize, label: "uint8_array") let bytes = try context.buffer.readBytes(count: byteSize) return uncheckedArrayCast(bytes, to: Element.self) } if Element.self == Bool.self { - try context.ensureCollectionLength(byteSize, label: "bool_array") + try preparePrimitiveArray(context, chargeContainerMemory: chargeContainerMemory, type: Element.self, count: byteSize, label: "bool_array") let out = try readArrayUninitialized(count: byteSize) { destination in for index in 0..(_ context: ReadContext) throws -> [ } if Element.self == Int8.self { - try context.ensureCollectionLength(byteSize, label: "int8_array") + try preparePrimitiveArray(context, chargeContainerMemory: chargeContainerMemory, type: Element.self, count: byteSize, label: "int8_array") var out = Array(repeating: Int8(0), count: byteSize) try out.withUnsafeMutableBytes { rawBytes in try context.buffer.readBytes(into: rawBytes) @@ -266,7 +283,7 @@ func readPrimitiveArray(_ context: ReadContext) throws -> [ if Element.self == Int16.self { if byteSize % 2 != 0 { throw ForyError.invalidData("int16 array byte size mismatch") } let count = byteSize / 2 - try context.ensureCollectionLength(count, label: "int16_array") + try preparePrimitiveArray(context, chargeContainerMemory: chargeContainerMemory, type: Element.self, count: count, label: "int16_array") if hostIsLittleEndian { var out = Array(repeating: Int16(0), count: count) try out.withUnsafeMutableBytes { rawBytes in @@ -285,7 +302,7 @@ func readPrimitiveArray(_ context: ReadContext) throws -> [ if Element.self == Int32.self { if byteSize % 4 != 0 { throw ForyError.invalidData("int32 array byte size mismatch") } let count = byteSize / 4 - try context.ensureCollectionLength(count, label: "int32_array") + try preparePrimitiveArray(context, chargeContainerMemory: chargeContainerMemory, type: Element.self, count: count, label: "int32_array") if hostIsLittleEndian { var out = Array(repeating: Int32(0), count: count) try out.withUnsafeMutableBytes { rawBytes in @@ -304,7 +321,7 @@ func readPrimitiveArray(_ context: ReadContext) throws -> [ if Element.self == UInt32.self { if byteSize % 4 != 0 { throw ForyError.invalidData("uint32 array byte size mismatch") } let count = byteSize / 4 - try context.ensureCollectionLength(count, label: "uint32_array") + try preparePrimitiveArray(context, chargeContainerMemory: chargeContainerMemory, type: Element.self, count: count, label: "uint32_array") if hostIsLittleEndian { var out = Array(repeating: UInt32(0), count: count) try out.withUnsafeMutableBytes { rawBytes in @@ -323,7 +340,7 @@ func readPrimitiveArray(_ context: ReadContext) throws -> [ if Element.self == Int64.self { if byteSize % 8 != 0 { throw ForyError.invalidData("int64 array byte size mismatch") } let count = byteSize / 8 - try context.ensureCollectionLength(count, label: "int64_array") + try preparePrimitiveArray(context, chargeContainerMemory: chargeContainerMemory, type: Element.self, count: count, label: "int64_array") if hostIsLittleEndian { var out = Array(repeating: Int64(0), count: count) try out.withUnsafeMutableBytes { rawBytes in @@ -342,7 +359,7 @@ func readPrimitiveArray(_ context: ReadContext) throws -> [ if Element.self == UInt64.self { if byteSize % 8 != 0 { throw ForyError.invalidData("uint64 array byte size mismatch") } let count = byteSize / 8 - try context.ensureCollectionLength(count, label: "uint64_array") + try preparePrimitiveArray(context, chargeContainerMemory: chargeContainerMemory, type: Element.self, count: count, label: "uint64_array") if hostIsLittleEndian { var out = Array(repeating: UInt64(0), count: count) try out.withUnsafeMutableBytes { rawBytes in @@ -361,7 +378,7 @@ func readPrimitiveArray(_ context: ReadContext) throws -> [ if Element.self == UInt16.self { if byteSize % 2 != 0 { throw ForyError.invalidData("uint16 array byte size mismatch") } let count = byteSize / 2 - try context.ensureCollectionLength(count, label: "uint16_array") + try preparePrimitiveArray(context, chargeContainerMemory: chargeContainerMemory, type: Element.self, count: count, label: "uint16_array") if hostIsLittleEndian { var out = Array(repeating: UInt16(0), count: count) try out.withUnsafeMutableBytes { rawBytes in @@ -380,7 +397,7 @@ func readPrimitiveArray(_ context: ReadContext) throws -> [ if Element.self == Float16.self { if byteSize % 2 != 0 { throw ForyError.invalidData("float16 array byte size mismatch") } let count = byteSize / 2 - try context.ensureCollectionLength(count, label: "float16_array") + try preparePrimitiveArray(context, chargeContainerMemory: chargeContainerMemory, type: Element.self, count: count, label: "float16_array") let values = try readArrayUninitialized(count: count) { destination in for index in 0..(_ context: ReadContext) throws -> [ if Element.self == BFloat16.self { if byteSize % 2 != 0 { throw ForyError.invalidData("bfloat16 array byte size mismatch") } let count = byteSize / 2 - try context.ensureCollectionLength(count, label: "bfloat16_array") + try preparePrimitiveArray(context, chargeContainerMemory: chargeContainerMemory, type: Element.self, count: count, label: "bfloat16_array") let values = try readArrayUninitialized(count: count) { destination in for index in 0..(_ context: ReadContext) throws -> [ if Element.self == Float.self { if byteSize % 4 != 0 { throw ForyError.invalidData("float32 array byte size mismatch") } let count = byteSize / 4 - try context.ensureCollectionLength(count, label: "float32_array") + try preparePrimitiveArray(context, chargeContainerMemory: chargeContainerMemory, type: Element.self, count: count, label: "float32_array") if hostIsLittleEndian { var out = Array(repeating: Float(0), count: count) try out.withUnsafeMutableBytes { rawBytes in @@ -422,7 +439,7 @@ func readPrimitiveArray(_ context: ReadContext) throws -> [ if byteSize % 8 != 0 { throw ForyError.invalidData("float64 array byte size mismatch") } let count = byteSize / 8 - try context.ensureCollectionLength(count, label: "float64_array") + try preparePrimitiveArray(context, chargeContainerMemory: chargeContainerMemory, type: Element.self, count: count, label: "float64_array") if hostIsLittleEndian { var out = Array(repeating: Double(0), count: count) try out.withUnsafeMutableBytes { rawBytes in @@ -532,6 +549,7 @@ extension Array: Serializer where Element: Serializer { let length = Int(try buffer.readVarUInt32()) try context.ensureCollectionLength(length, label: "array") if length == 0 { + try context.reserveArrayMemory(Element.self, count: length) return [] } @@ -541,6 +559,7 @@ extension Array: Serializer where Element: Serializer { let declared = (header & CollectionHeader.declaredElementType) != 0 let sameType = (header & CollectionHeader.sameType) != 0 if !sameType { + try context.reserveArrayMemory(Element.self, count: length) try context.ensureRemainingBytes(length, label: "array") if trackRef { return try readArrayUninitialized(count: length) { destination in @@ -579,6 +598,7 @@ extension Array: Serializer where Element: Serializer { } let elementTypeInfo = declared ? nil : try Element.foryReadTypeInfo(context) + try context.reserveArrayMemory(Element.self, count: length) try context.ensureRemainingBytes(length, label: "array") return try context.withTypeInfo(elementTypeInfo, for: Element.self) { if trackRef { @@ -637,7 +657,9 @@ extension Set: Serializer where Element: Serializer & Hashable { } public static func foryReadData(_ context: ReadContext) throws -> Set { - Set(try [Element].foryReadData(context)) + let values = try [Element].foryReadData(context) + try context.reserveSetMemory(Element.self, count: values.count) + return Set(values) } } @@ -864,11 +886,13 @@ extension Dictionary: Serializer where Key: Serializer & Hashable, Value: Serial let totalLength = Int(try context.buffer.readVarUInt32()) try context.ensureCollectionLength(totalLength, label: "map") if totalLength == 0 { + try context.reserveMapMemory(key: Key.self, value: Value.self, count: totalLength) return [:] } - var map: [Key: Value] = [:] + try context.reserveMapMemory(key: Key.self, value: Value.self, count: totalLength) try context.ensureRemainingBytes(totalLength, label: "map") + var map: [Key: Value] = [:] map.reserveCapacity(totalLength) let keyDynamicType = Key.staticTypeId == .unknown let valueDynamicType = Value.staticTypeId == .unknown diff --git a/swift/Sources/Fory/FieldCodecs.swift b/swift/Sources/Fory/FieldCodecs.swift index 3db01e7c3c..4eabf460e7 100644 --- a/swift/Sources/Fory/FieldCodecs.swift +++ b/swift/Sources/Fory/FieldCodecs.swift @@ -841,7 +841,9 @@ public enum SetFieldCodec: FieldCodec where ElementCod } public static func readPayload(_ context: ReadContext) throws -> Value { - Set(try readCollectionPayload(context, elementCodec: ElementCodec.self)) + let values = try readCollectionPayload(context, elementCodec: ElementCodec.self) + try context.reserveFieldSetMemory(ElementCodec.self, count: values.count) + return Set(values) } } @@ -960,11 +962,13 @@ where KeyCodec.Value: Hashable { let totalLength = Int(try context.buffer.readVarUInt32()) try context.ensureCollectionLength(totalLength, label: "map") if totalLength == 0 { + try context.reserveFieldMapMemory(key: KeyCodec.self, value: ValueCodec.self, count: totalLength) return [:] } - var map: Value = [:] + try context.reserveFieldMapMemory(key: KeyCodec.self, value: ValueCodec.self, count: totalLength) try context.ensureRemainingBytes(totalLength, label: "map") + var map: Value = [:] map.reserveCapacity(totalLength) var readCount = 0 while readCount < totalLength { @@ -1324,8 +1328,11 @@ private func writeUIntArrayPayload(_ value: [UInt], _ context: WriteContext) { } } -private func readIntArrayPayload(_ context: ReadContext) throws -> [Int] { +private func readIntArrayPayload(_ context: ReadContext, chargeContainerMemory: Bool = false) throws -> [Int] { let count = try readPackedArrayElementCount(context, width: 8, label: "int64_array") + if chargeContainerMemory { + try context.reserveArrayMemory(Int.self, count: count) + } var values: [Int] = [] values.reserveCapacity(count) for _ in 0.. [Int] { return values } -private func readUIntArrayPayload(_ context: ReadContext) throws -> [UInt] { +private func readUIntArrayPayload(_ context: ReadContext, chargeContainerMemory: Bool = false) throws -> [UInt] { let count = try readPackedArrayElementCount(context, width: 8, label: "uint64_array") + if chargeContainerMemory { + try context.reserveArrayMemory(UInt.self, count: count) + } var values: [UInt] = [] values.reserveCapacity(count) for _ in 0..( elementCodec _: ElementCodec.Type ) throws -> [ElementCodec.Value] { if ElementCodec.self == BoolCodec.self { - return uncheckedPackedArrayCast(try readPrimitiveArray(context) as [Bool], to: ElementCodec.Value.self) + return uncheckedPackedArrayCast(try readPrimitiveArray(context, chargeContainerMemory: true) as [Bool], to: ElementCodec.Value.self) } if ElementCodec.self == Int8Codec.self { - return uncheckedPackedArrayCast(try readPrimitiveArray(context) as [Int8], to: ElementCodec.Value.self) + return uncheckedPackedArrayCast(try readPrimitiveArray(context, chargeContainerMemory: true) as [Int8], to: ElementCodec.Value.self) } if ElementCodec.self == Int16Codec.self { - return uncheckedPackedArrayCast(try readPrimitiveArray(context) as [Int16], to: ElementCodec.Value.self) + return uncheckedPackedArrayCast(try readPrimitiveArray(context, chargeContainerMemory: true) as [Int16], to: ElementCodec.Value.self) } if ElementCodec.self == Int32FixedCodec.self || ElementCodec.self == Int32VarintCodec.self { - return uncheckedPackedArrayCast(try readPrimitiveArray(context) as [Int32], to: ElementCodec.Value.self) + return uncheckedPackedArrayCast(try readPrimitiveArray(context, chargeContainerMemory: true) as [Int32], to: ElementCodec.Value.self) } if ElementCodec.self == Int64FixedCodec.self || ElementCodec.self == Int64VarintCodec.self || ElementCodec.self == Int64TaggedCodec.self { - return uncheckedPackedArrayCast(try readPrimitiveArray(context) as [Int64], to: ElementCodec.Value.self) + return uncheckedPackedArrayCast(try readPrimitiveArray(context, chargeContainerMemory: true) as [Int64], to: ElementCodec.Value.self) } if ElementCodec.self == IntFixedCodec.self || ElementCodec.self == IntVarintCodec.self || ElementCodec.self == IntTaggedCodec.self { - return uncheckedPackedArrayCast(try readIntArrayPayload(context), to: ElementCodec.Value.self) + return uncheckedPackedArrayCast(try readIntArrayPayload(context, chargeContainerMemory: true), to: ElementCodec.Value.self) } if ElementCodec.self == UInt8Codec.self { - return uncheckedPackedArrayCast(try readPrimitiveArray(context) as [UInt8], to: ElementCodec.Value.self) + return uncheckedPackedArrayCast(try readPrimitiveArray(context, chargeContainerMemory: true) as [UInt8], to: ElementCodec.Value.self) } if ElementCodec.self == UInt16Codec.self { - return uncheckedPackedArrayCast(try readPrimitiveArray(context) as [UInt16], to: ElementCodec.Value.self) + return uncheckedPackedArrayCast(try readPrimitiveArray(context, chargeContainerMemory: true) as [UInt16], to: ElementCodec.Value.self) } if ElementCodec.self == UInt32FixedCodec.self || ElementCodec.self == UInt32VarintCodec.self { - return uncheckedPackedArrayCast(try readPrimitiveArray(context) as [UInt32], to: ElementCodec.Value.self) + return uncheckedPackedArrayCast(try readPrimitiveArray(context, chargeContainerMemory: true) as [UInt32], to: ElementCodec.Value.self) } if ElementCodec.self == UInt64FixedCodec.self || ElementCodec.self == UInt64VarintCodec.self || ElementCodec.self == UInt64TaggedCodec.self { - return uncheckedPackedArrayCast(try readPrimitiveArray(context) as [UInt64], to: ElementCodec.Value.self) + return uncheckedPackedArrayCast(try readPrimitiveArray(context, chargeContainerMemory: true) as [UInt64], to: ElementCodec.Value.self) } if ElementCodec.self == UIntFixedCodec.self || ElementCodec.self == UIntVarintCodec.self || ElementCodec.self == UIntTaggedCodec.self { - return uncheckedPackedArrayCast(try readUIntArrayPayload(context), to: ElementCodec.Value.self) + return uncheckedPackedArrayCast(try readUIntArrayPayload(context, chargeContainerMemory: true), to: ElementCodec.Value.self) } if ElementCodec.self == Float16Codec.self { - return uncheckedPackedArrayCast(try readPrimitiveArray(context) as [Float16], to: ElementCodec.Value.self) + return uncheckedPackedArrayCast(try readPrimitiveArray(context, chargeContainerMemory: true) as [Float16], to: ElementCodec.Value.self) } if ElementCodec.self == BFloat16Codec.self { - return uncheckedPackedArrayCast(try readPrimitiveArray(context) as [BFloat16], to: ElementCodec.Value.self) + return uncheckedPackedArrayCast(try readPrimitiveArray(context, chargeContainerMemory: true) as [BFloat16], to: ElementCodec.Value.self) } if ElementCodec.self == FloatCodec.self { - return uncheckedPackedArrayCast(try readPrimitiveArray(context) as [Float], to: ElementCodec.Value.self) + return uncheckedPackedArrayCast(try readPrimitiveArray(context, chargeContainerMemory: true) as [Float], to: ElementCodec.Value.self) } if ElementCodec.self == DoubleCodec.self { - return uncheckedPackedArrayCast(try readPrimitiveArray(context) as [Double], to: ElementCodec.Value.self) + return uncheckedPackedArrayCast(try readPrimitiveArray(context, chargeContainerMemory: true) as [Double], to: ElementCodec.Value.self) } throw ForyError.invalidData("unsupported compatible array-to-list field element codec \(ElementCodec.self)") } @@ -1592,6 +1602,7 @@ private func readCollectionPayload( let length = Int(try buffer.readVarUInt32()) try context.ensureCollectionLength(length, label: "array") if length == 0 { + try context.reserveFieldArrayMemory(ElementCodec.self, count: length) return [] } @@ -1606,6 +1617,7 @@ private func readCollectionPayload( let sameType = (header & CollectionHeader.sameType) != 0 var result: [ElementCodec.Value] = [] + try context.reserveFieldArrayMemory(ElementCodec.self, count: length) try context.ensureRemainingBytes(length, label: "array") result.reserveCapacity(length) @@ -1691,6 +1703,7 @@ private func readListPayloadAsArrayPayload( let length = Int(try buffer.readVarUInt32()) try context.ensureCollectionLength(length, label: "array") if length == 0 { + try context.reserveFieldArrayMemory(ElementCodec.self, count: length) return [] } @@ -1718,6 +1731,7 @@ private func readListPayloadAsArrayPayload( } try context.ensureRemainingBytes(length, label: "array") var result: [ElementCodec.Value] = [] + try context.reserveFieldArrayMemory(ElementCodec.self, count: length) result.reserveCapacity(length) return try ElementCodec.withTypeInfo(elementTypeInfo, context) { for _ in 0.. 0, "maxTypeFields must be positive") - precondition(maxTypeMetaBytes > 0, "maxTypeMetaBytes must be positive") - precondition(maxSchemaVersionsPerType > 0, "maxSchemaVersionsPerType must be positive") - precondition( - maxAverageSchemaVersionsPerType > 0, - "maxAverageSchemaVersionsPerType must be positive") - let effectiveCompatible = compatible ?? true - let effectiveCheckClassVersion = checkClassVersion ?? !effectiveCompatible - self.trackRef = trackRef - self.compatible = effectiveCompatible - self.checkClassVersion = effectiveCheckClassVersion - self.maxDepth = maxDepth - self.maxTypeFields = maxTypeFields - self.maxTypeMetaBytes = maxTypeMetaBytes - self.maxSchemaVersionsPerType = maxSchemaVersionsPerType - self.maxAverageSchemaVersionsPerType = maxAverageSchemaVersionsPerType - } + public let trackRef: Bool + public let compatible: Bool + public let checkClassVersion: Bool + public let maxDepth: Int + public let maxContainerMemoryBytes: Int64 + public let maxTypeFields: Int + public let maxTypeMetaBytes: Int + public let maxSchemaVersionsPerType: Int + public let maxAverageSchemaVersionsPerType: Int + + public init( + trackRef: Bool = false, + compatible: Bool? = nil, + checkClassVersion: Bool? = nil, + maxDepth: Int = 5, + maxContainerMemoryBytes: Int64 = -1, + maxTypeFields: Int = 512, + maxTypeMetaBytes: Int = 4096, + maxSchemaVersionsPerType: Int = 10, + maxAverageSchemaVersionsPerType: Int = 3 + ) { + precondition( + maxContainerMemoryBytes == -1 || maxContainerMemoryBytes > 0, + "maxContainerMemoryBytes must be positive or -1 for auto") + precondition(maxTypeFields > 0, "maxTypeFields must be positive") + precondition(maxTypeMetaBytes > 0, "maxTypeMetaBytes must be positive") + precondition(maxSchemaVersionsPerType > 0, "maxSchemaVersionsPerType must be positive") + precondition( + maxAverageSchemaVersionsPerType > 0, + "maxAverageSchemaVersionsPerType must be positive") + let effectiveCompatible = compatible ?? true + let effectiveCheckClassVersion = checkClassVersion ?? !effectiveCompatible + self.trackRef = trackRef + self.compatible = effectiveCompatible + self.checkClassVersion = effectiveCheckClassVersion + self.maxDepth = maxDepth + self.maxContainerMemoryBytes = maxContainerMemoryBytes + self.maxTypeFields = maxTypeFields + self.maxTypeMetaBytes = maxTypeMetaBytes + self.maxSchemaVersionsPerType = maxSchemaVersionsPerType + self.maxAverageSchemaVersionsPerType = maxAverageSchemaVersionsPerType + } } /// Single-threaded Fory runtime. @@ -62,498 +68,481 @@ public struct Config { /// reusable read/write context pair and must not be used concurrently from /// multiple threads. public final class Fory { - let typeResolver: TypeResolver - private let writeContext: WriteContext - private let readContext: ReadContext - public let config: Config - - public convenience init( - ref: Bool = false, - compatible: Bool? = nil, - checkClassVersion: Bool? = nil, - maxDepth: Int = 5, - maxTypeFields: Int = 512, - maxTypeMetaBytes: Int = 4096, - maxSchemaVersionsPerType: Int = 10, - maxAverageSchemaVersionsPerType: Int = 3 - ) { - self.init( - config: Config( - trackRef: ref, - compatible: compatible, - checkClassVersion: checkClassVersion, - maxDepth: maxDepth, - maxTypeFields: maxTypeFields, - maxTypeMetaBytes: maxTypeMetaBytes, - maxSchemaVersionsPerType: maxSchemaVersionsPerType, - maxAverageSchemaVersionsPerType: maxAverageSchemaVersionsPerType - )) - } - - public init(config: Config) { - self.typeResolver = TypeResolver(trackRef: config.trackRef) - self.writeContext = WriteContext( - buffer: ByteBuffer(), - typeResolver: typeResolver, - trackRef: config.trackRef, - compatible: config.compatible, - checkClassVersion: config.checkClassVersion, - maxDepth: config.maxDepth, - metaStringWriteState: MetaStringWriteState() - ) - self.readContext = ReadContext( - buffer: ByteBuffer(), - typeResolver: typeResolver, - config: config - ) - self.config = config - } - - public func register(_ type: T.Type, id: UInt32) { - typeResolver.register(type, id: id) - } - - /// Registers a user type by name. The last `.` separates namespace from the final type name. - public func register(_ type: T.Type, name: String) throws { - try typeResolver.register(type, name: name) - } - - public func serialize(_ value: T) throws -> Data { - try serializeRoot { context in - try writeRootTypedValue(value, context: context) - } - } - - public func deserialize(_ data: Data, as _: T.Type = T.self) throws -> T { - try deserializeRoot( - data: data - ) { context in - try readRootTypedValue(context: context) - } - } - - public func serialize(_ value: T, to buffer: inout Data) throws { - try appendSerializedRoot(to: &buffer) { context in - try writeRootTypedValue(value, context: context) - } - } - - public func deserialize(from buffer: ByteBuffer, as _: T.Type = T.self) throws -> T { - try deserializeRoot( - from: buffer - ) { context in - try readRootTypedValue(context: context) - } - } - - @_disfavoredOverload - public func serialize(_ value: Any) throws -> Data { - try serializeRoot { context in - try context.writeAny(value, refMode: refMode, writeTypeInfo: true, hasGenerics: false) - } - } - - @_disfavoredOverload - public func deserialize(_ data: Data, as _: Any.Type = Any.self) throws -> Any { - try deserializeRoot( - data: data - ) { context in - try castAnyDynamicValue( - context.readAny(refMode: refMode, readTypeInfo: true), - to: Any.self - ) - } - } - - @_disfavoredOverload - public func serialize(_ value: AnyObject) throws -> Data { - try serializeRoot { context in - try context.writeAny(value, refMode: refMode, writeTypeInfo: true, hasGenerics: false) - } - } - - @_disfavoredOverload - public func deserialize(_ data: Data, as _: AnyObject.Type = AnyObject.self) throws -> AnyObject { - try deserializeRoot( - data: data - ) { context in - try castAnyDynamicValue( - context.readAny(refMode: refMode, readTypeInfo: true), - to: AnyObject.self - ) - } - } - - @_disfavoredOverload - public func serialize(_ value: any Serializer) throws -> Data { - try serializeRoot { context in - try context.writeAny(value, refMode: refMode, writeTypeInfo: true, hasGenerics: false) - } - } - - @_disfavoredOverload - public func deserialize( - _ data: Data, as _: (any Serializer).Type = (any Serializer).self - ) throws - -> any Serializer - { - try deserializeRoot( - data: data - ) { context in - try castAnyDynamicValue( - context.readAny(refMode: refMode, readTypeInfo: true), - to: (any Serializer).self - ) - } - } - - @_disfavoredOverload - public func serialize(_ value: [Any]) throws -> Data { - try serializeRoot { context in - try context.writeListOfAny(value, refMode: refMode, writeTypeInfo: true, hasGenerics: false) - } - } - - @_disfavoredOverload - public func deserialize(_ data: Data, as _: [Any].Type = [Any].self) throws -> [Any] { - try deserializeRoot( - data: data - ) { context in - try context.readListOfAny(refMode: refMode, readTypeInfo: true) ?? [] - } - } - - @_disfavoredOverload - public func serialize(_ value: [String: Any]) throws -> Data { - try serializeRoot { context in - try context.writeMapStringToAny( - value, refMode: refMode, writeTypeInfo: true, hasGenerics: false) - } - } - - @_disfavoredOverload - public func deserialize( - _ data: Data, as _: [String: Any].Type = [String: Any].self - ) throws - -> [String: Any] - { - try deserializeRoot( - data: data - ) { context in - try context.readMapStringToAny(refMode: refMode, readTypeInfo: true) ?? [:] - } - } - - @_disfavoredOverload - public func serialize(_ value: [Int32: Any]) throws -> Data { - try serializeRoot { context in - try context.writeMapInt32ToAny( - value, refMode: refMode, writeTypeInfo: true, hasGenerics: false) - } - } - - @_disfavoredOverload - public func deserialize( - _ data: Data, as _: [Int32: Any].Type = [Int32: Any].self - ) throws - -> [Int32: Any] - { - try deserializeRoot( - data: data - ) { context in - try context.readMapInt32ToAny(refMode: refMode, readTypeInfo: true) ?? [:] - } - } - - @_disfavoredOverload - public func serialize(_ value: [AnyHashable: Any]) throws -> Data { - try serializeRoot { context in - try context.writeMapAnyHashableToAny( - value, refMode: refMode, writeTypeInfo: true, hasGenerics: false) - } - } - - @_disfavoredOverload - public func deserialize( - _ data: Data, as _: [AnyHashable: Any].Type = [AnyHashable: Any].self + let typeResolver: TypeResolver + private let writeContext: WriteContext + private let readContext: ReadContext + public let config: Config + + public convenience init( + ref: Bool = false, + compatible: Bool? = nil, + checkClassVersion: Bool? = nil, + maxDepth: Int = 5, + maxContainerMemoryBytes: Int64 = -1, + maxTypeFields: Int = 512, + maxTypeMetaBytes: Int = 4096, + maxSchemaVersionsPerType: Int = 10, + maxAverageSchemaVersionsPerType: Int = 3 + ) { + self.init( + config: Config( + trackRef: ref, + compatible: compatible, + checkClassVersion: checkClassVersion, + maxDepth: maxDepth, + maxContainerMemoryBytes: maxContainerMemoryBytes, + maxTypeFields: maxTypeFields, + maxTypeMetaBytes: maxTypeMetaBytes, + maxSchemaVersionsPerType: maxSchemaVersionsPerType, + maxAverageSchemaVersionsPerType: maxAverageSchemaVersionsPerType + )) + } + + public init(config: Config) { + self.typeResolver = TypeResolver(trackRef: config.trackRef) + self.writeContext = WriteContext( + buffer: ByteBuffer(), + typeResolver: typeResolver, + trackRef: config.trackRef, + compatible: config.compatible, + checkClassVersion: config.checkClassVersion, + maxDepth: config.maxDepth, + metaStringWriteState: MetaStringWriteState() ) - throws -> [AnyHashable: Any] - { - try deserializeRoot( - data: data - ) { context in - try context.readMapAnyHashableToAny(refMode: refMode, readTypeInfo: true) ?? [:] - } - } - - @_disfavoredOverload - public func serialize(_ value: [Any], to buffer: inout Data) throws { - try appendSerializedRoot(to: &buffer) { context in - try context.writeListOfAny(value, refMode: refMode, writeTypeInfo: true, hasGenerics: false) - } - } - - @_disfavoredOverload - public func serialize(_ value: Any, to buffer: inout Data) throws { - try appendSerializedRoot(to: &buffer) { context in - try context.writeAny(value, refMode: refMode, writeTypeInfo: true, hasGenerics: false) - } - } - - @_disfavoredOverload - public func deserialize(from buffer: ByteBuffer, as _: Any.Type = Any.self) throws -> Any { - try deserializeRoot( - from: buffer - ) { context in - try castAnyDynamicValue( - context.readAny(refMode: refMode, readTypeInfo: true), - to: Any.self - ) - } - } - - @_disfavoredOverload - public func serialize(_ value: AnyObject, to buffer: inout Data) throws { - try appendSerializedRoot(to: &buffer) { context in - try context.writeAny(value, refMode: refMode, writeTypeInfo: true, hasGenerics: false) - } - } - - @_disfavoredOverload - public func deserialize( - from buffer: ByteBuffer, as _: AnyObject.Type = AnyObject.self - ) throws - -> AnyObject - { - try deserializeRoot( - from: buffer - ) { context in - try castAnyDynamicValue( - context.readAny(refMode: refMode, readTypeInfo: true), - to: AnyObject.self - ) - } - } - - @_disfavoredOverload - public func serialize(_ value: any Serializer, to buffer: inout Data) throws { - try appendSerializedRoot(to: &buffer) { context in - try context.writeAny(value, refMode: refMode, writeTypeInfo: true, hasGenerics: false) - } - } - - @_disfavoredOverload - public func deserialize( - from buffer: ByteBuffer, - as _: (any Serializer).Type = (any Serializer).self - ) throws -> any Serializer { - try deserializeRoot( - from: buffer - ) { context in - try castAnyDynamicValue( - context.readAny(refMode: refMode, readTypeInfo: true), - to: (any Serializer).self - ) - } - } - - @_disfavoredOverload - public func deserialize(from buffer: ByteBuffer, as _: [Any].Type = [Any].self) throws -> [Any] { - try deserializeRoot( - from: buffer - ) { context in - try context.readListOfAny(refMode: refMode, readTypeInfo: true) ?? [] - } - } - - @_disfavoredOverload - public func serialize(_ value: [String: Any], to buffer: inout Data) throws { - try appendSerializedRoot(to: &buffer) { context in - try context.writeMapStringToAny( - value, refMode: refMode, writeTypeInfo: true, hasGenerics: false) - } - } - - @_disfavoredOverload - public func deserialize( - from buffer: ByteBuffer, as _: [String: Any].Type = [String: Any].self + self.readContext = ReadContext( + buffer: ByteBuffer(), + typeResolver: typeResolver, + config: config ) - throws -> [String: Any] - { - try deserializeRoot( - from: buffer - ) { context in - try context.readMapStringToAny(refMode: refMode, readTypeInfo: true) ?? [:] - } - } - - @_disfavoredOverload - public func serialize(_ value: [Int32: Any], to buffer: inout Data) throws { - try appendSerializedRoot(to: &buffer) { context in - try context.writeMapInt32ToAny( - value, refMode: refMode, writeTypeInfo: true, hasGenerics: false) - } - } - - @_disfavoredOverload - public func serialize(_ value: [AnyHashable: Any], to buffer: inout Data) throws { - try appendSerializedRoot(to: &buffer) { context in - try context.writeMapAnyHashableToAny( - value, refMode: refMode, writeTypeInfo: true, hasGenerics: false) - } - } - - @_disfavoredOverload - public func deserialize( - from buffer: ByteBuffer, as _: [Int32: Any].Type = [Int32: Any].self + self.config = config + } + + public func register(_ type: T.Type, id: UInt32) { + typeResolver.register(type, id: id) + } + + /// Registers a user type by name. The last `.` separates namespace from the final type name. + public func register(_ type: T.Type, name: String) throws { + try typeResolver.register(type, name: name) + } + + public func serialize(_ value: T) throws -> Data { + try serializeRoot { context in + try writeRootTypedValue(value, context: context) + } + } + + public func deserialize(_ data: Data, as _: T.Type = T.self) throws -> T { + try deserializeRoot( + data: data + ) { context in + try readRootTypedValue(context: context) + } + } + + public func serialize(_ value: T, to buffer: inout Data) throws { + try appendSerializedRoot(to: &buffer) { context in + try writeRootTypedValue(value, context: context) + } + } + + public func deserialize(from buffer: ByteBuffer, as _: T.Type = T.self) throws -> T { + try deserializeRoot( + from: buffer + ) { context in + try readRootTypedValue(context: context) + } + } + + @_disfavoredOverload + public func serialize(_ value: Any) throws -> Data { + try serializeRoot { context in + try context.writeAny(value, refMode: refMode, writeTypeInfo: true, hasGenerics: false) + } + } + + @_disfavoredOverload + public func deserialize(_ data: Data, as _: Any.Type = Any.self) throws -> Any { + try deserializeRoot( + data: data + ) { context in + try castAnyDynamicValue( + context.readAny(refMode: refMode, readTypeInfo: true), + to: Any.self + ) + } + } + + @_disfavoredOverload + public func serialize(_ value: AnyObject) throws -> Data { + try serializeRoot { context in + try context.writeAny(value, refMode: refMode, writeTypeInfo: true, hasGenerics: false) + } + } + + @_disfavoredOverload + public func deserialize(_ data: Data, as _: AnyObject.Type = AnyObject.self) throws -> AnyObject { + try deserializeRoot( + data: data + ) { context in + try castAnyDynamicValue( + context.readAny(refMode: refMode, readTypeInfo: true), + to: AnyObject.self + ) + } + } + + @_disfavoredOverload + public func serialize(_ value: any Serializer) throws -> Data { + try serializeRoot { context in + try context.writeAny(value, refMode: refMode, writeTypeInfo: true, hasGenerics: false) + } + } + + @_disfavoredOverload + public func deserialize(_ data: Data, as _: (any Serializer).Type = (any Serializer).self) throws + -> any Serializer { + try deserializeRoot( + data: data + ) { context in + try castAnyDynamicValue( + context.readAny(refMode: refMode, readTypeInfo: true), + to: (any Serializer).self + ) + } + } + + @_disfavoredOverload + public func serialize(_ value: [Any]) throws -> Data { + try serializeRoot { context in + try context.writeListOfAny(value, refMode: refMode, writeTypeInfo: true, hasGenerics: false) + } + } + + @_disfavoredOverload + public func deserialize(_ data: Data, as _: [Any].Type = [Any].self) throws -> [Any] { + try deserializeRoot( + data: data + ) { context in + try context.readListOfAny(refMode: refMode, readTypeInfo: true) ?? [] + } + } + + @_disfavoredOverload + public func serialize(_ value: [String: Any]) throws -> Data { + try serializeRoot { context in + try context.writeMapStringToAny( + value, refMode: refMode, writeTypeInfo: true, hasGenerics: false) + } + } + + @_disfavoredOverload + public func deserialize(_ data: Data, as _: [String: Any].Type = [String: Any].self) throws + -> [String: Any] { + try deserializeRoot( + data: data + ) { context in + try context.readMapStringToAny(refMode: refMode, readTypeInfo: true) ?? [:] + } + } + + @_disfavoredOverload + public func serialize(_ value: [Int32: Any]) throws -> Data { + try serializeRoot { context in + try context.writeMapInt32ToAny( + value, refMode: refMode, writeTypeInfo: true, hasGenerics: false) + } + } + + @_disfavoredOverload + public func deserialize(_ data: Data, as _: [Int32: Any].Type = [Int32: Any].self) throws + -> [Int32: Any] { + try deserializeRoot( + data: data + ) { context in + try context.readMapInt32ToAny(refMode: refMode, readTypeInfo: true) ?? [:] + } + } + + @_disfavoredOverload + public func serialize(_ value: [AnyHashable: Any]) throws -> Data { + try serializeRoot { context in + try context.writeMapAnyHashableToAny( + value, refMode: refMode, writeTypeInfo: true, hasGenerics: false) + } + } + + @_disfavoredOverload + public func deserialize(_ data: Data, as _: [AnyHashable: Any].Type = [AnyHashable: Any].self) + throws -> [AnyHashable: Any] { + try deserializeRoot( + data: data + ) { context in + try context.readMapAnyHashableToAny(refMode: refMode, readTypeInfo: true) ?? [:] + } + } + + @_disfavoredOverload + public func serialize(_ value: [Any], to buffer: inout Data) throws { + try appendSerializedRoot(to: &buffer) { context in + try context.writeListOfAny(value, refMode: refMode, writeTypeInfo: true, hasGenerics: false) + } + } + + @_disfavoredOverload + public func serialize(_ value: Any, to buffer: inout Data) throws { + try appendSerializedRoot(to: &buffer) { context in + try context.writeAny(value, refMode: refMode, writeTypeInfo: true, hasGenerics: false) + } + } + + @_disfavoredOverload + public func deserialize(from buffer: ByteBuffer, as _: Any.Type = Any.self) throws -> Any { + try deserializeRoot( + from: buffer + ) { context in + try castAnyDynamicValue( + context.readAny(refMode: refMode, readTypeInfo: true), + to: Any.self + ) + } + } + + @_disfavoredOverload + public func serialize(_ value: AnyObject, to buffer: inout Data) throws { + try appendSerializedRoot(to: &buffer) { context in + try context.writeAny(value, refMode: refMode, writeTypeInfo: true, hasGenerics: false) + } + } + + @_disfavoredOverload + public func deserialize(from buffer: ByteBuffer, as _: AnyObject.Type = AnyObject.self) throws + -> AnyObject { + try deserializeRoot( + from: buffer + ) { context in + try castAnyDynamicValue( + context.readAny(refMode: refMode, readTypeInfo: true), + to: AnyObject.self + ) + } + } + + @_disfavoredOverload + public func serialize(_ value: any Serializer, to buffer: inout Data) throws { + try appendSerializedRoot(to: &buffer) { context in + try context.writeAny(value, refMode: refMode, writeTypeInfo: true, hasGenerics: false) + } + } + + @_disfavoredOverload + public func deserialize( + from buffer: ByteBuffer, + as _: (any Serializer).Type = (any Serializer).self + ) throws -> any Serializer { + try deserializeRoot( + from: buffer + ) { context in + try castAnyDynamicValue( + context.readAny(refMode: refMode, readTypeInfo: true), + to: (any Serializer).self + ) + } + } + + @_disfavoredOverload + public func deserialize(from buffer: ByteBuffer, as _: [Any].Type = [Any].self) throws -> [Any] { + try deserializeRoot( + from: buffer + ) { context in + try context.readListOfAny(refMode: refMode, readTypeInfo: true) ?? [] + } + } + + @_disfavoredOverload + public func serialize(_ value: [String: Any], to buffer: inout Data) throws { + try appendSerializedRoot(to: &buffer) { context in + try context.writeMapStringToAny( + value, refMode: refMode, writeTypeInfo: true, hasGenerics: false) + } + } + + @_disfavoredOverload + public func deserialize(from buffer: ByteBuffer, as _: [String: Any].Type = [String: Any].self) + throws -> [String: Any] { + try deserializeRoot( + from: buffer + ) { context in + try context.readMapStringToAny(refMode: refMode, readTypeInfo: true) ?? [:] + } + } + + @_disfavoredOverload + public func serialize(_ value: [Int32: Any], to buffer: inout Data) throws { + try appendSerializedRoot(to: &buffer) { context in + try context.writeMapInt32ToAny( + value, refMode: refMode, writeTypeInfo: true, hasGenerics: false) + } + } + + @_disfavoredOverload + public func serialize(_ value: [AnyHashable: Any], to buffer: inout Data) throws { + try appendSerializedRoot(to: &buffer) { context in + try context.writeMapAnyHashableToAny( + value, refMode: refMode, writeTypeInfo: true, hasGenerics: false) + } + } + + @_disfavoredOverload + public func deserialize(from buffer: ByteBuffer, as _: [Int32: Any].Type = [Int32: Any].self) + throws -> [Int32: Any] { + try deserializeRoot( + from: buffer + ) { context in + try context.readMapInt32ToAny(refMode: refMode, readTypeInfo: true) ?? [:] + } + } + + @_disfavoredOverload + public func deserialize( + from buffer: ByteBuffer, as _: [AnyHashable: Any].Type = [AnyHashable: Any].self + ) throws -> [AnyHashable: Any] { + try deserializeRoot( + from: buffer + ) { context in + try context.readMapAnyHashableToAny(refMode: refMode, readTypeInfo: true) ?? [:] + } + } + + @inlinable + @inline(__always) + func writeHead(buffer: ByteBuffer) { + buffer.writeUInt8(ForyHeaderFlag.isXlang) + } + + @inlinable + @inline(__always) + func readHead(buffer: ByteBuffer) throws { + let bitmap = try buffer.readUInt8() + let expected = ForyHeaderFlag.isXlang + if bitmap != expected { + try readHeadSlow(bitmap: bitmap, expected: expected) + } + } + + @usableFromInline + @inline(never) + func readHeadSlow(bitmap: UInt8, expected: UInt8) throws { + if (bitmap & ~ForyHeaderFlag.knownMask) != 0 || (bitmap & ForyHeaderFlag.isOutOfBand) != 0 { + throw ForyError.invalidData("unsupported root header bitmap 0x\(String(bitmap, radix: 16))") + } + if (bitmap & ForyHeaderFlag.isXlang) != (expected & ForyHeaderFlag.isXlang) { + throw ForyError.invalidData("xlang bitmap mismatch") + } + } + + @inline(__always) + private var refMode: RefMode { + config.trackRef ? .tracking : .nullOnly + } + + private func writeRootTypedValue( + _ value: T, + context: WriteContext + ) throws { + try value.foryWrite( + context, + refMode: refMode, + writeTypeInfo: true, + hasGenerics: false ) - throws -> [Int32: Any] - { - try deserializeRoot( - from: buffer - ) { context in - try context.readMapInt32ToAny(refMode: refMode, readTypeInfo: true) ?? [:] - } - } - - @_disfavoredOverload - public func deserialize( - from buffer: ByteBuffer, as _: [AnyHashable: Any].Type = [AnyHashable: Any].self - ) throws -> [AnyHashable: Any] { - try deserializeRoot( - from: buffer - ) { context in - try context.readMapAnyHashableToAny(refMode: refMode, readTypeInfo: true) ?? [:] - } - } - - @inlinable - @inline(__always) - func writeHead(buffer: ByteBuffer) { - buffer.writeUInt8(ForyHeaderFlag.isXlang) - } - - @inlinable - @inline(__always) - func readHead(buffer: ByteBuffer) throws { - let bitmap = try buffer.readUInt8() - let expected = ForyHeaderFlag.isXlang - if bitmap != expected { - try readHeadSlow(bitmap: bitmap, expected: expected) - } - } - - @usableFromInline - @inline(never) - func readHeadSlow(bitmap: UInt8, expected: UInt8) throws { - if (bitmap & ~ForyHeaderFlag.knownMask) != 0 || (bitmap & ForyHeaderFlag.isOutOfBand) != 0 { - throw ForyError.invalidData("unsupported root header bitmap 0x\(String(bitmap, radix: 16))") - } - if (bitmap & ForyHeaderFlag.isXlang) != (expected & ForyHeaderFlag.isXlang) { - throw ForyError.invalidData("xlang bitmap mismatch") - } - } - - @inline(__always) - private var refMode: RefMode { - config.trackRef ? .tracking : .nullOnly - } - - private func writeRootTypedValue( - _ value: T, - context: WriteContext - ) throws { - try value.foryWrite( - context, - refMode: refMode, - writeTypeInfo: true, - hasGenerics: false - ) - } - - @inline(__always) - private func readRootTypedValue( - context: ReadContext - ) throws -> T { - return try T.foryRead( - context, - refMode: refMode, - readTypeInfo: true - ) - } - - @inline(__always) - func withReusableReadContext( - data: Data, - _ body: (ReadContext) throws -> R - ) rethrows -> R { - readContext.buffer.replace(with: data) - defer { - readContext.reset() - } - return try body(readContext) - } - - @inline(__always) - private func serializeRoot( - _ body: (WriteContext) throws -> Void - ) throws -> Data { - try typeResolver.finishRegistration() - let context = writeContext - context.buffer.clear() - defer { - context.reset() - } - writeHead(buffer: context.buffer) - try body(context) - return context.buffer.copyToData() - } - - @inline(__always) - private func appendSerializedRoot( - to output: inout Data, - _ body: (WriteContext) throws -> Void - ) throws { - try typeResolver.finishRegistration() - let context = writeContext - context.buffer.clear() - defer { - context.reset() - } - writeHead(buffer: context.buffer) - try body(context) - output.append(contentsOf: context.buffer.storage.prefix(context.buffer.count)) - } - - @inline(__always) - private func deserializeRoot( - data: Data, - _ body: (ReadContext) throws -> R - ) throws -> R { - try typeResolver.finishRegistration() - return try withReusableReadContext(data: data) { context in - try readHead(buffer: context.buffer) - let value = try body(context) - if context.buffer.remaining != 0 { - throw ForyError.invalidData( - "unexpected trailing bytes at root: \(context.buffer.remaining)") - } - return value - } - } - - @inline(__always) - private func deserializeRoot( - from buffer: ByteBuffer, - _ body: (ReadContext) throws -> R - ) throws -> R { - try typeResolver.finishRegistration() - readContext.buffer.swapState(with: buffer) - defer { - readContext.buffer.swapState(with: buffer) - readContext.reset() - } - try readHead(buffer: readContext.buffer) - return try body(readContext) - } + } + + @inline(__always) + private func readRootTypedValue( + context: ReadContext + ) throws -> T { + return try T.foryRead( + context, + refMode: refMode, + readTypeInfo: true + ) + } + + @inline(__always) + func withReusableReadContext( + data: Data, + _ body: (ReadContext) throws -> R + ) throws -> R { + readContext.buffer.replace(with: data) + try readContext.initContainerMemoryBudgetKnown(rootBytes: data.count) + defer { + readContext.reset() + } + return try body(readContext) + } + + @inline(__always) + private func serializeRoot( + _ body: (WriteContext) throws -> Void + ) throws -> Data { + try typeResolver.finishRegistration() + let context = writeContext + context.buffer.clear() + defer { + context.reset() + } + writeHead(buffer: context.buffer) + try body(context) + return context.buffer.copyToData() + } + + @inline(__always) + private func appendSerializedRoot( + to output: inout Data, + _ body: (WriteContext) throws -> Void + ) throws { + try typeResolver.finishRegistration() + let context = writeContext + context.buffer.clear() + defer { + context.reset() + } + writeHead(buffer: context.buffer) + try body(context) + output.append(contentsOf: context.buffer.storage.prefix(context.buffer.count)) + } + + @inline(__always) + private func deserializeRoot( + data: Data, + _ body: (ReadContext) throws -> R + ) throws -> R { + try typeResolver.finishRegistration() + return try withReusableReadContext(data: data) { context in + try readHead(buffer: context.buffer) + let value = try body(context) + if context.buffer.remaining != 0 { + throw ForyError.invalidData( + "unexpected trailing bytes at root: \(context.buffer.remaining)") + } + return value + } + } + + @inline(__always) + private func deserializeRoot( + from buffer: ByteBuffer, + _ body: (ReadContext) throws -> R + ) throws -> R { + try typeResolver.finishRegistration() + readContext.buffer.swapState(with: buffer) + try readContext.initContainerMemoryBudgetKnown(rootBytes: readContext.buffer.remaining) + defer { + readContext.buffer.swapState(with: buffer) + readContext.reset() + } + try readHead(buffer: readContext.buffer) + return try body(readContext) + } } diff --git a/swift/Sources/Fory/ReadContext.swift b/swift/Sources/Fory/ReadContext.swift index 84505e738b..fd5fb63678 100644 --- a/swift/Sources/Fory/ReadContext.swift +++ b/swift/Sources/Fory/ReadContext.swift @@ -20,777 +20,939 @@ import Foundation private let typeMetaSizeMask = 0xFF public final class ReadContext { - public let buffer: ByteBuffer - let typeResolver: TypeResolver - public let trackRef: Bool - public let compatible: Bool - public let checkClassVersion: Bool - public let maxDepth: Int - public let refReader: RefReader - private let compatibleTypeDefTypeInfos = ReusableArray(defaultValue: nil, reserve: 2) - private let metaStrings = ReusableArray(defaultValue: nil, reserve: 16) - private var dynamicAnyDepth = 0 - - private var typeInfoStack = UInt64Map(initialCapacity: 8) - private var typeInfoScopeStack: [(typeKey: UInt64, previousTypeInfo: TypeInfo?)] = [] - private var lastTypeInfo = TypeInfo.uncached - private let config: Config - - init( - buffer: ByteBuffer, - typeResolver: TypeResolver, - config: Config - ) { - self.buffer = buffer - self.typeResolver = typeResolver - self.trackRef = config.trackRef - self.compatible = config.compatible - self.checkClassVersion = config.checkClassVersion - self.maxDepth = config.maxDepth - self.config = config - self.refReader = RefReader() - } - - @inline(__always) - func enterDynamicAnyDepth() throws { - if maxDepth < 0 { - throw ForyError.invalidData("configured maxDepth \(maxDepth) is negative") - } - let nextDepth = dynamicAnyDepth + 1 - if nextDepth > maxDepth { - throw ForyError.invalidData( - "dynamic Any nesting depth \(nextDepth) exceeds configured maxDepth \(maxDepth)" - ) - } - dynamicAnyDepth = nextDepth + static let knownContainerBudgetSlackBytes = 64 * 1024 + static let unknownContainerBudgetBytes = 128 * 1024 * 1024 + static let containerFixedBytes = 32 + static let arrayHeaderBytes = 24 + static let referenceBytes = 4 + static let collectionEntryOverheadBytes = 16 + static let mapEntryOverheadBytes = 24 + private static let maxKnownContainerRootBytes = (Int.max - knownContainerBudgetSlackBytes) / 8 + + public let buffer: ByteBuffer + let typeResolver: TypeResolver + public let trackRef: Bool + public let compatible: Bool + public let checkClassVersion: Bool + public let maxDepth: Int + public let refReader: RefReader + private let compatibleTypeDefTypeInfos = ReusableArray(defaultValue: nil, reserve: 2) + private let metaStrings = ReusableArray(defaultValue: nil, reserve: 16) + private var dynamicAnyDepth = 0 + + private var typeInfoStack = UInt64Map(initialCapacity: 8) + private var typeInfoScopeStack: [(typeKey: UInt64, previousTypeInfo: TypeInfo?)] = [] + private var lastTypeInfo = TypeInfo.uncached + private let config: Config + private let maxContainerMemoryBytes: Int + private var remainingContainerMemoryBytes = Int.max + + init( + buffer: ByteBuffer, + typeResolver: TypeResolver, + config: Config + ) { + self.buffer = buffer + self.typeResolver = typeResolver + self.trackRef = config.trackRef + self.compatible = config.compatible + self.checkClassVersion = config.checkClassVersion + self.maxDepth = config.maxDepth + self.config = config + self.maxContainerMemoryBytes = Int(config.maxContainerMemoryBytes) + self.refReader = RefReader() + } + + @inline(__always) + func initContainerMemoryBudgetKnown(rootBytes: Int) throws { + var limit = maxContainerMemoryBytes + if limit < 0 { + if rootBytes > Self.maxKnownContainerRootBytes { + try throwContainerMemoryOverflow() + } + limit = rootBytes * 8 + Self.knownContainerBudgetSlackBytes } - - @inline(__always) - func leaveDynamicAnyDepth() { - if dynamicAnyDepth > 0 { - dynamicAnyDepth -= 1 - } + remainingContainerMemoryBytes = limit + } + + @inline(__always) + func reserveArrayMemory(_ type: Element.Type, count: Int) throws { + try reserveArrayMemory(count: count, elementBytes: containerElementBytes(type)) + } + + @inline(__always) + func reserveFieldArrayMemory( + _ codec: ElementCodec.Type, + count: Int + ) throws { + try reserveArrayMemory(count: count, elementBytes: fieldElementBytes(codec)) + } + + @inline(__always) + func reserveReferenceArrayMemory(count: Int) throws { + try reserveArrayMemory(count: count, elementBytes: Self.referenceBytes) + } + + @inline(__always) + func reserveSetMemory(_ type: Element.Type, count: Int) throws { + try reserveSetMemory(count: count, elementBytes: containerElementBytes(type)) + } + + @inline(__always) + func reserveFieldSetMemory( + _ codec: ElementCodec.Type, + count: Int + ) throws { + try reserveSetMemory(count: count, elementBytes: fieldElementBytes(codec)) + } + + @inline(__always) + func reserveMapMemory( + key _: Key.Type, + value _: Value.Type, + count: Int + ) throws { + try reserveMapMemory( + count: count, + keyBytes: containerElementBytes(Key.self), + valueBytes: containerElementBytes(Value.self) + ) + } + + @inline(__always) + func reserveFieldMapMemory( + key _: KeyCodec.Type, + value _: ValueCodec.Type, + count: Int + ) throws { + try reserveMapMemory( + count: count, + keyBytes: fieldElementBytes(KeyCodec.self), + valueBytes: fieldElementBytes(ValueCodec.self) + ) + } + + @inline(__always) + func reserveReferenceMapMemory(count: Int) throws { + try reserveMapMemory(count: count, keyBytes: Self.referenceBytes, valueBytes: Self.referenceBytes) + } + + @inline(__always) + private func reserveArrayMemory(count: Int, elementBytes: Int) throws { + if count == 0 { + try reserveContainerMemory(Self.containerFixedBytes) + return } - - @inline(__always) - func ensureCollectionLength(_ length: Int, label: String) throws { - if length < 0 { - throw ForyError.invalidData("\(label) length is negative") - } + try reserveCountedContainerMemory( + count: count, + fixedBytes: Self.containerFixedBytes + Self.arrayHeaderBytes, + elementBytes: elementBytes + ) + } + + @inline(__always) + private func reserveSetMemory(count: Int, elementBytes: Int) throws { + if count == 0 { + try reserveContainerMemory(Self.containerFixedBytes) + return } - - @inline(__always) - func ensureRemainingBytes(_ byteCount: Int, label: String) throws { - if byteCount < 0 { - throw ForyError.invalidData("\(label) size is negative") - } - let remainingBytes = buffer.remaining - if byteCount > remainingBytes { - throw ForyError.invalidData( - "\(label) requires \(byteCount) bytes but only \(remainingBytes) remain in buffer" - ) - } + try reserveCountedContainerMemory( + count: count, + fixedBytes: Self.containerFixedBytes + Self.arrayHeaderBytes, + elementBytes: elementBytes + Self.collectionEntryOverheadBytes + Self.referenceBytes * 2 + ) + } + + @inline(__always) + private func reserveMapMemory(count: Int, keyBytes: Int, valueBytes: Int) throws { + if count == 0 { + try reserveContainerMemory(Self.containerFixedBytes) + return } - - @inline(__always) - func typeInfo(for type: T.Type) throws -> TypeInfo { - let typeID = ObjectIdentifier(type) - if lastTypeInfo.swiftTypeID == typeID { - return lastTypeInfo - } - let info = try typeResolver.requireTypeInfo(for: type) - lastTypeInfo = info - return info + try reserveCountedContainerMemory( + count: count, + fixedBytes: Self.containerFixedBytes + Self.arrayHeaderBytes * 2, + elementBytes: keyBytes + valueBytes + Self.mapEntryOverheadBytes + Self.referenceBytes + ) + } + + @inline(__always) + private func reserveContainerMemory(_ bytes: Int) throws { + if bytes > remainingContainerMemoryBytes { + try throwContainerMemoryExceeded(bytes: bytes) + } + remainingContainerMemoryBytes -= bytes + } + + @inline(__always) + private func reserveCountedContainerMemory( + count: Int, + fixedBytes: Int, + elementBytes: Int + ) throws { + if count > (Int.max - fixedBytes) / elementBytes { + try throwContainerMemoryOverflow() + } + try reserveContainerMemory(count * elementBytes + fixedBytes) + } + + @inline(__always) + private func containerElementBytes(_ type: Element.Type) -> Int { + type.isRefType ? Self.referenceBytes : max(1, MemoryLayout.stride) + } + + @inline(__always) + private func fieldElementBytes(_ codec: ElementCodec.Type) -> Int { + codec.isRefType ? Self.referenceBytes : max(1, MemoryLayout.stride) + } + + @inline(never) + private func throwContainerMemoryOverflow() throws -> Never { + throw ForyError.invalidData("container memory estimate overflows") + } + + @inline(never) + private func throwContainerMemoryExceeded(bytes: Int) throws -> Never { + let message = + "estimated container memory request \(bytes) bytes exceeds maxContainerMemoryBytes " + + "remaining budget \(remainingContainerMemoryBytes) bytes" + throw ForyError.invalidData(message) + } + + @inline(__always) + func enterDynamicAnyDepth() throws { + if maxDepth < 0 { + throw ForyError.invalidData("configured maxDepth \(maxDepth) is negative") } + let nextDepth = dynamicAnyDepth + 1 + if nextDepth > maxDepth { + throw ForyError.invalidData( + "dynamic Any nesting depth \(nextDepth) exceeds configured maxDepth \(maxDepth)" + ) + } + dynamicAnyDepth = nextDepth + } - @inline(__always) - func readStaticTypeInfo(_ typeID: TypeId) throws -> TypeInfo? { - let rawTypeID = UInt32(try buffer.readUInt8()) - guard let actualTypeID = TypeId(rawValue: rawTypeID) else { - throw ForyError.invalidData("unknown type id \(rawTypeID)") - } - if actualTypeID != typeID { - throw ForyError.typeMismatch(expected: typeID.rawValue, actual: rawTypeID) - } - return nil + @inline(__always) + func leaveDynamicAnyDepth() { + if dynamicAnyDepth > 0 { + dynamicAnyDepth -= 1 } + } - func readTypeInfo() throws -> TypeInfo { - let rawTypeID = UInt32(try buffer.readUInt8()) - guard let wireTypeID = TypeId(rawValue: rawTypeID) else { - throw ForyError.invalidData("unknown dynamic type id \(rawTypeID)") - } + @inline(__always) + func ensureCollectionLength(_ length: Int, label: String) throws { + if length < 0 { + throw ForyError.invalidData("\(label) length is negative") + } + } - switch wireTypeID { - case .compatibleStruct, .namedCompatibleStruct: - return try readCompatibleTypeInfo() - case .namedEnum, .namedStruct, .namedExt, .namedUnion: - if compatible { - return try readCompatibleTypeInfo() - } - let namespace = try readMetaString( - context: self, - decoder: .namespace, - encodings: namespaceMetaStringEncodings - ) - let typeName = try readMetaString( - context: self, - decoder: .typeName, - encodings: typeNameMetaStringEncodings - ) - return try typeResolver.requireTypeInfo(namespace: namespace.value, typeName: typeName.value) - case .structType, .enumType, .ext, .typedUnion, .union: - let userTypeID = try buffer.readVarUInt32() - return try typeResolver.requireTypeInfo(userTypeID: userTypeID) - default: - return typeResolver.builtinTypeInfo(for: wireTypeID) - } + @inline(__always) + func ensureRemainingBytes(_ byteCount: Int, label: String) throws { + if byteCount < 0 { + throw ForyError.invalidData("\(label) size is negative") + } + let remainingBytes = buffer.remaining + if byteCount > remainingBytes { + throw ForyError.invalidData( + "\(label) requires \(byteCount) bytes but only \(remainingBytes) remain in buffer" + ) } + } - func readTypeInfo(for type: T.Type) throws -> TypeInfo? { - let rawTypeID = UInt32(try buffer.readUInt8()) - guard let typeID = TypeId(rawValue: rawTypeID) else { - throw ForyError.invalidData("unknown type id \(rawTypeID)") - } + @inline(__always) + func typeInfo(for type: T.Type) throws -> TypeInfo { + let typeID = ObjectIdentifier(type) + if lastTypeInfo.swiftTypeID == typeID { + return lastTypeInfo + } + let info = try typeResolver.requireTypeInfo(for: type) + lastTypeInfo = info + return info + } + + @inline(__always) + func readStaticTypeInfo(_ typeID: TypeId) throws -> TypeInfo? { + let rawTypeID = UInt32(try buffer.readUInt8()) + guard let actualTypeID = TypeId(rawValue: rawTypeID) else { + throw ForyError.invalidData("unknown type id \(rawTypeID)") + } + if actualTypeID != typeID { + throw ForyError.typeMismatch(expected: typeID.rawValue, actual: rawTypeID) + } + return nil + } - guard T.staticTypeId.isUserTypeKind else { - if typeID != T.staticTypeId { - throw ForyError.typeMismatch(expected: T.staticTypeId.rawValue, actual: rawTypeID) - } - return nil - } + func readTypeInfo() throws -> TypeInfo { + let rawTypeID = UInt32(try buffer.readUInt8()) + guard let wireTypeID = TypeId(rawValue: rawTypeID) else { + throw ForyError.invalidData("unknown dynamic type id \(rawTypeID)") + } - let localTypeInfo = try typeInfo(for: type) - let expectedWireTypeID = localTypeInfo.wireTypeID(compatible: compatible) - if !isAllowedRegisteredWireTypeID( - typeID, - declaredTypeID: localTypeInfo.typeID, - registerByName: localTypeInfo.registerByName, - compatible: compatible, - evolving: localTypeInfo.evolving - ) { - throw ForyError.typeMismatch(expected: expectedWireTypeID.rawValue, actual: rawTypeID) - } + switch wireTypeID { + case .compatibleStruct, .namedCompatibleStruct: + return try readCompatibleTypeInfo() + case .namedEnum, .namedStruct, .namedExt, .namedUnion: + if compatible { + return try readCompatibleTypeInfo() + } + let namespace = try readMetaString( + context: self, + decoder: .namespace, + encodings: namespaceMetaStringEncodings + ) + let typeName = try readMetaString( + context: self, + decoder: .typeName, + encodings: typeNameMetaStringEncodings + ) + return try typeResolver.requireTypeInfo(namespace: namespace.value, typeName: typeName.value) + case .structType, .enumType, .ext, .typedUnion, .union: + let userTypeID = try buffer.readVarUInt32() + return try typeResolver.requireTypeInfo(userTypeID: userTypeID) + default: + return typeResolver.builtinTypeInfo(for: wireTypeID) + } + } - switch typeID { - case .compatibleStruct, .namedCompatibleStruct: - return try readCompatibleTypeInfoIfNeeded( - for: localTypeInfo, - wireTypeID: typeID - ) - case .namedEnum, .namedStruct, .namedExt, .namedUnion: - if compatible { - _ = try readCompatibleTypeInfoIfNeeded( - for: localTypeInfo, - wireTypeID: typeID - ) - } else { - let namespace = try readMetaString( - context: self, - decoder: .namespace, - encodings: namespaceMetaStringEncodings - ) - let typeName = try readMetaString( - context: self, - decoder: .typeName, - encodings: typeNameMetaStringEncodings - ) - guard localTypeInfo.registerByName else { - throw ForyError.invalidData( - "received name-registered type info for id-registered local type") - } - if namespace.value != localTypeInfo.namespace.value - || typeName.value != localTypeInfo.typeName.value - { - let expectedTypeName = "\(localTypeInfo.namespace.value)::\(localTypeInfo.typeName.value)" - let actualTypeName = "\(namespace.value)::\(typeName.value)" - throw ForyError.invalidData( - "type name mismatch: expected \(expectedTypeName), got \(actualTypeName)" - ) - } - } - default: - if !localTypeInfo.registerByName && registeredWireTypeNeedsUserTypeID(typeID) { - guard let localUserTypeID = localTypeInfo.userTypeID else { - throw ForyError.invalidData("missing user type id for id-registered type") - } - let remoteUserTypeID = try buffer.readVarUInt32() - if remoteUserTypeID != localUserTypeID { - throw ForyError.typeMismatch(expected: localUserTypeID, actual: remoteUserTypeID) - } - } - } - return nil - } - - @inline(__always) - private func readCompatibleTypeInfoIfNeeded( - for localTypeInfo: TypeInfo, - wireTypeID: TypeId - ) throws -> TypeInfo? { - let buffer = self.buffer - let compatibleTypeDefTypeInfos = self.compatibleTypeDefTypeInfos - if !checkClassVersion, - compatibleTypeDefTypeInfos.isEmpty, - !localTypeInfo.typeDefHasUserTypeFields, - let localTypeDefHeader = localTypeInfo.typeDefHeader - { - let indexMarker = try buffer.readVarUInt32() - if indexMarker == 0 { - let headerStart = buffer.getCursor() - let header = try buffer.readUInt64() - var bodySize = Int(header & UInt64(typeMetaSizeMask)) - if bodySize == typeMetaSizeMask { - bodySize += Int(try buffer.readVarUInt32()) - } - if header == localTypeDefHeader { - // The declared local type owns this exact metadata header, so this is a - // local-schema hit rather than a remote cache publish. Keep it allocation-free: - // skip the body, add the local type to the per-read table, and do not parse/hash. - try buffer.skip(bodySize) - compatibleTypeDefTypeInfos.push(localTypeInfo) - return nil - } - if let cached = typeResolver.getTypeInfo(forHeader: header) { - // Header-cache hits intentionally skip without rehashing. Entries reach this cache only - // after a successful TypeDef parse and 52-bit metadata-hash validation. Do not add - // body/hash/schema-limit/exact-local checks here; the miss path owns them before publish. - try buffer.skip(bodySize) - compatibleTypeDefTypeInfos.push(cached) - return try validateCompatibleTypeInfo(cached, for: localTypeInfo, wireTypeID: wireTypeID) - } - let cachedTypeInfo = try readTypeInfoBody( - start: headerStart, - header: header, - for: localTypeInfo, - wireTypeID: wireTypeID) - compatibleTypeDefTypeInfos.push(cachedTypeInfo) - if cachedTypeInfo === localTypeInfo { - return nil - } - return try validateCompatibleTypeInfo( - cachedTypeInfo, for: localTypeInfo, wireTypeID: wireTypeID) - } - return try readCompatibleTypeInfo( - afterMarker: indexMarker, - for: localTypeInfo, - wireTypeID: wireTypeID) - } - return try readCompatibleTypeInfo( - for: localTypeInfo, - wireTypeID: wireTypeID - ) + func readTypeInfo(for type: T.Type) throws -> TypeInfo? { + let rawTypeID = UInt32(try buffer.readUInt8()) + guard let typeID = TypeId(rawValue: rawTypeID) else { + throw ForyError.invalidData("unknown type id \(rawTypeID)") } - private func readCompatibleTypeInfo() throws -> TypeInfo { - let indexMarker = try buffer.readVarUInt32() - return try readCompatibleTypeInfo(afterMarker: indexMarker) + guard T.staticTypeId.isUserTypeKind else { + if typeID != T.staticTypeId { + throw ForyError.typeMismatch(expected: T.staticTypeId.rawValue, actual: rawTypeID) + } + return nil } - private func readCompatibleTypeInfo(afterMarker indexMarker: UInt32) throws -> TypeInfo { - let buffer = self.buffer - let compatibleTypeDefTypeInfos = self.compatibleTypeDefTypeInfos - let isRef = (indexMarker & 1) == 1 - let index = Int(indexMarker >> 1) - if isRef { - guard let typeInfo = compatibleTypeDefTypeInfos.get(index) else { - throw ForyError.invalidData("unknown compatible type definition ref index \(index)") - } - return typeInfo - } + let localTypeInfo = try typeInfo(for: type) + let expectedWireTypeID = localTypeInfo.wireTypeID(compatible: compatible) + if !isAllowedRegisteredWireTypeID( + typeID, + declaredTypeID: localTypeInfo.typeID, + registerByName: localTypeInfo.registerByName, + compatible: compatible, + evolving: localTypeInfo.evolving + ) { + throw ForyError.typeMismatch(expected: expectedWireTypeID.rawValue, actual: rawTypeID) + } - let typeMetaStart = buffer.getCursor() + switch typeID { + case .compatibleStruct, .namedCompatibleStruct: + return try readCompatibleTypeInfoIfNeeded( + for: localTypeInfo, + wireTypeID: typeID + ) + case .namedEnum, .namedStruct, .namedExt, .namedUnion: + if compatible { + _ = try readCompatibleTypeInfoIfNeeded( + for: localTypeInfo, + wireTypeID: typeID + ) + } else { + let namespace = try readMetaString( + context: self, + decoder: .namespace, + encodings: namespaceMetaStringEncodings + ) + let typeName = try readMetaString( + context: self, + decoder: .typeName, + encodings: typeNameMetaStringEncodings + ) + guard localTypeInfo.registerByName else { + throw ForyError.invalidData( + "received name-registered type info for id-registered local type") + } + if namespace.value != localTypeInfo.namespace.value + || typeName.value != localTypeInfo.typeName.value { + let expectedTypeName = "\(localTypeInfo.namespace.value)::\(localTypeInfo.typeName.value)" + let actualTypeName = "\(namespace.value)::\(typeName.value)" + throw ForyError.invalidData( + "type name mismatch: expected \(expectedTypeName), got \(actualTypeName)" + ) + } + } + default: + if !localTypeInfo.registerByName && registeredWireTypeNeedsUserTypeID(typeID) { + guard let localUserTypeID = localTypeInfo.userTypeID else { + throw ForyError.invalidData("missing user type id for id-registered type") + } + let remoteUserTypeID = try buffer.readVarUInt32() + if remoteUserTypeID != localUserTypeID { + throw ForyError.typeMismatch(expected: localUserTypeID, actual: remoteUserTypeID) + } + } + } + return nil + } + + @inline(__always) + private func readCompatibleTypeInfoIfNeeded( + for localTypeInfo: TypeInfo, + wireTypeID: TypeId + ) throws -> TypeInfo? { + let buffer = self.buffer + let compatibleTypeDefTypeInfos = self.compatibleTypeDefTypeInfos + if !checkClassVersion, + compatibleTypeDefTypeInfos.isEmpty, + !localTypeInfo.typeDefHasUserTypeFields, + let localTypeDefHeader = localTypeInfo.typeDefHeader { + let indexMarker = try buffer.readVarUInt32() + if indexMarker == 0 { + let headerStart = buffer.getCursor() let header = try buffer.readUInt64() var bodySize = Int(header & UInt64(typeMetaSizeMask)) if bodySize == typeMetaSizeMask { - bodySize += Int(try buffer.readVarUInt32()) + bodySize += Int(try buffer.readVarUInt32()) + } + if header == localTypeDefHeader { + // The declared local type owns this exact metadata header, so this is a + // local-schema hit rather than a remote cache publish. Keep it allocation-free: + // skip the body, add the local type to the per-read table, and do not parse/hash. + try buffer.skip(bodySize) + compatibleTypeDefTypeInfos.push(localTypeInfo) + return nil } if let cached = typeResolver.getTypeInfo(forHeader: header) { - // Header-cache hits intentionally skip without rehashing. Entries reach this cache only - // after a successful TypeDef parse and 52-bit metadata-hash validation. Do not add - // body/hash/schema-limit/exact-local checks here; the miss path owns them before publish. - try buffer.skip(bodySize) - compatibleTypeDefTypeInfos.push(cached) - return cached + // Header-cache hits intentionally skip without rehashing. Entries reach this cache only + // after a successful TypeDef parse and 52-bit metadata-hash validation. Do not add + // body/hash/schema-limit/exact-local checks here; the miss path owns them before publish. + try buffer.skip(bodySize) + compatibleTypeDefTypeInfos.push(cached) + return try validateCompatibleTypeInfo(cached, for: localTypeInfo, wireTypeID: wireTypeID) } - - let cachedTypeInfo = try readTypeInfoBody(start: typeMetaStart, header: header) + let cachedTypeInfo = try readTypeInfoBody( + start: headerStart, + header: header, + for: localTypeInfo, + wireTypeID: wireTypeID) compatibleTypeDefTypeInfos.push(cachedTypeInfo) - return cachedTypeInfo - } - - @inline(never) - private func readCompatibleTypeInfo( - afterMarker indexMarker: UInt32, - for localTypeInfo: TypeInfo, - wireTypeID: TypeId - ) throws -> TypeInfo { - let buffer = self.buffer - let compatibleTypeDefTypeInfos = self.compatibleTypeDefTypeInfos - let isRef = (indexMarker & 1) == 1 - let index = Int(indexMarker >> 1) - if isRef { - guard let typeInfo = compatibleTypeDefTypeInfos.get(index) else { - throw ForyError.invalidData("unknown compatible type definition ref index \(index)") - } - return try validateCompatibleTypeInfo(typeInfo, for: localTypeInfo, wireTypeID: wireTypeID) - } + if cachedTypeInfo === localTypeInfo { + return nil + } + return try validateCompatibleTypeInfo( + cachedTypeInfo, for: localTypeInfo, wireTypeID: wireTypeID) + } + return try readCompatibleTypeInfo( + afterMarker: indexMarker, + for: localTypeInfo, + wireTypeID: wireTypeID) + } + return try readCompatibleTypeInfo( + for: localTypeInfo, + wireTypeID: wireTypeID + ) + } + + private func readCompatibleTypeInfo() throws -> TypeInfo { + let indexMarker = try buffer.readVarUInt32() + return try readCompatibleTypeInfo(afterMarker: indexMarker) + } + + private func readCompatibleTypeInfo(afterMarker indexMarker: UInt32) throws -> TypeInfo { + let buffer = self.buffer + let compatibleTypeDefTypeInfos = self.compatibleTypeDefTypeInfos + let isRef = (indexMarker & 1) == 1 + let index = Int(indexMarker >> 1) + if isRef { + guard let typeInfo = compatibleTypeDefTypeInfos.get(index) else { + throw ForyError.invalidData("unknown compatible type definition ref index \(index)") + } + return typeInfo + } + + let typeMetaStart = buffer.getCursor() + let header = try buffer.readUInt64() + var bodySize = Int(header & UInt64(typeMetaSizeMask)) + if bodySize == typeMetaSizeMask { + bodySize += Int(try buffer.readVarUInt32()) + } + if let cached = typeResolver.getTypeInfo(forHeader: header) { + // Header-cache hits intentionally skip without rehashing. Entries reach this cache only + // after a successful TypeDef parse and 52-bit metadata-hash validation. Do not add + // body/hash/schema-limit/exact-local checks here; the miss path owns them before publish. + try buffer.skip(bodySize) + compatibleTypeDefTypeInfos.push(cached) + return cached + } + + let cachedTypeInfo = try readTypeInfoBody(start: typeMetaStart, header: header) + compatibleTypeDefTypeInfos.push(cachedTypeInfo) + return cachedTypeInfo + } + + @inline(never) + private func readCompatibleTypeInfo( + afterMarker indexMarker: UInt32, + for localTypeInfo: TypeInfo, + wireTypeID: TypeId + ) throws -> TypeInfo { + let buffer = self.buffer + let compatibleTypeDefTypeInfos = self.compatibleTypeDefTypeInfos + let isRef = (indexMarker & 1) == 1 + let index = Int(indexMarker >> 1) + if isRef { + guard let typeInfo = compatibleTypeDefTypeInfos.get(index) else { + throw ForyError.invalidData("unknown compatible type definition ref index \(index)") + } + return try validateCompatibleTypeInfo(typeInfo, for: localTypeInfo, wireTypeID: wireTypeID) + } + + let typeMetaStart = buffer.getCursor() + let header = try buffer.readUInt64() + var bodySize = Int(header & UInt64(typeMetaSizeMask)) + if bodySize == typeMetaSizeMask { + bodySize += Int(try buffer.readVarUInt32()) + } + if let cached = typeResolver.getTypeInfo(forHeader: header) { + // Header-cache hits intentionally skip without rehashing. Entries reach this cache only + // after a successful TypeDef parse and 52-bit metadata-hash validation. Do not add + // body/hash/schema-limit/exact-local checks here; the miss path owns them before publish. + try buffer.skip(bodySize) + compatibleTypeDefTypeInfos.push(cached) + return try validateCompatibleTypeInfo(cached, for: localTypeInfo, wireTypeID: wireTypeID) + } - let typeMetaStart = buffer.getCursor() + let cachedTypeInfo = try readTypeInfoBody( + start: typeMetaStart, + header: header, + for: localTypeInfo, + wireTypeID: wireTypeID) + compatibleTypeDefTypeInfos.push(cachedTypeInfo) + return try validateCompatibleTypeInfo(cachedTypeInfo, for: localTypeInfo, wireTypeID: wireTypeID) + } + + @inline(__always) + private func readCompatibleTypeInfo( + for localTypeInfo: TypeInfo, + wireTypeID: TypeId + ) throws -> TypeInfo { + let buffer = self.buffer + let compatibleTypeDefTypeInfos = self.compatibleTypeDefTypeInfos + if compatibleTypeDefTypeInfos.isEmpty, + let localTypeDefHeader = localTypeInfo.typeDefHeader { + let indexMarker = try buffer.readVarUInt32() + if indexMarker != 0 { + return try readCompatibleTypeInfo( + afterMarker: indexMarker, + for: localTypeInfo, + wireTypeID: wireTypeID) + } else { + let headerStart = buffer.getCursor() let header = try buffer.readUInt64() var bodySize = Int(header & UInt64(typeMetaSizeMask)) if bodySize == typeMetaSizeMask { - bodySize += Int(try buffer.readVarUInt32()) + bodySize += Int(try buffer.readVarUInt32()) } - if let cached = typeResolver.getTypeInfo(forHeader: header) { - // Header-cache hits intentionally skip without rehashing. Entries reach this cache only - // after a successful TypeDef parse and 52-bit metadata-hash validation. Do not add - // body/hash/schema-limit/exact-local checks here; the miss path owns them before publish. - try buffer.skip(bodySize) - compatibleTypeDefTypeInfos.push(cached) - return try validateCompatibleTypeInfo(cached, for: localTypeInfo, wireTypeID: wireTypeID) + + if header == localTypeDefHeader { + // The declared local type owns this exact metadata header, so this is a + // local-schema hit rather than a remote cache publish. Keep it allocation-free: + // skip the body, add the local type to the per-read table, and do not parse/hash. + try buffer.skip(bodySize) + compatibleTypeDefTypeInfos.push(localTypeInfo) + return localTypeInfo } - let cachedTypeInfo = try readTypeInfoBody( - start: typeMetaStart, + if let cached = typeResolver.getTypeInfo(forHeader: header) { + // Header-cache hits intentionally skip without rehashing. Entries reach this cache only + // after a successful TypeDef parse and 52-bit metadata-hash validation. Do not add + // body/hash/schema-limit/exact-local checks here; the miss path owns them before publish. + try buffer.skip(bodySize) + compatibleTypeDefTypeInfos.push(cached) + return try validateCompatibleTypeInfo(cached, for: localTypeInfo, wireTypeID: wireTypeID) + } else { + let remoteTypeInfo = try readTypeInfoBody( + start: headerStart, header: header, for: localTypeInfo, wireTypeID: wireTypeID) - compatibleTypeDefTypeInfos.push(cachedTypeInfo) - return try validateCompatibleTypeInfo(cachedTypeInfo, for: localTypeInfo, wireTypeID: wireTypeID) - } - - @inline(__always) - private func readCompatibleTypeInfo( - for localTypeInfo: TypeInfo, - wireTypeID: TypeId - ) throws -> TypeInfo { - let buffer = self.buffer - let compatibleTypeDefTypeInfos = self.compatibleTypeDefTypeInfos - if compatibleTypeDefTypeInfos.isEmpty, - let localTypeDefHeader = localTypeInfo.typeDefHeader - { - let indexMarker = try buffer.readVarUInt32() - if indexMarker != 0 { - return try readCompatibleTypeInfo( - afterMarker: indexMarker, - for: localTypeInfo, - wireTypeID: wireTypeID) - } else { - let headerStart = buffer.getCursor() - let header = try buffer.readUInt64() - var bodySize = Int(header & UInt64(typeMetaSizeMask)) - if bodySize == typeMetaSizeMask { - bodySize += Int(try buffer.readVarUInt32()) - } - - if header == localTypeDefHeader { - // The declared local type owns this exact metadata header, so this is a - // local-schema hit rather than a remote cache publish. Keep it allocation-free: - // skip the body, add the local type to the per-read table, and do not parse/hash. - try buffer.skip(bodySize) - compatibleTypeDefTypeInfos.push(localTypeInfo) - return localTypeInfo - } - - if let cached = typeResolver.getTypeInfo(forHeader: header) { - // Header-cache hits intentionally skip without rehashing. Entries reach this cache only - // after a successful TypeDef parse and 52-bit metadata-hash validation. Do not add - // body/hash/schema-limit/exact-local checks here; the miss path owns them before publish. - try buffer.skip(bodySize) - compatibleTypeDefTypeInfos.push(cached) - return try validateCompatibleTypeInfo(cached, for: localTypeInfo, wireTypeID: wireTypeID) - } else { - let remoteTypeInfo = try readTypeInfoBody( - start: headerStart, - header: header, - for: localTypeInfo, - wireTypeID: wireTypeID) - compatibleTypeDefTypeInfos.push(remoteTypeInfo) - return try validateCompatibleTypeInfo( - remoteTypeInfo, for: localTypeInfo, wireTypeID: wireTypeID) - } - } + compatibleTypeDefTypeInfos.push(remoteTypeInfo) + return try validateCompatibleTypeInfo( + remoteTypeInfo, for: localTypeInfo, wireTypeID: wireTypeID) } - let indexMarker = try buffer.readVarUInt32() - return try readCompatibleTypeInfo( - afterMarker: indexMarker, - for: localTypeInfo, - wireTypeID: wireTypeID) + } } - - @inline(never) - private func readTypeInfoBody(start: Int, header: UInt64) throws -> TypeInfo { - buffer.setCursor(start) - let decoded = try TypeMeta.decode( - buffer, - maxTypeFields: config.maxTypeFields, - maxTypeMetaBytes: config.maxTypeMetaBytes) - let typeMetaEnd = buffer.getCursor() - let localTypeInfo = try typeResolver.requireTypeInfo(for: decoded) - return try typeResolver.cacheTypeInfo( - decoded, - forHeader: header, - localTypeInfo: localTypeInfo, - exactLocal: try matchesLocalTypeDefBytes( - localTypeInfo: localTypeInfo, - typeMeta: decoded, - start: start, - end: typeMetaEnd), - config: config - ) + let indexMarker = try buffer.readVarUInt32() + return try readCompatibleTypeInfo( + afterMarker: indexMarker, + for: localTypeInfo, + wireTypeID: wireTypeID) + } + + @inline(never) + private func readTypeInfoBody(start: Int, header: UInt64) throws -> TypeInfo { + buffer.setCursor(start) + let decoded = try TypeMeta.decode( + buffer, + maxTypeFields: config.maxTypeFields, + maxTypeMetaBytes: config.maxTypeMetaBytes) + let typeMetaEnd = buffer.getCursor() + let localTypeInfo = try typeResolver.requireTypeInfo(for: decoded) + return try typeResolver.cacheTypeInfo( + decoded, + forHeader: header, + localTypeInfo: localTypeInfo, + exactLocal: try matchesLocalTypeDefBytes( + localTypeInfo: localTypeInfo, + typeMeta: decoded, + start: start, + end: typeMetaEnd), + config: config + ) + } + + @inline(never) + private func readTypeInfoBody( + start: Int, + header: UInt64, + for localTypeInfo: TypeInfo, + wireTypeID: TypeId + ) throws -> TypeInfo { + buffer.setCursor(start) + let decoded = try TypeMeta.decode( + buffer, + maxTypeFields: config.maxTypeFields, + maxTypeMetaBytes: config.maxTypeMetaBytes) + let typeMetaEnd = buffer.getCursor() + try validateCompatibleTypeMeta(decoded, for: localTypeInfo, wireTypeID: wireTypeID) + // The typed path is owned by the declared local type. After identity validation, the + // decoded metadata must describe this same TypeInfo; do not resolve another owner here. + return try typeResolver.cacheTypeInfo( + decoded, + forHeader: header, + localTypeInfo: localTypeInfo, + exactLocal: try matchesLocalTypeDefBytes( + localTypeInfo: localTypeInfo, + typeMeta: decoded, + start: start, + end: typeMetaEnd), + config: config + ) + } + + @inline(never) + private func matchesLocalTypeDefBytes( + localTypeInfo: TypeInfo, + typeMeta: TypeMeta, + start: Int, + end: Int + ) throws -> Bool { + guard typeMeta.typeID != nil else { + return false } - - @inline(never) - private func readTypeInfoBody( - start: Int, - header: UInt64, - for localTypeInfo: TypeInfo, - wireTypeID: TypeId - ) throws -> TypeInfo { - buffer.setCursor(start) - let decoded = try TypeMeta.decode( - buffer, - maxTypeFields: config.maxTypeFields, - maxTypeMetaBytes: config.maxTypeMetaBytes) - let typeMetaEnd = buffer.getCursor() - try validateCompatibleTypeMeta(decoded, for: localTypeInfo, wireTypeID: wireTypeID) - // The typed path is owned by the declared local type. After identity validation, the - // decoded metadata must describe this same TypeInfo; do not resolve another owner here. - return try typeResolver.cacheTypeInfo( - decoded, - forHeader: header, - localTypeInfo: localTypeInfo, - exactLocal: try matchesLocalTypeDefBytes( - localTypeInfo: localTypeInfo, - typeMeta: decoded, - start: start, - end: typeMetaEnd), - config: config - ) + guard let localTypeDefBytes = localTypeInfo.typeDefBytes, + end - start == localTypeDefBytes.count else { + return false } - - @inline(never) - private func matchesLocalTypeDefBytes( - localTypeInfo: TypeInfo, - typeMeta: TypeMeta, - start: Int, - end: Int - ) throws -> Bool { - guard typeMeta.typeID != nil else { - return false - } - guard let localTypeDefBytes = localTypeInfo.typeDefBytes, - end - start == localTypeDefBytes.count - else { - return false - } - return buffer.matchesBytes(start: start, bytes: localTypeDefBytes) + return buffer.matchesBytes(start: start, bytes: localTypeDefBytes) + } + + private func validateCompatibleTypeInfo( + _ remoteTypeInfo: TypeInfo, + for localTypeInfo: TypeInfo, + wireTypeID: TypeId + ) throws -> TypeInfo { + guard let remoteTypeMeta = remoteTypeInfo.compatibleTypeMeta else { + throw ForyError.invalidData("compatible type metadata is required") } - - private func validateCompatibleTypeInfo( - _ remoteTypeInfo: TypeInfo, - for localTypeInfo: TypeInfo, - wireTypeID: TypeId - ) throws -> TypeInfo { - guard let remoteTypeMeta = remoteTypeInfo.compatibleTypeMeta else { - throw ForyError.invalidData("compatible type metadata is required") - } - try validateCompatibleTypeMeta(remoteTypeMeta, for: localTypeInfo, wireTypeID: wireTypeID) - return remoteTypeInfo - } - - @inline(__always) - private func validateCompatibleTypeMeta( - _ remoteTypeMeta: TypeMeta, - for localTypeInfo: TypeInfo, - wireTypeID: TypeId - ) throws { - if let localTypeMeta = localTypeInfo.typeMeta, - remoteTypeMeta === localTypeMeta - { - return - } - if remoteTypeMeta.registerByName { - guard localTypeInfo.registerByName else { - throw ForyError.invalidData( - "received name-registered compatible metadata for id-registered local type") - } - if remoteTypeMeta.namespace.value != localTypeInfo.namespace.value { - throw ForyError.invalidData( - "namespace mismatch: expected \(localTypeInfo.namespace.value), got \(remoteTypeMeta.namespace.value)" - ) - } - if remoteTypeMeta.typeName.value != localTypeInfo.typeName.value { - throw ForyError.invalidData( - "type name mismatch: expected \(localTypeInfo.typeName.value), got \(remoteTypeMeta.typeName.value)" - ) - } - } else { - guard !localTypeInfo.registerByName else { - throw ForyError.invalidData( - "received id-registered compatible metadata for name-registered local type") - } - guard let remoteUserTypeID = remoteTypeMeta.userTypeID else { - throw ForyError.invalidData("missing user type id in compatible type metadata") - } - guard let localUserTypeID = localTypeInfo.userTypeID else { - throw ForyError.invalidData("missing local user type id metadata for id-registered type") - } - if remoteUserTypeID != localUserTypeID { - throw ForyError.typeMismatch(expected: localUserTypeID, actual: remoteUserTypeID) - } - } - - if let remoteTypeID = remoteTypeMeta.typeID, - let remoteWireTypeID = TypeId(rawValue: remoteTypeID), - !isAllowedRegisteredWireTypeID( - remoteWireTypeID, - declaredTypeID: localTypeInfo.typeID, - registerByName: localTypeInfo.registerByName, - compatible: compatible, - evolving: localTypeInfo.evolving - ) - { - throw ForyError.typeMismatch(expected: wireTypeID.rawValue, actual: remoteTypeID) - } + try validateCompatibleTypeMeta(remoteTypeMeta, for: localTypeInfo, wireTypeID: wireTypeID) + return remoteTypeInfo + } + + @inline(__always) + private func validateCompatibleTypeMeta( + _ remoteTypeMeta: TypeMeta, + for localTypeInfo: TypeInfo, + wireTypeID: TypeId + ) throws { + if let localTypeMeta = localTypeInfo.typeMeta, + remoteTypeMeta === localTypeMeta { + return } - - func readAnyValue(typeInfo: TypeInfo) throws -> Any { - try enterDynamicAnyDepth() - defer { leaveDynamicAnyDepth() } - - let value: Any - switch typeInfo.typeID { - case .bool: - value = try Bool.foryRead(self, refMode: .none, readTypeInfo: false) - case .int8: - value = try Int8.foryRead(self, refMode: .none, readTypeInfo: false) - case .int16: - value = try Int16.foryRead(self, refMode: .none, readTypeInfo: false) - case .int32: - value = try buffer.readInt32() - case .varint32: - value = try Int32.foryRead(self, refMode: .none, readTypeInfo: false) - case .int64: - value = try buffer.readInt64() - case .varint64: - value = try Int64.foryRead(self, refMode: .none, readTypeInfo: false) - case .taggedInt64: - value = try buffer.readTaggedInt64() - case .uint8: - value = try UInt8.foryRead(self, refMode: .none, readTypeInfo: false) - case .uint16: - value = try UInt16.foryRead(self, refMode: .none, readTypeInfo: false) - case .uint32: - value = try buffer.readUInt32() - case .varUInt32: - value = try UInt32.foryRead(self, refMode: .none, readTypeInfo: false) - case .uint64: - value = try buffer.readUInt64() - case .varUInt64: - value = try UInt64.foryRead(self, refMode: .none, readTypeInfo: false) - case .taggedUInt64: - value = try buffer.readTaggedUInt64() - case .float16: - value = try Float16.foryRead(self, refMode: .none, readTypeInfo: false) - case .bfloat16: - value = try BFloat16.foryRead(self, refMode: .none, readTypeInfo: false) - case .float32: - value = try Float.foryRead(self, refMode: .none, readTypeInfo: false) - case .float64: - value = try Double.foryRead(self, refMode: .none, readTypeInfo: false) - case .string: - value = try String.foryRead(self, refMode: .none, readTypeInfo: false) - case .duration: - value = try Duration.foryRead(self, refMode: .none, readTypeInfo: false) - case .timestamp: - value = try Date.foryRead(self, refMode: .none, readTypeInfo: false) - case .date: - value = try LocalDate.foryRead(self, refMode: .none, readTypeInfo: false) - case .decimal: - value = try Decimal.foryRead(self, refMode: .none, readTypeInfo: false) - case .binary: - value = try Data.foryRead(self, refMode: .none, readTypeInfo: false) - case .boolArray: - value = try readPrimitiveArray(self) as [Bool] - case .int8Array: - value = try readPrimitiveArray(self) as [Int8] - case .int16Array: - value = try readPrimitiveArray(self) as [Int16] - case .int32Array: - value = try readPrimitiveArray(self) as [Int32] - case .int64Array: - value = try readPrimitiveArray(self) as [Int64] - case .uint8Array: - value = try readPrimitiveArray(self) as [UInt8] - case .uint16Array: - value = try readPrimitiveArray(self) as [UInt16] - case .uint32Array: - value = try readPrimitiveArray(self) as [UInt32] - case .uint64Array: - value = try readPrimitiveArray(self) as [UInt64] - case .float16Array: - value = try readPrimitiveArray(self) as [Float16] - case .bfloat16Array: - value = try readPrimitiveArray(self) as [BFloat16] - case .float32Array: - value = try readPrimitiveArray(self) as [Float] - case .float64Array: - value = try readPrimitiveArray(self) as [Double] - case .array, .list: - value = try readListOfAny(refMode: .none) ?? [] - case .set: - value = try Set.foryRead(self, refMode: .none, readTypeInfo: false) - case .map: - value = try readDynamicAnyMapValue(context: self) - case .none: - value = ForyAnyNullValue() - default: - if typeInfo.typeID.isUserTypeKind { - value = try typeInfo.read(self) - } else { - throw ForyError.invalidData("unsupported dynamic type id \(typeInfo.typeID)") - } - } - return value + if remoteTypeMeta.registerByName { + guard localTypeInfo.registerByName else { + throw ForyError.invalidData( + "received name-registered compatible metadata for id-registered local type") + } + if remoteTypeMeta.namespace.value != localTypeInfo.namespace.value { + throw ForyError.invalidData( + "namespace mismatch: expected \(localTypeInfo.namespace.value), got \(remoteTypeMeta.namespace.value)" + ) + } + if remoteTypeMeta.typeName.value != localTypeInfo.typeName.value { + throw ForyError.invalidData( + "type name mismatch: expected \(localTypeInfo.typeName.value), got \(remoteTypeMeta.typeName.value)" + ) + } + } else { + guard !localTypeInfo.registerByName else { + throw ForyError.invalidData( + "received id-registered compatible metadata for name-registered local type") + } + guard let remoteUserTypeID = remoteTypeMeta.userTypeID else { + throw ForyError.invalidData("missing user type id in compatible type metadata") + } + guard let localUserTypeID = localTypeInfo.userTypeID else { + throw ForyError.invalidData("missing local user type id metadata for id-registered type") + } + if remoteUserTypeID != localUserTypeID { + throw ForyError.typeMismatch(expected: localUserTypeID, actual: remoteUserTypeID) + } } - @inline(__always) - func getTypeInfo(for type: T.Type) -> TypeInfo? { - typeInfoStack.value(for: UInt64(UInt(bitPattern: ObjectIdentifier(type)))) + if let remoteTypeID = remoteTypeMeta.typeID, + let remoteWireTypeID = TypeId(rawValue: remoteTypeID), + !isAllowedRegisteredWireTypeID( + remoteWireTypeID, + declaredTypeID: localTypeInfo.typeID, + registerByName: localTypeInfo.registerByName, + compatible: compatible, + evolving: localTypeInfo.evolving + ) { + throw ForyError.typeMismatch(expected: wireTypeID.rawValue, actual: remoteTypeID) + } + } + + func readAnyValue(typeInfo: TypeInfo) throws -> Any { + try enterDynamicAnyDepth() + defer { leaveDynamicAnyDepth() } + + let value: Any + switch typeInfo.typeID { + case .bool: + value = try Bool.foryRead(self, refMode: .none, readTypeInfo: false) + case .int8: + value = try Int8.foryRead(self, refMode: .none, readTypeInfo: false) + case .int16: + value = try Int16.foryRead(self, refMode: .none, readTypeInfo: false) + case .int32: + value = try buffer.readInt32() + case .varint32: + value = try Int32.foryRead(self, refMode: .none, readTypeInfo: false) + case .int64: + value = try buffer.readInt64() + case .varint64: + value = try Int64.foryRead(self, refMode: .none, readTypeInfo: false) + case .taggedInt64: + value = try buffer.readTaggedInt64() + case .uint8: + value = try UInt8.foryRead(self, refMode: .none, readTypeInfo: false) + case .uint16: + value = try UInt16.foryRead(self, refMode: .none, readTypeInfo: false) + case .uint32: + value = try buffer.readUInt32() + case .varUInt32: + value = try UInt32.foryRead(self, refMode: .none, readTypeInfo: false) + case .uint64: + value = try buffer.readUInt64() + case .varUInt64: + value = try UInt64.foryRead(self, refMode: .none, readTypeInfo: false) + case .taggedUInt64: + value = try buffer.readTaggedUInt64() + case .float16: + value = try Float16.foryRead(self, refMode: .none, readTypeInfo: false) + case .bfloat16: + value = try BFloat16.foryRead(self, refMode: .none, readTypeInfo: false) + case .float32: + value = try Float.foryRead(self, refMode: .none, readTypeInfo: false) + case .float64: + value = try Double.foryRead(self, refMode: .none, readTypeInfo: false) + case .string: + value = try String.foryRead(self, refMode: .none, readTypeInfo: false) + case .duration: + value = try Duration.foryRead(self, refMode: .none, readTypeInfo: false) + case .timestamp: + value = try Date.foryRead(self, refMode: .none, readTypeInfo: false) + case .date: + value = try LocalDate.foryRead(self, refMode: .none, readTypeInfo: false) + case .decimal: + value = try Decimal.foryRead(self, refMode: .none, readTypeInfo: false) + case .binary: + value = try Data.foryRead(self, refMode: .none, readTypeInfo: false) + case .boolArray: + value = try readPrimitiveArray(self) as [Bool] + case .int8Array: + value = try readPrimitiveArray(self) as [Int8] + case .int16Array: + value = try readPrimitiveArray(self) as [Int16] + case .int32Array: + value = try readPrimitiveArray(self) as [Int32] + case .int64Array: + value = try readPrimitiveArray(self) as [Int64] + case .uint8Array: + value = try readPrimitiveArray(self) as [UInt8] + case .uint16Array: + value = try readPrimitiveArray(self) as [UInt16] + case .uint32Array: + value = try readPrimitiveArray(self) as [UInt32] + case .uint64Array: + value = try readPrimitiveArray(self) as [UInt64] + case .float16Array: + value = try readPrimitiveArray(self) as [Float16] + case .bfloat16Array: + value = try readPrimitiveArray(self) as [BFloat16] + case .float32Array: + value = try readPrimitiveArray(self) as [Float] + case .float64Array: + value = try readPrimitiveArray(self) as [Double] + case .array, .list: + value = try readListOfAny(refMode: .none) ?? [] + case .set: + value = try Set.foryRead(self, refMode: .none, readTypeInfo: false) + case .map: + value = try readDynamicAnyMapValue(context: self) + case .none: + value = ForyAnyNullValue() + default: + if typeInfo.typeID.isUserTypeKind { + value = try typeInfo.read(self) + } else { + throw ForyError.invalidData("unsupported dynamic type id \(typeInfo.typeID)") + } + } + return value + } + + @inline(__always) + func getTypeInfo(for type: T.Type) -> TypeInfo? { + typeInfoStack.value(for: UInt64(UInt(bitPattern: ObjectIdentifier(type)))) + } + + func withTypeInfo( + _ typeInfo: TypeInfo?, + for type: T.Type, + _ body: () throws -> R + ) rethrows -> R { + guard let typeInfo else { + return try body() } - func withTypeInfo( - _ typeInfo: TypeInfo?, - for type: T.Type, - _ body: () throws -> R - ) rethrows -> R { - guard let typeInfo else { - return try body() - } - - let typeKey = UInt64(UInt(bitPattern: ObjectIdentifier(type))) - let previousTypeInfo = typeInfoStack.value(for: typeKey) - typeInfoScopeStack.append((typeKey: typeKey, previousTypeInfo: previousTypeInfo)) - typeInfoStack.set(typeInfo, for: typeKey) - defer { - if let scope = typeInfoScopeStack.popLast() { - if let previousTypeInfo = scope.previousTypeInfo { - typeInfoStack.set(previousTypeInfo, for: scope.typeKey) - } else { - _ = typeInfoStack.removeValue(for: scope.typeKey) - } - } else { - assertionFailure("type info scope stack underflow") - } + let typeKey = UInt64(UInt(bitPattern: ObjectIdentifier(type))) + let previousTypeInfo = typeInfoStack.value(for: typeKey) + typeInfoScopeStack.append((typeKey: typeKey, previousTypeInfo: previousTypeInfo)) + typeInfoStack.set(typeInfo, for: typeKey) + defer { + if let scope = typeInfoScopeStack.popLast() { + if let previousTypeInfo = scope.previousTypeInfo { + typeInfoStack.set(previousTypeInfo, for: scope.typeKey) + } else { + _ = typeInfoStack.removeValue(for: scope.typeKey) } - return try body() + } else { + assertionFailure("type info scope stack underflow") + } } - - @inline(__always) - func getReadMetaString(at index: Int) -> MetaString? { - metaStrings.get(index) + return try body() + } + + @inline(__always) + func getReadMetaString(at index: Int) -> MetaString? { + metaStrings.get(index) + } + + @inline(__always) + func appendReadMetaString(_ value: MetaString) { + metaStrings.push(value) + } + + func reset() { + if dynamicAnyDepth != 0 { + dynamicAnyDepth = 0 } - - @inline(__always) - func appendReadMetaString(_ value: MetaString) { - metaStrings.push(value) + refReader.reset() + if !typeInfoStack.isEmpty { + typeInfoStack.clear() } - - func reset() { - if dynamicAnyDepth != 0 { - dynamicAnyDepth = 0 - } - refReader.reset() - if !typeInfoStack.isEmpty { - typeInfoStack.clear() - } - if !typeInfoScopeStack.isEmpty { - typeInfoScopeStack.removeAll(keepingCapacity: true) - } - compatibleTypeDefTypeInfos.reset() - metaStrings.reset() + if !typeInfoScopeStack.isEmpty { + typeInfoScopeStack.removeAll(keepingCapacity: true) } + compatibleTypeDefTypeInfos.reset() + metaStrings.reset() + } } extension ReadContext { - public func readAny( - refMode: RefMode, - readTypeInfo: Bool = true - ) throws -> Any? { - try SerializableAny.foryRead(self, refMode: refMode, readTypeInfo: readTypeInfo).anyValue() - } - - public func readListOfAny( - refMode: RefMode, - readTypeInfo: Bool = false - ) throws -> [Any]? { - let wrapped: [SerializableAny]? = try [SerializableAny]?.foryRead( - self, - refMode: refMode, - readTypeInfo: readTypeInfo - ) - return wrapped?.map { $0.anyValueForCollection() } + public func readAny( + refMode: RefMode, + readTypeInfo: Bool = true + ) throws -> Any? { + try SerializableAny.foryRead(self, refMode: refMode, readTypeInfo: readTypeInfo).anyValue() + } + + public func readListOfAny( + refMode: RefMode, + readTypeInfo: Bool = false + ) throws -> [Any]? { + let wrapped: [SerializableAny]? = try [SerializableAny]?.foryRead( + self, + refMode: refMode, + readTypeInfo: readTypeInfo + ) + return wrapped?.map { $0.anyValueForCollection() } + } + + public func readMapStringToAny( + refMode: RefMode, + readTypeInfo: Bool = false + ) throws -> [String: Any]? { + let wrapped: [String: SerializableAny]? = try [String: SerializableAny]?.foryRead( + self, + refMode: refMode, + readTypeInfo: readTypeInfo + ) + guard let wrapped else { + return nil } - - public func readMapStringToAny( - refMode: RefMode, - readTypeInfo: Bool = false - ) throws -> [String: Any]? { - let wrapped: [String: SerializableAny]? = try [String: SerializableAny]?.foryRead( - self, - refMode: refMode, - readTypeInfo: readTypeInfo - ) - guard let wrapped else { - return nil - } - var map: [String: Any] = [:] - map.reserveCapacity(wrapped.count) - for pair in wrapped { - map[pair.key] = pair.value.anyValueForCollection() - } - return map + var map: [String: Any] = [:] + map.reserveCapacity(wrapped.count) + for pair in wrapped { + map[pair.key] = pair.value.anyValueForCollection() } - - public func readMapInt32ToAny( - refMode: RefMode, - readTypeInfo: Bool = false - ) throws -> [Int32: Any]? { - let wrapped: [Int32: SerializableAny]? = try [Int32: SerializableAny]?.foryRead( - self, - refMode: refMode, - readTypeInfo: readTypeInfo - ) - guard let wrapped else { - return nil - } - var map: [Int32: Any] = [:] - map.reserveCapacity(wrapped.count) - for pair in wrapped { - map[pair.key] = pair.value.anyValueForCollection() - } - return map + return map + } + + public func readMapInt32ToAny( + refMode: RefMode, + readTypeInfo: Bool = false + ) throws -> [Int32: Any]? { + let wrapped: [Int32: SerializableAny]? = try [Int32: SerializableAny]?.foryRead( + self, + refMode: refMode, + readTypeInfo: readTypeInfo + ) + guard let wrapped else { + return nil } - - public func readMapAnyHashableToAny( - refMode: RefMode, - readTypeInfo: Bool = false - ) throws -> [AnyHashable: Any]? { - let wrapped: [AnyHashable: SerializableAny]? = try [AnyHashable: SerializableAny]?.foryRead( - self, - refMode: refMode, - readTypeInfo: readTypeInfo - ) - guard let wrapped else { - return nil - } - var map: [AnyHashable: Any] = [:] - map.reserveCapacity(wrapped.count) - for pair in wrapped { - map[pair.key] = pair.value.anyValueForCollection() - } - return map + var map: [Int32: Any] = [:] + map.reserveCapacity(wrapped.count) + for pair in wrapped { + map[pair.key] = pair.value.anyValueForCollection() + } + return map + } + + public func readMapAnyHashableToAny( + refMode: RefMode, + readTypeInfo: Bool = false + ) throws -> [AnyHashable: Any]? { + let wrapped: [AnyHashable: SerializableAny]? = try [AnyHashable: SerializableAny]?.foryRead( + self, + refMode: refMode, + readTypeInfo: readTypeInfo + ) + guard let wrapped else { + return nil + } + var map: [AnyHashable: Any] = [:] + map.reserveCapacity(wrapped.count) + for pair in wrapped { + map[pair.key] = pair.value.anyValueForCollection() } + return map + } } diff --git a/swift/Tests/ForyTests/ContainerMemoryBudgetTests.swift b/swift/Tests/ForyTests/ContainerMemoryBudgetTests.swift new file mode 100644 index 0000000000..3650f55f48 --- /dev/null +++ b/swift/Tests/ForyTests/ContainerMemoryBudgetTests.swift @@ -0,0 +1,232 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +import Foundation +import Testing +@testable import Fory + +@ForyStruct +private final class BudgetNode { + var id: Int32 = 0 + + required init() {} + + init(id: Int32) { + self.id = id + } +} + +@ForyStruct +private struct BudgetSiblings { + var left: [BudgetNode] = [] + var right: [BudgetNode] = [] +} + +@ForyStruct +private struct BudgetDenseHolder: Equatable { + var text: String = "" + var data: Data = Data() + @ArrayField(element: .int32()) + var dense: [Int32] = [] +} + +private func makeBudgetFory(maxContainerMemoryBytes: Int64 = -1) -> Fory { + let fory = Fory(config: .init( + trackRef: false, + compatible: false, + maxContainerMemoryBytes: maxContainerMemoryBytes + )) + fory.register(BudgetNode.self, id: 9801) + fory.register(BudgetSiblings.self, id: 9802) + fory.register(BudgetDenseHolder.self, id: 9803) + return fory +} + +private func elementBytes(_ type: Element.Type) -> Int { + type.isRefType ? ReadContext.referenceBytes : max(1, MemoryLayout.stride) +} + +private func arrayBudget(_ type: Element.Type, count: Int) -> Int { + if count == 0 { + return ReadContext.containerFixedBytes + } + return ReadContext.containerFixedBytes + ReadContext.arrayHeaderBytes + + count * elementBytes(type) +} + +private func mapBudget( + key: Key.Type, + value: Value.Type, + count: Int +) -> Int { + if count == 0 { + return ReadContext.containerFixedBytes + } + return ReadContext.containerFixedBytes + ReadContext.arrayHeaderBytes * 2 + + count * ( + elementBytes(key) + elementBytes(value) + + ReadContext.mapEntryOverheadBytes + ReadContext.referenceBytes + ) +} + +private func expectInvalidData(_ body: () throws -> Void) { + do { + try body() + Issue.record("expected invalid data") + } catch ForyError.invalidData { + } catch { + Issue.record("expected invalid data, got \(error)") + } +} + +@Test +func knownLengthAutoBudgetRejectsNestedEmptyArrays() throws { + let count = 6_000 + let value = Array(repeating: [String](), count: count) + let bytes = try makeBudgetFory().serialize(value) + let autoLimit = bytes.count * 8 + ReadContext.knownContainerBudgetSlackBytes + let required = arrayBudget([String].self, count: count) + + count * arrayBudget(String.self, count: 0) + #expect(required > autoLimit) + + expectInvalidData { + let _: [[String]] = try makeBudgetFory().deserialize(bytes) + } + + let decoded: [[String]] = try makeBudgetFory(maxContainerMemoryBytes: Int64(required)).deserialize(bytes) + #expect(decoded.count == count) +} + +@Test +func byteBufferRootUsesKnownLengthAutoBudget() throws { + let count = 6_000 + let value = Array(repeating: [String](), count: count) + let bytes = try makeBudgetFory().serialize(value) + let buffer = ByteBuffer(data: bytes) + + expectInvalidData { + let _: [[String]] = try makeBudgetFory().deserialize(from: buffer) + } +} + +@Test +func explicitConfigOverridesAutoBudget() throws { + let values = (0..<16).map(Int32.init) + let bytes = try makeBudgetFory().serialize(values) + let required = arrayBudget(Int32.self, count: values.count) + + expectInvalidData { + let _: [Int32] = try makeBudgetFory(maxContainerMemoryBytes: Int64(required - 1)).deserialize(bytes) + } + let decoded: [Int32] = try makeBudgetFory(maxContainerMemoryBytes: Int64(required)).deserialize(bytes) + #expect(decoded == values) +} + +@Test +func siblingContainersShareOneBudget() throws { + let value = BudgetSiblings( + left: (0..<16).map { BudgetNode(id: Int32($0)) }, + right: (16..<32).map { BudgetNode(id: Int32($0)) } + ) + let bytes = try makeBudgetFory().serialize(value) + let oneList = arrayBudget(BudgetNode.self, count: 16) + + expectInvalidData { + let _: BudgetSiblings = try makeBudgetFory(maxContainerMemoryBytes: Int64(oneList)).deserialize(bytes) + } + let decoded: BudgetSiblings = try makeBudgetFory(maxContainerMemoryBytes: Int64(oneList * 2)).deserialize(bytes) + #expect(decoded.left.count == 16) + #expect(decoded.right.count == 16) +} + +@Test +func mapBudgetIsCharged() throws { + let value: [String: Int32] = ["a": 1, "b": 2, "c": 3] + let bytes = try makeBudgetFory().serialize(value) + let required = mapBudget(key: String.self, value: Int32.self, count: value.count) + + expectInvalidData { + let _: [String: Int32] = try makeBudgetFory(maxContainerMemoryBytes: Int64(required - 1)).deserialize(bytes) + } + let decoded: [String: Int32] = try makeBudgetFory(maxContainerMemoryBytes: Int64(required)).deserialize(bytes) + #expect(decoded == value) +} + +@Test +func referenceAndInlineValueArraysAreCharged() throws { + let nodes = (0..<4).map { BudgetNode(id: Int32($0)) } + let nodeBytes = try makeBudgetFory().serialize(nodes) + let nodeBudget = arrayBudget(BudgetNode.self, count: nodes.count) + expectInvalidData { + let _: [BudgetNode] = try makeBudgetFory(maxContainerMemoryBytes: Int64(nodeBudget - 1)).deserialize(nodeBytes) + } + let decodedNodes: [BudgetNode] = try makeBudgetFory(maxContainerMemoryBytes: Int64(nodeBudget)).deserialize(nodeBytes) + #expect(decodedNodes.count == nodes.count) + + let ints: [Int32] = [1, 2, 3, 4] + let intBytes = try makeBudgetFory().serialize(ints) + let intBudget = arrayBudget(Int32.self, count: ints.count) + expectInvalidData { + let _: [Int32] = try makeBudgetFory(maxContainerMemoryBytes: Int64(intBudget - 1)).deserialize(intBytes) + } + #expect(try makeBudgetFory(maxContainerMemoryBytes: Int64(intBudget)).deserialize(intBytes) as [Int32] == ints) +} + +@Test +func stringBinaryAndPrimitiveDenseArrayOwnersAreSkipped() throws { + let value = BudgetDenseHolder( + text: "budget", + data: Data([1, 2, 3]), + dense: [1, 2, 3] + ) + let bytes = try makeBudgetFory().serialize(value) + + let decoded: BudgetDenseHolder = try makeBudgetFory(maxContainerMemoryBytes: 1).deserialize(bytes) + #expect(decoded == value) +} + +@Test +func dynamicAnyEmptyMapChargesFixedCost() throws { + let value = [:] as [AnyHashable: Any] + let bytes = try makeBudgetFory().serialize(value as Any) + let required = ReadContext.containerFixedBytes * 3 + + expectInvalidData { + let _: Any = try makeBudgetFory(maxContainerMemoryBytes: Int64(required - 1)) + .deserialize(bytes) + } + let decoded: Any = try makeBudgetFory(maxContainerMemoryBytes: Int64(required)) + .deserialize(bytes) + #expect((decoded as? [String: Any])?.isEmpty == true) +} + +@Test +func byteAvailabilityCheckStillRejectsLargeLength() throws { + let buffer = ByteBuffer() + buffer.writeVarUInt32(64) + buffer.writeUInt8(CollectionHeader.sameType | CollectionHeader.declaredElementType) + let config = Config(trackRef: false, compatible: false) + let context = ReadContext( + buffer: buffer, + typeResolver: TypeResolver(config: config), + config: config + ) + + expectInvalidData { + let _: [String] = try [String].foryReadData(context) + } +} diff --git a/swift/Tests/ForyTests/ForySwiftTests.swift b/swift/Tests/ForyTests/ForySwiftTests.swift index 897c881968..c81b70b30f 100644 --- a/swift/Tests/ForyTests/ForySwiftTests.swift +++ b/swift/Tests/ForyTests/ForySwiftTests.swift @@ -22,1713 +22,1716 @@ import Testing @ForyStruct struct Address: Equatable { - var street: String - var zip: Int32 + var street: String + var zip: Int32 } @ForyStruct struct Person: Equatable { - var id: Int64 - var name: String - var nickname: String? - var scores: [Int32] - var tags: Set - var addresses: [Address] - var metadata: [Int8: Int32?] + var id: Int64 + var name: String + var nickname: String? + var scores: [Int32] + var tags: Set + var addresses: [Address] + var metadata: [Int8: Int32?] } @ForyStruct struct FieldOrder: Equatable { - var textTail: String - var longValue: Int64 - var shortValue: Int16 - var intValue: Int32 + var textTail: String + var longValue: Int64 + var shortValue: Int16 + var intValue: Int32 } @ForyStruct struct TaggedFieldOrder: Equatable { - @ForyField(id: 1) - var textTail: String + @ForyField(id: 1) + var textTail: String - @ForyField(id: 10) - var intValue: Int32 + @ForyField(id: 10) + var intValue: Int32 } @ForyStruct struct NonPrimitiveFieldOrder: Equatable { - @ForyField(id: 20) - var stringValue: String + @ForyField(id: 20) + var stringValue: String - @ForyField(id: 10) - var mapValue: [String: Int32] + @ForyField(id: 10) + var mapValue: [String: Int32] - var binaryValue: Data - var addressValue: Address - var intValue: Int32 + var binaryValue: Data + var addressValue: Address + var intValue: Int32 } @ForyStruct struct EncodedNumberFields: Equatable { - @ForyField(encoding: .fixed) - var u32Fixed: UInt32 + @ForyField(encoding: .fixed) + var u32Fixed: UInt32 - @ForyField(encoding: .tagged) - var u64Tagged: UInt64 + @ForyField(encoding: .tagged) + var u64Tagged: UInt64 } @ForyStruct struct ReducedPrecisionMacroFields: Equatable { - var float16Value: Float16 - var bfloat16Value: BFloat16 - @ArrayField(element: .float16) - var float16Array: [Float16] - @ArrayField(element: .bfloat16) - var bfloat16Array: [BFloat16] + var float16Value: Float16 + var bfloat16Value: BFloat16 + @ArrayField(element: .float16) + var float16Array: [Float16] + @ArrayField(element: .bfloat16) + var bfloat16Array: [BFloat16] } @ForyStruct struct FieldIdConfigured: Equatable { - @ForyField(id: 2) - var stableID: Int32 + @ForyField(id: 2) + var stableID: Int32 - @ForyField(id: 5, encoding: .fixed) - var fixedValue: Int32 + @ForyField(id: 5, encoding: .fixed) + var fixedValue: Int32 } @ForyStruct struct FieldIdSource: Equatable { - @ForyField(id: 1) - var value: Int32 + @ForyField(id: 1) + var value: Int32 - @ForyField(id: 4) - var label: String + @ForyField(id: 4) + var label: String } @ForyStruct struct FieldIdTarget: Equatable { - @ForyField(id: 1) - var renamedValue: Int32 + @ForyField(id: 1) + var renamedValue: Int32 - @ForyField(id: 4) - var renamedLabel: String + @ForyField(id: 4) + var renamedLabel: String } @ForyEnum enum SparseStatus: Int32, CaseIterable { - case unknown = 4096 - case ok = 8192 + case unknown = 4096 + case ok = 8192 } @ForyStruct struct EvolvingOverrideValue: Equatable { - var f1: String = "" + var f1: String = "" } @ForyStruct(evolving: false) struct FixedOverrideValue: Equatable { - var f1: String = "" + var f1: String = "" } @ForyUnion enum FieldIdUnionSource: Equatable { - @ForyUnknownCase - case unknown(UnknownCase) + @ForyUnknownCase + case unknown(UnknownCase) - @ForyCase(id: 3) - case number(Int32) + @ForyCase(id: 3) + case number(Int32) - @ForyCase(id: 9) - case text(String) + @ForyCase(id: 9) + case text(String) } @ForyUnion enum FieldIdUnionTarget: Equatable { - @ForyUnknownCase - case unknown(UnknownCase) + @ForyUnknownCase + case unknown(UnknownCase) - @ForyCase(id: 3) - case renamedNumber(Int32) + @ForyCase(id: 3) + case renamedNumber(Int32) - @ForyCase(id: 9) - case renamedText(String) + @ForyCase(id: 9) + case renamedText(String) } @ForyStruct struct CompatibleNestedItem: Equatable { - var id: Int32 - var name: String + var id: Int32 + var name: String } @ForyStruct struct CompatibleNestedArrayHolder: Equatable { - var items: [CompatibleNestedItem] + var items: [CompatibleNestedItem] } @ForyStruct struct CompatibleNestedOptionalArrayHolder: Equatable { - var items: [CompatibleNestedItem?] + var items: [CompatibleNestedItem?] } @ForyStruct struct CompatibleNestedMapHolder: Equatable { - var items: [Int32: CompatibleNestedItem] + var items: [Int32: CompatibleNestedItem] } struct LateMetaExt: Serializer, Equatable { - var value: Int32 = 0 + var value: Int32 = 0 - static func foryDefault() -> LateMetaExt { - LateMetaExt() - } + static func foryDefault() -> LateMetaExt { + LateMetaExt() + } - static var staticTypeId: TypeId { - .ext - } + static var staticTypeId: TypeId { + .ext + } - func foryWriteData(_ context: WriteContext, hasGenerics _: Bool) throws { - context.buffer.writeVarInt32(value) - } + func foryWriteData(_ context: WriteContext, hasGenerics _: Bool) throws { + context.buffer.writeVarInt32(value) + } - static func foryReadData(_ context: ReadContext) throws -> LateMetaExt { - LateMetaExt(value: try context.buffer.readVarInt32()) - } + static func foryReadData(_ context: ReadContext) throws -> LateMetaExt { + LateMetaExt(value: try context.buffer.readVarInt32()) + } } @ForyStruct struct LateMetaHolder: Equatable { - var ext: LateMetaExt + var ext: LateMetaExt } @ForyStruct final class Node { - var value: Int32 = 0 - var next: Node? + var value: Int32 = 0 + var next: Node? - required init() {} + required init() {} - init(value: Int32, next: Node? = nil) { - self.value = value - self.next = next - } + init(value: Int32, next: Node? = nil) { + self.value = value + self.next = next + } } @ForyStruct final class WeakNode { - var value: Int32 = 0 - weak var next: WeakNode? + var value: Int32 = 0 + weak var next: WeakNode? - required init() {} + required init() {} - init(value: Int32, next: WeakNode? = nil) { - self.value = value - self.next = next - } + init(value: Int32, next: WeakNode? = nil) { + self.value = value + self.next = next + } } @ForyStruct struct AnyObjectHolder { - var value: AnyObject - var optionalValue: AnyObject? - var items: [AnyObject] + var value: AnyObject + var optionalValue: AnyObject? + var items: [AnyObject] } @ForyStruct struct AnySerializerHolder { - var value: any Serializer - var items: [any Serializer] - var map: [String: any Serializer] + var value: any Serializer + var items: [any Serializer] + var map: [String: any Serializer] } @ForyStruct struct AnyFieldHolder { - var value: Any - var optionalValue: Any? - var list: [Any] - var stringMap: [String: Any] - var int32Map: [Int32: Any] + var value: Any + var optionalValue: Any? + var list: [Any] + var stringMap: [String: Any] + var int32Map: [Int32: Any] } @Test func primitiveRoundTrip() throws { - let fory = Fory() + let fory = Fory() - let boolData = try fory.serialize(true) - let boolValue: Bool = try fory.deserialize(boolData) - #expect(boolValue == true) + let boolData = try fory.serialize(true) + let boolValue: Bool = try fory.deserialize(boolData) + #expect(boolValue == true) - let int32Data = try fory.serialize(Int32(-123456)) - let int32Value: Int32 = try fory.deserialize(int32Data) - #expect(int32Value == -123456) + let int32Data = try fory.serialize(Int32(-123456)) + let int32Value: Int32 = try fory.deserialize(int32Data) + #expect(int32Value == -123456) - let int64Data = try fory.serialize(Int64(9_223_372_036_854_775_000)) - let int64Value: Int64 = try fory.deserialize(int64Data) - #expect(int64Value == 9_223_372_036_854_775_000) + let int64Data = try fory.serialize(Int64(9_223_372_036_854_775_000)) + let int64Value: Int64 = try fory.deserialize(int64Data) + #expect(int64Value == 9_223_372_036_854_775_000) - let uint32Data = try fory.serialize(UInt32(123456)) - let uint32Value: UInt32 = try fory.deserialize(uint32Data) - #expect(uint32Value == 123456) + let uint32Data = try fory.serialize(UInt32(123456)) + let uint32Value: UInt32 = try fory.deserialize(uint32Data) + #expect(uint32Value == 123456) - let uint64Data = try fory.serialize(UInt64(9_223_372_036_854_775_000)) - let uint64Value: UInt64 = try fory.deserialize(uint64Data) - #expect(uint64Value == 9_223_372_036_854_775_000) + let uint64Data = try fory.serialize(UInt64(9_223_372_036_854_775_000)) + let uint64Value: UInt64 = try fory.deserialize(uint64Data) + #expect(uint64Value == 9_223_372_036_854_775_000) - let floatData = try fory.serialize(Float(3.25)) - let floatValue: Float = try fory.deserialize(floatData) - #expect(floatValue == 3.25) + let floatData = try fory.serialize(Float(3.25)) + let floatValue: Float = try fory.deserialize(floatData) + #expect(floatValue == 3.25) - let doubleData = try fory.serialize(Double(3.1415926)) - let doubleValue: Double = try fory.deserialize(doubleData) - #expect(doubleValue == 3.1415926) + let doubleData = try fory.serialize(Double(3.1415926)) + let doubleValue: Double = try fory.deserialize(doubleData) + #expect(doubleValue == 3.1415926) - let stringData = try fory.serialize("hello_fory") - let stringValue: String = try fory.deserialize(stringData) - #expect(stringValue == "hello_fory") + let stringData = try fory.serialize("hello_fory") + let stringValue: String = try fory.deserialize(stringData) + #expect(stringValue == "hello_fory") - let binary = Data([0x01, 0x02, 0x03, 0xFF]) - let binaryData = try fory.serialize(binary) - let binaryValue: Data = try fory.deserialize(binaryData) - #expect(binaryValue == binary) + let binary = Data([0x01, 0x02, 0x03, 0xFF]) + let binaryData = try fory.serialize(binary) + let binaryValue: Data = try fory.deserialize(binaryData) + #expect(binaryValue == binary) } @Test func extendedWireTypesRoundTrip() throws { - let fory = Fory() + let fory = Fory() - let float16Value = Float16(3.5) - let float16Data = try fory.serialize(float16Value) - let float16Decoded: Float16 = try fory.deserialize(float16Data) - #expect(float16Decoded.bitPattern == float16Value.bitPattern) + let float16Value = Float16(3.5) + let float16Data = try fory.serialize(float16Value) + let float16Decoded: Float16 = try fory.deserialize(float16Data) + #expect(float16Decoded.bitPattern == float16Value.bitPattern) - let bfloatValue = BFloat16(rawValue: 0x3F80) - let bfloatData = try fory.serialize(bfloatValue) - let bfloatDecoded: BFloat16 = try fory.deserialize(bfloatData) - #expect(bfloatDecoded == bfloatValue) + let bfloatValue = BFloat16(rawValue: 0x3F80) + let bfloatData = try fory.serialize(bfloatValue) + let bfloatDecoded: BFloat16 = try fory.deserialize(bfloatData) + #expect(bfloatDecoded == bfloatValue) - let durationValue = Duration.seconds(-2) + Duration.nanoseconds(123_456_789) - let durationData = try fory.serialize(durationValue) - let durationDecoded: Duration = try fory.deserialize(durationData) - #expect(durationDecoded == durationValue) + let durationValue = Duration.seconds(-2) + Duration.nanoseconds(123_456_789) + let durationData = try fory.serialize(durationValue) + let durationDecoded: Duration = try fory.deserialize(durationData) + #expect(durationDecoded == durationValue) - let float16Array: [Float16] = [Float16(1), Float16(-2), Float16(4.5)] - let float16ArrayData = try fory.serialize(float16Array) - let float16ArrayDecoded: [Float16] = try fory.deserialize(float16ArrayData) - #expect(float16ArrayDecoded.map(\.bitPattern) == float16Array.map(\.bitPattern)) + let float16Array: [Float16] = [Float16(1), Float16(-2), Float16(4.5)] + let float16ArrayData = try fory.serialize(float16Array) + let float16ArrayDecoded: [Float16] = try fory.deserialize(float16ArrayData) + #expect(float16ArrayDecoded.map(\.bitPattern) == float16Array.map(\.bitPattern)) } @Test func floatingSpecialsRoundTrip() throws { - let fory = Fory() - - let floatValues: [Float] = [ - 0.0, - -0.0, - .infinity, - -.infinity, - .leastNonzeroMagnitude, - .greatestFiniteMagnitude, - Float(bitPattern: 0x7FC0_1234) - ] - for value in floatValues { - let decoded: Float = try fory.deserialize(try fory.serialize(value)) - #expect(decoded.bitPattern == value.bitPattern) - } - - let doubleValues: [Double] = [ - 0.0, - -0.0, - .infinity, - -.infinity, - .leastNonzeroMagnitude, - .greatestFiniteMagnitude, - Double(bitPattern: 0x7FF8_0000_0000_1234) - ] - for value in doubleValues { - let decoded: Double = try fory.deserialize(try fory.serialize(value)) - #expect(decoded.bitPattern == value.bitPattern) - } - - let float16Values: [Float16] = [ - .init(bitPattern: 0x0000), - .init(bitPattern: 0x8000), - .init(bitPattern: 0x7C00), - .init(bitPattern: 0xFC00), - .init(bitPattern: 0x0001), - .init(bitPattern: 0x7BFF), - .init(bitPattern: 0x7E11) - ] - for value in float16Values { - let decoded: Float16 = try fory.deserialize(try fory.serialize(value)) - #expect(decoded.bitPattern == value.bitPattern) - } - - let bfloat16Values: [BFloat16] = [ - .init(rawValue: 0x0000), - .init(rawValue: 0x8000), - .init(rawValue: 0x7F80), - .init(rawValue: 0xFF80), - .init(rawValue: 0x0001), - .init(rawValue: 0x7FC1) - ] - for value in bfloat16Values { - let decoded: BFloat16 = try fory.deserialize(try fory.serialize(value)) - #expect(decoded.rawValue == value.rawValue) - } + let fory = Fory() + + let floatValues: [Float] = [ + 0.0, + -0.0, + .infinity, + -.infinity, + .leastNonzeroMagnitude, + .greatestFiniteMagnitude, + Float(bitPattern: 0x7FC0_1234) + ] + for value in floatValues { + let decoded: Float = try fory.deserialize(try fory.serialize(value)) + #expect(decoded.bitPattern == value.bitPattern) + } + + let doubleValues: [Double] = [ + 0.0, + -0.0, + .infinity, + -.infinity, + .leastNonzeroMagnitude, + .greatestFiniteMagnitude, + Double(bitPattern: 0x7FF8_0000_0000_1234) + ] + for value in doubleValues { + let decoded: Double = try fory.deserialize(try fory.serialize(value)) + #expect(decoded.bitPattern == value.bitPattern) + } + + let float16Values: [Float16] = [ + .init(bitPattern: 0x0000), + .init(bitPattern: 0x8000), + .init(bitPattern: 0x7C00), + .init(bitPattern: 0xFC00), + .init(bitPattern: 0x0001), + .init(bitPattern: 0x7BFF), + .init(bitPattern: 0x7E11) + ] + for value in float16Values { + let decoded: Float16 = try fory.deserialize(try fory.serialize(value)) + #expect(decoded.bitPattern == value.bitPattern) + } + + let bfloat16Values: [BFloat16] = [ + .init(rawValue: 0x0000), + .init(rawValue: 0x8000), + .init(rawValue: 0x7F80), + .init(rawValue: 0xFF80), + .init(rawValue: 0x0001), + .init(rawValue: 0x7FC1) + ] + for value in bfloat16Values { + let decoded: BFloat16 = try fory.deserialize(try fory.serialize(value)) + #expect(decoded.rawValue == value.rawValue) + } } @Test func namedInitializerBuildsConfig() { - let defaultConfig = Fory() - #expect(defaultConfig.config.trackRef == false) - #expect(defaultConfig.config.compatible == true) - #expect(defaultConfig.config.checkClassVersion == false) - #expect(defaultConfig.config.maxDepth == 5) - #expect(defaultConfig.config.maxTypeFields == 512) - #expect(defaultConfig.config.maxTypeMetaBytes == 4096) - #expect(defaultConfig.config.maxSchemaVersionsPerType == 10) - #expect(defaultConfig.config.maxAverageSchemaVersionsPerType == 3) - - let explicitConfig = Fory( - ref: true, - compatible: true, - maxDepth: 7, - maxTypeFields: 31, - maxTypeMetaBytes: 1234, - maxSchemaVersionsPerType: 12, - maxAverageSchemaVersionsPerType: 4 - ) - #expect(explicitConfig.config.trackRef == true) - #expect(explicitConfig.config.compatible == true) - #expect(explicitConfig.config.checkClassVersion == false) - #expect(explicitConfig.config.maxDepth == 7) - #expect(explicitConfig.config.maxTypeFields == 31) - #expect(explicitConfig.config.maxTypeMetaBytes == 1234) - #expect(explicitConfig.config.maxSchemaVersionsPerType == 12) - #expect(explicitConfig.config.maxAverageSchemaVersionsPerType == 4) - - let configInit = Fory( - config: .init( - trackRef: false, - compatible: true, - maxDepth: 9, - maxTypeFields: 41, - maxTypeMetaBytes: 2048, - maxSchemaVersionsPerType: 14, - maxAverageSchemaVersionsPerType: 5 - )) - #expect(configInit.config.trackRef == false) - #expect(configInit.config.compatible == true) - #expect(configInit.config.checkClassVersion == false) - #expect(configInit.config.maxDepth == 9) - #expect(configInit.config.maxTypeFields == 41) - #expect(configInit.config.maxTypeMetaBytes == 2048) - #expect(configInit.config.maxSchemaVersionsPerType == 14) - #expect(configInit.config.maxAverageSchemaVersionsPerType == 5) - - let schemaConsistentDirect = Fory(ref: true, compatible: false) - let schemaConsistentViaConfig = Fory(config: Config(trackRef: true, compatible: false)) - #expect(schemaConsistentDirect.config.checkClassVersion == true) - #expect(schemaConsistentViaConfig.config.checkClassVersion == true) + let defaultConfig = Fory() + #expect(defaultConfig.config.trackRef == false) + #expect(defaultConfig.config.compatible == true) + #expect(defaultConfig.config.checkClassVersion == false) + #expect(defaultConfig.config.maxDepth == 5) + #expect(defaultConfig.config.maxContainerMemoryBytes == -1) + #expect(defaultConfig.config.maxTypeFields == 512) + #expect(defaultConfig.config.maxTypeMetaBytes == 4096) + #expect(defaultConfig.config.maxSchemaVersionsPerType == 10) + #expect(defaultConfig.config.maxAverageSchemaVersionsPerType == 3) + + let explicitConfig = Fory( + ref: true, + compatible: true, + maxDepth: 7, + maxContainerMemoryBytes: 65_536, + maxTypeFields: 31, + maxTypeMetaBytes: 1234, + maxSchemaVersionsPerType: 12, + maxAverageSchemaVersionsPerType: 4 + ) + #expect(explicitConfig.config.trackRef == true) + #expect(explicitConfig.config.compatible == true) + #expect(explicitConfig.config.checkClassVersion == false) + #expect(explicitConfig.config.maxDepth == 7) + #expect(explicitConfig.config.maxContainerMemoryBytes == 65_536) + #expect(explicitConfig.config.maxTypeFields == 31) + #expect(explicitConfig.config.maxTypeMetaBytes == 1234) + #expect(explicitConfig.config.maxSchemaVersionsPerType == 12) + #expect(explicitConfig.config.maxAverageSchemaVersionsPerType == 4) + + let configInit = Fory( + config: .init( + trackRef: false, + compatible: true, + maxDepth: 9, + maxContainerMemoryBytes: 131_072, + maxTypeFields: 41, + maxTypeMetaBytes: 2048, + maxSchemaVersionsPerType: 14, + maxAverageSchemaVersionsPerType: 5 + )) + #expect(configInit.config.trackRef == false) + #expect(configInit.config.compatible == true) + #expect(configInit.config.checkClassVersion == false) + #expect(configInit.config.maxDepth == 9) + #expect(configInit.config.maxContainerMemoryBytes == 131_072) + #expect(configInit.config.maxTypeFields == 41) + #expect(configInit.config.maxTypeMetaBytes == 2048) + #expect(configInit.config.maxSchemaVersionsPerType == 14) + #expect(configInit.config.maxAverageSchemaVersionsPerType == 5) + + let schemaConsistentDirect = Fory(ref: true, compatible: false) + let schemaConsistentViaConfig = Fory(config: Config(trackRef: true, compatible: false)) + #expect(schemaConsistentDirect.config.checkClassVersion == true) + #expect(schemaConsistentViaConfig.config.checkClassVersion == true) } @Test func structEvolvingOverrideUsesSmallerCompatiblePayload() throws { - let fory = Fory(compatible: true) - fory.register(EvolvingOverrideValue.self, id: 1001) - fory.register(FixedOverrideValue.self, id: 1002) + let fory = Fory(compatible: true) + fory.register(EvolvingOverrideValue.self, id: 1001) + fory.register(FixedOverrideValue.self, id: 1002) - let evolving = EvolvingOverrideValue(f1: "payload") - let fixed = FixedOverrideValue(f1: "payload") + let evolving = EvolvingOverrideValue(f1: "payload") + let fixed = FixedOverrideValue(f1: "payload") - let evolvingData = try fory.serialize(evolving) - let fixedData = try fory.serialize(fixed) + let evolvingData = try fory.serialize(evolving) + let fixedData = try fory.serialize(fixed) - #expect(fixedData.count < evolvingData.count) - let decodedEvolving: EvolvingOverrideValue = try fory.deserialize(evolvingData) - let decodedFixed: FixedOverrideValue = try fory.deserialize(fixedData) - #expect(decodedEvolving == evolving) - #expect(decodedFixed == fixed) + #expect(fixedData.count < evolvingData.count) + let decodedEvolving: EvolvingOverrideValue = try fory.deserialize(evolvingData) + let decodedFixed: FixedOverrideValue = try fory.deserialize(fixedData) + #expect(decodedEvolving == evolving) + #expect(decodedFixed == fixed) } @Test func deserializeRejectsTrailingBytes() throws { - let fory = Fory() - let payload = try fory.serialize(Int32(7)) - var bytes = [UInt8](payload) - bytes.append(0xFF) - let withTrailing = Data(bytes) + let fory = Fory() + let payload = try fory.serialize(Int32(7)) + var bytes = [UInt8](payload) + bytes.append(0xFF) + let withTrailing = Data(bytes) - do { - let _: Int32 = try fory.deserialize(withTrailing) - #expect(Bool(false)) - } catch {} + do { + let _: Int32 = try fory.deserialize(withTrailing) + #expect(Bool(false)) + } catch {} } @Test func optionalRoundTrip() throws { - let fory = Fory() + let fory = Fory() - let some: String? = "present" - let someData = try fory.serialize(some) - let someValue: String? = try fory.deserialize(someData) - #expect(someValue == "present") + let some: String? = "present" + let someData = try fory.serialize(some) + let someValue: String? = try fory.deserialize(someData) + #expect(someValue == "present") - let none: String? = nil - let noneData = try fory.serialize(none) - let noneValue: String? = try fory.deserialize(noneData) - #expect(noneValue == nil) + let none: String? = nil + let noneData = try fory.serialize(none) + let noneValue: String? = try fory.deserialize(noneData) + #expect(noneValue == nil) } @Test func collectionsRoundTrip() throws { - let fory = Fory() + let fory = Fory() - let list: [String?] = ["a", nil, "b"] - let listData = try fory.serialize(list) - let listValue: [String?] = try fory.deserialize(listData) - #expect(listValue == list) + let list: [String?] = ["a", nil, "b"] + let listData = try fory.serialize(list) + let listValue: [String?] = try fory.deserialize(listData) + #expect(listValue == list) - let intArray: [Int32] = [1, 2, 3, 4] - let intArrayData = try fory.serialize(intArray) - let intArrayValue: [Int32] = try fory.deserialize(intArrayData) - #expect(intArrayValue == intArray) + let intArray: [Int32] = [1, 2, 3, 4] + let intArrayData = try fory.serialize(intArray) + let intArrayValue: [Int32] = try fory.deserialize(intArrayData) + #expect(intArrayValue == intArray) - let uint8Array: [UInt8] = [1, 2, 3, 250] - let uint8ArrayData = try fory.serialize(uint8Array) - let uint8ArrayValue: [UInt8] = try fory.deserialize(uint8ArrayData) - #expect(uint8ArrayValue == uint8Array) + let uint8Array: [UInt8] = [1, 2, 3, 250] + let uint8ArrayData = try fory.serialize(uint8Array) + let uint8ArrayValue: [UInt8] = try fory.deserialize(uint8ArrayData) + #expect(uint8ArrayValue == uint8Array) - let set: Set = [1, 5, 8] - let setData = try fory.serialize(set) - let setValue: Set = try fory.deserialize(setData) - #expect(setValue == set) + let set: Set = [1, 5, 8] + let setData = try fory.serialize(set) + let setValue: Set = try fory.deserialize(setData) + #expect(setValue == set) - let map: [Int8: Int32?] = [1: 100, 2: nil, 3: -7] - let mapData = try fory.serialize(map) - let mapValue: [Int8: Int32?] = try fory.deserialize(mapData) - #expect(mapValue == map) + let map: [Int8: Int32?] = [1: 100, 2: nil, 3: -7] + let mapData = try fory.serialize(map) + let mapValue: [Int8: Int32?] = try fory.deserialize(mapData) + #expect(mapValue == map) - let nullableKeyMap: [Int8?: Int32?] = [1: 10, nil: nil] - let nullableMapData = try fory.serialize(nullableKeyMap) - let nullableMapValue: [Int8?: Int32?] = try fory.deserialize(nullableMapData) - #expect(nullableMapValue == nullableKeyMap) + let nullableKeyMap: [Int8?: Int32?] = [1: 10, nil: nil] + let nullableMapData = try fory.serialize(nullableKeyMap) + let nullableMapValue: [Int8?: Int32?] = try fory.deserialize(nullableMapData) + #expect(nullableMapValue == nullableKeyMap) } @Test func primitiveArrayTypeIDs() throws { - let fory = Fory() + let fory = Fory() - let int32Data = try fory.serialize([Int32(7), 9]) - let int32Bytes = [UInt8](int32Data) - #expect(int32Bytes[0] == ForyHeaderFlag.isXlang) - #expect(Int8(bitPattern: int32Bytes[1]) == RefFlag.notNullValue.rawValue) - #expect(UInt32(int32Bytes[2]) == TypeId.list.rawValue) + let int32Data = try fory.serialize([Int32(7), 9]) + let int32Bytes = [UInt8](int32Data) + #expect(int32Bytes[0] == ForyHeaderFlag.isXlang) + #expect(Int8(bitPattern: int32Bytes[1]) == RefFlag.notNullValue.rawValue) + #expect(UInt32(int32Bytes[2]) == TypeId.list.rawValue) - let uint8Data = try fory.serialize([UInt8(1), 2, 3]) - let uint8Bytes = [UInt8](uint8Data) - #expect(UInt32(uint8Bytes[2]) == TypeId.list.rawValue) + let uint8Data = try fory.serialize([UInt8(1), 2, 3]) + let uint8Bytes = [UInt8](uint8Data) + #expect(UInt32(uint8Bytes[2]) == TypeId.list.rawValue) } @Test func typeMetaFieldLimitRejectsLargeStruct() throws { - let fieldType = TypeMeta.FieldType(typeID: TypeId.int32.rawValue, nullable: false) - let meta = try TypeMeta( - typeID: TypeId.structType.rawValue, - userTypeID: 901, - namespace: .empty(specialChar1: ".", specialChar2: "_"), - typeName: .empty(specialChar1: "$", specialChar2: "_"), - registerByName: false, - fields: [ - TypeMeta.FieldInfo(fieldID: nil, fieldName: "first", fieldType: fieldType), - TypeMeta.FieldInfo(fieldID: nil, fieldName: "second", fieldType: fieldType) - ] - ) - let encoded = try meta.encode() + let fieldType = TypeMeta.FieldType(typeID: TypeId.int32.rawValue, nullable: false) + let meta = try TypeMeta( + typeID: TypeId.structType.rawValue, + userTypeID: 901, + namespace: .empty(specialChar1: ".", specialChar2: "_"), + typeName: .empty(specialChar1: "$", specialChar2: "_"), + registerByName: false, + fields: [ + TypeMeta.FieldInfo(fieldID: nil, fieldName: "first", fieldType: fieldType), + TypeMeta.FieldInfo(fieldID: nil, fieldName: "second", fieldType: fieldType) + ] + ) + let encoded = try meta.encode() - #expect(throws: (any Error).self) { - _ = try TypeMeta.decode(encoded, maxTypeFields: 1) - } + #expect(throws: (any Error).self) { + _ = try TypeMeta.decode(encoded, maxTypeFields: 1) + } } @Test func typeMetaBodyLimitRejectsLargeMetadata() throws { - let meta = try TypeMeta( - typeID: TypeId.structType.rawValue, - userTypeID: 901, - namespace: .empty(specialChar1: ".", specialChar2: "_"), - typeName: .empty(specialChar1: "$", specialChar2: "_"), - registerByName: false, - fields: [ - TypeMeta.FieldInfo( - fieldID: nil, - fieldName: "value", - fieldType: TypeMeta.FieldType(typeID: TypeId.int32.rawValue, nullable: false)) - ] - ) - let encoded = try meta.encode() + let meta = try TypeMeta( + typeID: TypeId.structType.rawValue, + userTypeID: 901, + namespace: .empty(specialChar1: ".", specialChar2: "_"), + typeName: .empty(specialChar1: "$", specialChar2: "_"), + registerByName: false, + fields: [ + TypeMeta.FieldInfo( + fieldID: nil, + fieldName: "value", + fieldType: TypeMeta.FieldType(typeID: TypeId.int32.rawValue, nullable: false)) + ] + ) + let encoded = try meta.encode() - #expect(throws: (any Error).self) { - _ = try TypeMeta.decode(encoded, maxTypeMetaBytes: 1) - } + #expect(throws: (any Error).self) { + _ = try TypeMeta.decode(encoded, maxTypeMetaBytes: 1) + } } @Test func schemaLimitTracksStructTypesSeparately() throws { - let config = Config(maxSchemaVersionsPerType: 1) - let resolver = TypeResolver(config: config) - resolver.register(Person.self, id: 901) - resolver.register(Address.self, id: 902) - try resolver.finishRegistration() - - func remoteTypeMeta(userTypeID: UInt32, fieldName: String) throws -> TypeMeta { - try TypeMeta( - typeID: TypeId.structType.rawValue, - userTypeID: userTypeID, - namespace: .empty(specialChar1: ".", specialChar2: "_"), - typeName: .empty(specialChar1: "$", specialChar2: "_"), - registerByName: false, - fields: [ - TypeMeta.FieldInfo( - fieldID: nil, - fieldName: fieldName, - fieldType: TypeMeta.FieldType(typeID: TypeId.int32.rawValue, nullable: false) - ) - ] + let config = Config(maxSchemaVersionsPerType: 1) + let resolver = TypeResolver(config: config) + resolver.register(Person.self, id: 901) + resolver.register(Address.self, id: 902) + try resolver.finishRegistration() + + func remoteTypeMeta(userTypeID: UInt32, fieldName: String) throws -> TypeMeta { + try TypeMeta( + typeID: TypeId.structType.rawValue, + userTypeID: userTypeID, + namespace: .empty(specialChar1: ".", specialChar2: "_"), + typeName: .empty(specialChar1: "$", specialChar2: "_"), + registerByName: false, + fields: [ + TypeMeta.FieldInfo( + fieldID: nil, + fieldName: fieldName, + fieldType: TypeMeta.FieldType(typeID: TypeId.int32.rawValue, nullable: false) ) - } + ] + ) + } - func cache(_ typeMeta: TypeMeta) throws { - let encoded = try typeMeta.encode() - let headerReader = ByteBuffer(bytes: encoded) - let header = try headerReader.readUInt64() - let buffer = ByteBuffer(bytes: encoded) - let decoded = try TypeMeta.decode(buffer) - let localTypeInfo = try resolver.requireTypeInfo(for: decoded) - _ = try resolver.cacheTypeInfo( - decoded, - forHeader: header, - localTypeInfo: localTypeInfo, - exactLocal: false, - config: config - ) - } + func cache(_ typeMeta: TypeMeta) throws { + let encoded = try typeMeta.encode() + let headerReader = ByteBuffer(bytes: encoded) + let header = try headerReader.readUInt64() + let buffer = ByteBuffer(bytes: encoded) + let decoded = try TypeMeta.decode(buffer) + let localTypeInfo = try resolver.requireTypeInfo(for: decoded) + _ = try resolver.cacheTypeInfo( + decoded, + forHeader: header, + localTypeInfo: localTypeInfo, + exactLocal: false, + config: config + ) + } - try cache(remoteTypeMeta(userTypeID: 901, fieldName: "remoteA")) - try cache(remoteTypeMeta(userTypeID: 902, fieldName: "remoteA")) - #expect(throws: (any Error).self) { - try cache(remoteTypeMeta(userTypeID: 901, fieldName: "remoteB")) - } + try cache(remoteTypeMeta(userTypeID: 901, fieldName: "remoteA")) + try cache(remoteTypeMeta(userTypeID: 902, fieldName: "remoteA")) + #expect(throws: (any Error).self) { + try cache(remoteTypeMeta(userTypeID: 901, fieldName: "remoteB")) + } } @Test func nonStructTypeMetaUsesSchemaLimit() throws { - let config = Config(maxSchemaVersionsPerType: 1) - let resolver = TypeResolver(config: config) - try resolver.register(SparseStatus.self, name: "example.SharedEnum") - try resolver.finishRegistration() - let namespace = try MetaStringEncoder.namespace.encode("example") - let typeName = try MetaStringEncoder.typeName.encode("SharedEnum") - - func remoteTypeMeta(_ typeID: TypeId) throws -> TypeMeta { - try TypeMeta( - typeID: typeID.rawValue, - userTypeID: nil, - namespace: namespace, - typeName: typeName, - registerByName: true, - fields: [] - ) - } - - func cache(_ typeMeta: TypeMeta) throws { - let encoded = try typeMeta.encode() - let headerReader = ByteBuffer(bytes: encoded) - let header = try headerReader.readUInt64() - let buffer = ByteBuffer(bytes: encoded) - let decoded = try TypeMeta.decode(buffer) - let localTypeInfo = try resolver.requireTypeInfo(for: decoded) - _ = try resolver.cacheTypeInfo( - decoded, - forHeader: header, - localTypeInfo: localTypeInfo, - exactLocal: false, - config: config - ) - } - - try cache(remoteTypeMeta(.namedExt)) - #expect(throws: (any Error).self) { - try cache(remoteTypeMeta(.namedUnion)) - } -} - -@Test -func exactLocalNonStructTypeMetaBypassesSchemaLimit() throws { - let config = Config(compatible: true, maxSchemaVersionsPerType: 1) - let resolver = TypeResolver(config: config) - try resolver.register(SparseStatus.self, name: "example.SharedEnum") - try resolver.finishRegistration() - let localTypeInfo = try resolver.requireTypeInfo(for: SparseStatus.self) - let namespace = try MetaStringEncoder.namespace.encode("example") - let typeName = try MetaStringEncoder.typeName.encode("SharedEnum") - - let exactBuffer = ByteBuffer() - exactBuffer.writeUInt8(UInt8(truncatingIfNeeded: TypeId.namedEnum.rawValue)) - exactBuffer.writeUInt8(0) - exactBuffer.writeBytes(localTypeInfo.typeDefBytes!) - let exactContext = ReadContext(buffer: exactBuffer, typeResolver: resolver, config: config) - _ = try exactContext.readTypeInfo(for: SparseStatus.self) - - let remote = try TypeMeta( - typeID: TypeId.namedExt.rawValue, - userTypeID: nil, - namespace: namespace, - typeName: typeName, - registerByName: true, - fields: [] + let config = Config(maxSchemaVersionsPerType: 1) + let resolver = TypeResolver(config: config) + try resolver.register(SparseStatus.self, name: "example.SharedEnum") + try resolver.finishRegistration() + let namespace = try MetaStringEncoder.namespace.encode("example") + let typeName = try MetaStringEncoder.typeName.encode("SharedEnum") + + func remoteTypeMeta(_ typeID: TypeId) throws -> TypeMeta { + try TypeMeta( + typeID: typeID.rawValue, + userTypeID: nil, + namespace: namespace, + typeName: typeName, + registerByName: true, + fields: [] ) - let encoded = try remote.encode() + } + + func cache(_ typeMeta: TypeMeta) throws { + let encoded = try typeMeta.encode() let headerReader = ByteBuffer(bytes: encoded) let header = try headerReader.readUInt64() let buffer = ByteBuffer(bytes: encoded) let decoded = try TypeMeta.decode(buffer) - let resolved = try resolver.requireTypeInfo(for: decoded) + let localTypeInfo = try resolver.requireTypeInfo(for: decoded) _ = try resolver.cacheTypeInfo( - decoded, - forHeader: header, - localTypeInfo: resolved, - exactLocal: false, - config: config + decoded, + forHeader: header, + localTypeInfo: localTypeInfo, + exactLocal: false, + config: config ) + } + + try cache(remoteTypeMeta(.namedExt)) + #expect(throws: (any Error).self) { + try cache(remoteTypeMeta(.namedUnion)) + } +} + +@Test +func exactLocalNonStructTypeMetaBypassesSchemaLimit() throws { + let config = Config(compatible: true, maxSchemaVersionsPerType: 1) + let resolver = TypeResolver(config: config) + try resolver.register(SparseStatus.self, name: "example.SharedEnum") + try resolver.finishRegistration() + let localTypeInfo = try resolver.requireTypeInfo(for: SparseStatus.self) + let namespace = try MetaStringEncoder.namespace.encode("example") + let typeName = try MetaStringEncoder.typeName.encode("SharedEnum") + + let exactBuffer = ByteBuffer() + exactBuffer.writeUInt8(UInt8(truncatingIfNeeded: TypeId.namedEnum.rawValue)) + exactBuffer.writeUInt8(0) + exactBuffer.writeBytes(localTypeInfo.typeDefBytes!) + let exactContext = ReadContext(buffer: exactBuffer, typeResolver: resolver, config: config) + _ = try exactContext.readTypeInfo(for: SparseStatus.self) + + let remote = try TypeMeta( + typeID: TypeId.namedExt.rawValue, + userTypeID: nil, + namespace: namespace, + typeName: typeName, + registerByName: true, + fields: [] + ) + let encoded = try remote.encode() + let headerReader = ByteBuffer(bytes: encoded) + let header = try headerReader.readUInt64() + let buffer = ByteBuffer(bytes: encoded) + let decoded = try TypeMeta.decode(buffer) + let resolved = try resolver.requireTypeInfo(for: decoded) + _ = try resolver.cacheTypeInfo( + decoded, + forHeader: header, + localTypeInfo: resolved, + exactLocal: false, + config: config + ) } @Test func typeMetaUsesFinalRegistration() throws { - func holderTypeDefBytes(registerFieldTypeFirst: Bool) throws -> [UInt8] { - let resolver = TypeResolver(config: Config(compatible: true)) - if registerFieldTypeFirst { - try resolver.register(LateMetaExt.self, name: "example.LateMetaExt") - try resolver.register(LateMetaHolder.self, name: "example.LateMetaHolder") - } else { - try resolver.register(LateMetaHolder.self, name: "example.LateMetaHolder") - try resolver.register(LateMetaExt.self, name: "example.LateMetaExt") - } - try resolver.finishRegistration() - return try resolver.requireTypeInfo(for: LateMetaHolder.self).typeDefBytes! + func holderTypeDefBytes(registerFieldTypeFirst: Bool) throws -> [UInt8] { + let resolver = TypeResolver(config: Config(compatible: true)) + if registerFieldTypeFirst { + try resolver.register(LateMetaExt.self, name: "example.LateMetaExt") + try resolver.register(LateMetaHolder.self, name: "example.LateMetaHolder") + } else { + try resolver.register(LateMetaHolder.self, name: "example.LateMetaHolder") + try resolver.register(LateMetaExt.self, name: "example.LateMetaExt") } + try resolver.finishRegistration() + return try resolver.requireTypeInfo(for: LateMetaHolder.self).typeDefBytes! + } - let fieldFirst = try holderTypeDefBytes(registerFieldTypeFirst: true) - let holderFirst = try holderTypeDefBytes(registerFieldTypeFirst: false) - #expect(fieldFirst == holderFirst) + let fieldFirst = try holderTypeDefBytes(registerFieldTypeFirst: true) + let holderFirst = try holderTypeDefBytes(registerFieldTypeFirst: false) + #expect(fieldFirst == holderFirst) - let typeMeta = try TypeMeta.decode(ByteBuffer(bytes: holderFirst)) - #expect(typeMeta.fields.count == 1) - #expect(typeMeta.fields[0].fieldType.typeID == TypeId.namedExt.rawValue) + let typeMeta = try TypeMeta.decode(ByteBuffer(bytes: holderFirst)) + #expect(typeMeta.fields.count == 1) + #expect(typeMeta.fields[0].fieldType.typeID == TypeId.namedExt.rawValue) } @Test func failedSchemaDoesNotConsumeLimit() throws { - let config = Config(maxSchemaVersionsPerType: 1) - let resolver = TypeResolver(config: config) - resolver.register(Person.self, id: 901) - resolver.register(Address.self, id: 902) - try resolver.finishRegistration() - - func remoteTypeMeta(fieldName: String, fieldType: TypeMeta.FieldType) throws -> TypeMeta { - try TypeMeta( - typeID: TypeId.structType.rawValue, - userTypeID: 901, - namespace: .empty(specialChar1: ".", specialChar2: "_"), - typeName: .empty(specialChar1: "$", specialChar2: "_"), - registerByName: false, - fields: [ - TypeMeta.FieldInfo( - fieldID: nil, - fieldName: fieldName, - fieldType: fieldType - ) - ] + let config = Config(maxSchemaVersionsPerType: 1) + let resolver = TypeResolver(config: config) + resolver.register(Person.self, id: 901) + resolver.register(Address.self, id: 902) + try resolver.finishRegistration() + + func remoteTypeMeta(fieldName: String, fieldType: TypeMeta.FieldType) throws -> TypeMeta { + try TypeMeta( + typeID: TypeId.structType.rawValue, + userTypeID: 901, + namespace: .empty(specialChar1: ".", specialChar2: "_"), + typeName: .empty(specialChar1: "$", specialChar2: "_"), + registerByName: false, + fields: [ + TypeMeta.FieldInfo( + fieldID: nil, + fieldName: fieldName, + fieldType: fieldType ) - } - - func cache(_ typeMeta: TypeMeta) throws { - let encoded = try typeMeta.encode() - let headerReader = ByteBuffer(bytes: encoded) - let header = try headerReader.readUInt64() - let buffer = ByteBuffer(bytes: encoded) - let decoded = try TypeMeta.decode(buffer) - let localTypeInfo = try resolver.requireTypeInfo(for: decoded) - _ = try resolver.cacheTypeInfo( - decoded, - forHeader: header, - localTypeInfo: localTypeInfo, - exactLocal: false, - config: config - ) - } + ] + ) + } - #expect(throws: (any Error).self) { - try cache( - remoteTypeMeta( - fieldName: "id", - fieldType: TypeMeta.FieldType( - typeID: TypeId.map.rawValue, - nullable: false, - generics: [ - TypeMeta.FieldType(typeID: TypeId.string.rawValue, nullable: false), - TypeMeta.FieldType(typeID: TypeId.int32.rawValue, nullable: false) - ] - ) - )) - } - try cache( - remoteTypeMeta( - fieldName: "remoteA", - fieldType: TypeMeta.FieldType(typeID: TypeId.int32.rawValue, nullable: false) - )) + func cache(_ typeMeta: TypeMeta) throws { + let encoded = try typeMeta.encode() + let headerReader = ByteBuffer(bytes: encoded) + let header = try headerReader.readUInt64() + let buffer = ByteBuffer(bytes: encoded) + let decoded = try TypeMeta.decode(buffer) + let localTypeInfo = try resolver.requireTypeInfo(for: decoded) + _ = try resolver.cacheTypeInfo( + decoded, + forHeader: header, + localTypeInfo: localTypeInfo, + exactLocal: false, + config: config + ) + } + + #expect(throws: (any Error).self) { + try cache(remoteTypeMeta( + fieldName: "id", + fieldType: TypeMeta.FieldType( + typeID: TypeId.map.rawValue, + nullable: false, + generics: [ + TypeMeta.FieldType(typeID: TypeId.string.rawValue, nullable: false), + TypeMeta.FieldType(typeID: TypeId.int32.rawValue, nullable: false) + ] + ) + )) + } + try cache(remoteTypeMeta( + fieldName: "remoteA", + fieldType: TypeMeta.FieldType(typeID: TypeId.int32.rawValue, nullable: false) + )) } @Test func staticTypeRejectsWrongMetaOwner() throws { - let config = Config(compatible: true) - let resolver = TypeResolver(config: config) - resolver.register(Person.self, id: 901) - resolver.register(Address.self, id: 902) - try resolver.finishRegistration() - let wrongTypeMeta = try TypeMeta( - typeID: TypeId.compatibleStruct.rawValue, - userTypeID: 901, - namespace: .empty(specialChar1: ".", specialChar2: "_"), - typeName: .empty(specialChar1: "$", specialChar2: "_"), - registerByName: false, - fields: [] - ) - let wrongBytes = try wrongTypeMeta.encode() - let wrongHeader = try ByteBuffer(bytes: wrongBytes).readUInt64() - let buffer = ByteBuffer() - buffer.writeUInt8(UInt8(truncatingIfNeeded: TypeId.compatibleStruct.rawValue)) - buffer.writeUInt8(0) - buffer.writeBytes(wrongBytes) - let context = ReadContext(buffer: buffer, typeResolver: resolver, config: config) - - #expect(throws: (any Error).self) { - _ = try context.readTypeInfo(for: Address.self) - } - #expect(resolver.getTypeInfo(forHeader: wrongHeader) == nil) - - let addressInfo = try resolver.requireTypeInfo(for: Address.self) - let addressBytes = try #require(addressInfo.typeDefBytes) - let addressHeader = try ByteBuffer(bytes: addressBytes).readUInt64() - let exactBuffer = ByteBuffer() - exactBuffer.writeUInt8(UInt8(truncatingIfNeeded: TypeId.compatibleStruct.rawValue)) - exactBuffer.writeUInt8(0) - exactBuffer.writeBytes(addressBytes) - let exactContext = ReadContext(buffer: exactBuffer, typeResolver: resolver, config: config) - _ = try exactContext.readTypeInfo(for: Address.self) - #expect(resolver.getTypeInfo(forHeader: addressHeader) == nil) + let config = Config(compatible: true) + let resolver = TypeResolver(config: config) + resolver.register(Person.self, id: 901) + resolver.register(Address.self, id: 902) + try resolver.finishRegistration() + let wrongTypeMeta = try TypeMeta( + typeID: TypeId.compatibleStruct.rawValue, + userTypeID: 901, + namespace: .empty(specialChar1: ".", specialChar2: "_"), + typeName: .empty(specialChar1: "$", specialChar2: "_"), + registerByName: false, + fields: [] + ) + let wrongBytes = try wrongTypeMeta.encode() + let wrongHeader = try ByteBuffer(bytes: wrongBytes).readUInt64() + let buffer = ByteBuffer() + buffer.writeUInt8(UInt8(truncatingIfNeeded: TypeId.compatibleStruct.rawValue)) + buffer.writeUInt8(0) + buffer.writeBytes(wrongBytes) + let context = ReadContext(buffer: buffer, typeResolver: resolver, config: config) + + #expect(throws: (any Error).self) { + _ = try context.readTypeInfo(for: Address.self) + } + #expect(resolver.getTypeInfo(forHeader: wrongHeader) == nil) + + let addressInfo = try resolver.requireTypeInfo(for: Address.self) + let addressBytes = try #require(addressInfo.typeDefBytes) + let addressHeader = try ByteBuffer(bytes: addressBytes).readUInt64() + let exactBuffer = ByteBuffer() + exactBuffer.writeUInt8(UInt8(truncatingIfNeeded: TypeId.compatibleStruct.rawValue)) + exactBuffer.writeUInt8(0) + exactBuffer.writeBytes(addressBytes) + let exactContext = ReadContext(buffer: exactBuffer, typeResolver: resolver, config: config) + _ = try exactContext.readTypeInfo(for: Address.self) + #expect(resolver.getTypeInfo(forHeader: addressHeader) == nil) } @Test func failedStaticMetaDoesNotCount() throws { - let config = Config(compatible: true, maxSchemaVersionsPerType: 1) - let resolver = TypeResolver(config: config) - resolver.register(Person.self, id: 901) - resolver.register(Address.self, id: 902) - try resolver.finishRegistration() - - func typeMeta(userTypeID: UInt32, fieldName: String) throws -> TypeMeta { - try TypeMeta( - typeID: TypeId.compatibleStruct.rawValue, - userTypeID: userTypeID, - namespace: .empty(specialChar1: ".", specialChar2: "_"), - typeName: .empty(specialChar1: "$", specialChar2: "_"), - registerByName: false, - fields: [ - TypeMeta.FieldInfo( - fieldID: nil, - fieldName: fieldName, - fieldType: TypeMeta.FieldType(typeID: TypeId.int32.rawValue, nullable: false) - ) - ] + let config = Config(compatible: true, maxSchemaVersionsPerType: 1) + let resolver = TypeResolver(config: config) + resolver.register(Person.self, id: 901) + resolver.register(Address.self, id: 902) + try resolver.finishRegistration() + + func typeMeta(userTypeID: UInt32, fieldName: String) throws -> TypeMeta { + try TypeMeta( + typeID: TypeId.compatibleStruct.rawValue, + userTypeID: userTypeID, + namespace: .empty(specialChar1: ".", specialChar2: "_"), + typeName: .empty(specialChar1: "$", specialChar2: "_"), + registerByName: false, + fields: [ + TypeMeta.FieldInfo( + fieldID: nil, + fieldName: fieldName, + fieldType: TypeMeta.FieldType(typeID: TypeId.int32.rawValue, nullable: false) ) - } - - func writeTypeInfo(_ buffer: ByteBuffer, marker: UInt8, typeMeta: TypeMeta) throws { - buffer.writeUInt8(UInt8(truncatingIfNeeded: TypeId.compatibleStruct.rawValue)) - buffer.writeUInt8(marker) - buffer.writeBytes(try typeMeta.encode()) - } + ] + ) + } - let failedBuffer = ByteBuffer() - try writeTypeInfo(failedBuffer, marker: 0, typeMeta: typeMeta(userTypeID: 902, fieldName: "zip2")) - try writeTypeInfo(failedBuffer, marker: 2, typeMeta: typeMeta(userTypeID: 901, fieldName: "id2")) - let failedContext = ReadContext(buffer: failedBuffer, typeResolver: resolver, config: config) + func writeTypeInfo(_ buffer: ByteBuffer, marker: UInt8, typeMeta: TypeMeta) throws { + buffer.writeUInt8(UInt8(truncatingIfNeeded: TypeId.compatibleStruct.rawValue)) + buffer.writeUInt8(marker) + buffer.writeBytes(try typeMeta.encode()) + } + + let failedBuffer = ByteBuffer() + try writeTypeInfo(failedBuffer, marker: 0, typeMeta: typeMeta(userTypeID: 902, fieldName: "zip2")) + try writeTypeInfo(failedBuffer, marker: 2, typeMeta: typeMeta(userTypeID: 901, fieldName: "id2")) + let failedContext = ReadContext(buffer: failedBuffer, typeResolver: resolver, config: config) + _ = try failedContext.readTypeInfo(for: Address.self) + #expect(throws: (any Error).self) { _ = try failedContext.readTypeInfo(for: Address.self) - #expect(throws: (any Error).self) { - _ = try failedContext.readTypeInfo(for: Address.self) - } + } - let validBuffer = ByteBuffer() - try writeTypeInfo(validBuffer, marker: 0, typeMeta: typeMeta(userTypeID: 901, fieldName: "id3")) - let validContext = ReadContext(buffer: validBuffer, typeResolver: resolver, config: config) - _ = try validContext.readTypeInfo(for: Person.self) + let validBuffer = ByteBuffer() + try writeTypeInfo(validBuffer, marker: 0, typeMeta: typeMeta(userTypeID: 901, fieldName: "id3")) + let validContext = ReadContext(buffer: validBuffer, typeResolver: resolver, config: config) + _ = try validContext.readTypeInfo(for: Person.self) } @Test func macroStructRoundTrip() throws { - let fory = Fory() - fory.register(Address.self, id: 100) - fory.register(Person.self, id: 101) - - let person = Person( - id: 42, - name: "Alice", - nickname: nil, - scores: [10, 20, 30], - tags: ["swift", "xlang"], - addresses: [Address(street: "Main", zip: 94107)], - metadata: [1: 100, 2: nil] - ) + let fory = Fory() + fory.register(Address.self, id: 100) + fory.register(Person.self, id: 101) + + let person = Person( + id: 42, + name: "Alice", + nickname: nil, + scores: [10, 20, 30], + tags: ["swift", "xlang"], + addresses: [Address(street: "Main", zip: 94107)], + metadata: [1: 100, 2: nil] + ) - let data = try fory.serialize(person) - let decoded: Person = try fory.deserialize(data) - #expect(decoded == person) + let data = try fory.serialize(person) + let decoded: Person = try fory.deserialize(data) + #expect(decoded == person) } @Test func macroClassRefTracking() throws { - let fory = Fory(config: .init(trackRef: true, compatible: false)) - fory.register(Node.self, id: 200) + let fory = Fory(config: .init(trackRef: true, compatible: false)) + fory.register(Node.self, id: 200) - let node = Node(value: 7) - node.next = node + let node = Node(value: 7) + node.next = node - let data = try fory.serialize(node) - let decoded: Node = try fory.deserialize(data) + let data = try fory.serialize(node) + let decoded: Node = try fory.deserialize(data) - #expect(decoded.value == 7) - #expect(decoded.next === decoded) + #expect(decoded.value == 7) + #expect(decoded.next === decoded) } @Test func macroClassWeakRefTracking() throws { - let fory = Fory(config: .init(trackRef: true, compatible: false)) - fory.register(WeakNode.self, id: 201) + let fory = Fory(config: .init(trackRef: true, compatible: false)) + fory.register(WeakNode.self, id: 201) - let node = WeakNode(value: 13) - node.next = node + let node = WeakNode(value: 13) + node.next = node - let data = try fory.serialize(node) - let decoded: WeakNode = try fory.deserialize(data) + let data = try fory.serialize(node) + let decoded: WeakNode = try fory.deserialize(data) - #expect(decoded.value == 13) - #expect(decoded.next === decoded) + #expect(decoded.value == 13) + #expect(decoded.next === decoded) } @Test func topLevelAnyRoundTrip() throws { - let fory = Fory() - fory.register(Address.self, id: 209) + let fory = Fory() + fory.register(Address.self, id: 209) - let value: Any = Address(street: "AnyTop", zip: 8080) - let data = try fory.serialize(value) - let decoded: Any = try fory.deserialize(data) - #expect(decoded as? Address == Address(street: "AnyTop", zip: 8080)) + let value: Any = Address(street: "AnyTop", zip: 8080) + let data = try fory.serialize(value) + let decoded: Any = try fory.deserialize(data) + #expect(decoded as? Address == Address(street: "AnyTop", zip: 8080)) - var buffer = Data() - try fory.serialize(value, to: &buffer) - let decodedFrom: Any = try fory.deserialize(from: ByteBuffer(data: buffer)) - #expect(decodedFrom as? Address == Address(street: "AnyTop", zip: 8080)) + var buffer = Data() + try fory.serialize(value, to: &buffer) + let decodedFrom: Any = try fory.deserialize(from: ByteBuffer(data: buffer)) + #expect(decodedFrom as? Address == Address(street: "AnyTop", zip: 8080)) - let nullAny: Any = Optional.none as Any - let nullData = try fory.serialize(nullAny) - let nullDecoded: Any = try fory.deserialize(nullData) - #expect(nullDecoded is ForyAnyNullValue) + let nullAny: Any = Optional.none as Any + let nullData = try fory.serialize(nullAny) + let nullDecoded: Any = try fory.deserialize(nullData) + #expect(nullDecoded is ForyAnyNullValue) } @Test func dynamicUserTypesDecodeByID() throws { - let fory = Fory() - fory.register(Address.self, id: 600) - try fory.register(Person.self, name: "demo.person") + let fory = Fory() + fory.register(Address.self, id: 600) + try fory.register(Person.self, name: "demo.person") - let value: Any = Address(street: "mixed", zip: 7788) - let data = try fory.serialize(value) - let decoded: Any = try fory.deserialize(data) - #expect(decoded as? Address == Address(street: "mixed", zip: 7788)) + let value: Any = Address(street: "mixed", zip: 7788) + let data = try fory.serialize(value) + let decoded: Any = try fory.deserialize(data) + #expect(decoded as? Address == Address(street: "mixed", zip: 7788)) } @Test func duplicateNameRegistrationIsRejected() throws { - let resolver = TypeResolver(config: Config(trackRef: false)) - try resolver.register(Address.self, namespace: "demo", typeName: "entity") + let resolver = TypeResolver(config: Config(trackRef: false)) + try resolver.register(Address.self, namespace: "demo", typeName: "entity") - do { - try resolver.register(Person.self, namespace: "demo", typeName: "entity") - #expect(Bool(false)) - } catch {} + do { + try resolver.register(Person.self, namespace: "demo", typeName: "entity") + #expect(Bool(false)) + } catch {} } @Test func nameRegistrationSplitsLastDot() throws { - let resolver = TypeResolver(config: Config(trackRef: false)) - try resolver.register(Address.self, name: "com.example.Address") + let resolver = TypeResolver(config: Config(trackRef: false)) + try resolver.register(Address.self, name: "com.example.Address") - let info = try resolver.requireTypeInfo(namespace: "com.example", typeName: "Address") - #expect(info.namespace.value == "com.example") - #expect(info.typeName.value == "Address") + let info = try resolver.requireTypeInfo(namespace: "com.example", typeName: "Address") + #expect(info.namespace.value == "com.example") + #expect(info.typeName.value == "Address") } @Test func nameRegistrationAllowsSimpleName() throws { - let resolver = TypeResolver(config: Config(trackRef: false)) - try resolver.register(Address.self, name: "Address") + let resolver = TypeResolver(config: Config(trackRef: false)) + try resolver.register(Address.self, name: "Address") - let info = try resolver.requireTypeInfo(namespace: "", typeName: "Address") - #expect(info.namespace.value == "") - #expect(info.typeName.value == "Address") + let info = try resolver.requireTypeInfo(namespace: "", typeName: "Address") + #expect(info.namespace.value == "") + #expect(info.typeName.value == "Address") } @Test func nameRegistrationRejectsEmptyName() throws { - let fory = Fory() + let fory = Fory() - #expect(throws: ForyError.self) { - try fory.register(Address.self, name: "") - } + #expect(throws: ForyError.self) { + try fory.register(Address.self, name: "") + } } @Test func nameRegistrationRejectsTrailingDot() throws { - let fory = Fory() + let fory = Fory() - #expect(throws: ForyError.self) { - try fory.register(Address.self, name: "com.example.") - } + #expect(throws: ForyError.self) { + try fory.register(Address.self, name: "com.example.") + } } @Test func splitNameRegistrationRejectsDottedTypeName() throws { - let resolver = TypeResolver(config: Config(trackRef: false)) + let resolver = TypeResolver(config: Config(trackRef: false)) - #expect(throws: ForyError.self) { - try resolver.register(Address.self, namespace: "com", typeName: "example.Address") - } + #expect(throws: ForyError.self) { + try resolver.register(Address.self, namespace: "com", typeName: "example.Address") + } } @Test func registrationIsRejectedAfterFirstTopLevelUse() throws { - let fory = Fory() - _ = try fory.serialize(Int32(7)) - - do { - try fory.register(Address.self, name: "demo.address") - #expect(Bool(false)) - } catch { - #expect("\(error)".contains("cannot register more types")) - } + let fory = Fory() + _ = try fory.serialize(Int32(7)) + + do { + try fory.register(Address.self, name: "demo.address") + #expect(Bool(false)) + } catch { + #expect("\(error)".contains("cannot register more types")) + } } @Test func serializeToAppendsRoots() throws { - let fory = Fory() - let first = Int32(7) - let second = "swift-buffer" - let third: String? = nil + let fory = Fory() + let first = Int32(7) + let second = "swift-buffer" + let third: String? = nil - let firstData = try fory.serialize(first) - let secondData = try fory.serialize(second) - let thirdData = try fory.serialize(third) + let firstData = try fory.serialize(first) + let secondData = try fory.serialize(second) + let thirdData = try fory.serialize(third) - var stream = Data() - try fory.serialize(first, to: &stream) - try fory.serialize(second, to: &stream) - try fory.serialize(third, to: &stream) + var stream = Data() + try fory.serialize(first, to: &stream) + try fory.serialize(second, to: &stream) + try fory.serialize(third, to: &stream) - var expected = Data() - expected.append(firstData) - expected.append(secondData) - expected.append(thirdData) - #expect(stream == expected) + var expected = Data() + expected.append(firstData) + expected.append(secondData) + expected.append(thirdData) + #expect(stream == expected) - let buffer = ByteBuffer(data: stream) - let decodedFirst: Int32 = try fory.deserialize(from: buffer) - #expect(decodedFirst == first) - #expect(buffer.getCursor() == firstData.count) + let buffer = ByteBuffer(data: stream) + let decodedFirst: Int32 = try fory.deserialize(from: buffer) + #expect(decodedFirst == first) + #expect(buffer.getCursor() == firstData.count) - let decodedSecond: String = try fory.deserialize(from: buffer) - #expect(decodedSecond == second) - #expect(buffer.getCursor() == firstData.count + secondData.count) + let decodedSecond: String = try fory.deserialize(from: buffer) + #expect(decodedSecond == second) + #expect(buffer.getCursor() == firstData.count + secondData.count) - let decodedThird: String? = try fory.deserialize(from: buffer) - #expect(decodedThird == nil) - #expect(buffer.remaining == 0) + let decodedThird: String? = try fory.deserialize(from: buffer) + #expect(decodedThird == nil) + #expect(buffer.remaining == 0) } @Test func rootBufferHonorsCursor() throws { - let fory = Fory() - let prefix: [UInt8] = [0xAA, 0xBB, 0xCC] - let payload = try fory.serialize("offset") + let fory = Fory() + let prefix: [UInt8] = [0xAA, 0xBB, 0xCC] + let payload = try fory.serialize("offset") - let buffer = ByteBuffer() - buffer.writeBytes(prefix) - buffer.writeBytes(Array(payload)) - buffer.setCursor(prefix.count) + let buffer = ByteBuffer() + buffer.writeBytes(prefix) + buffer.writeBytes(Array(payload)) + buffer.setCursor(prefix.count) - let decoded: String = try fory.deserialize(from: buffer) - #expect(decoded == "offset") - #expect(buffer.getCursor() == buffer.count) - #expect(Array(buffer.storage.prefix(prefix.count)) == prefix) + let decoded: String = try fory.deserialize(from: buffer) + #expect(decoded == "offset") + #expect(buffer.getCursor() == buffer.count) + #expect(Array(buffer.storage.prefix(prefix.count)) == prefix) } @Test func topLevelAnyObjectRoundTrip() throws { - let fory = Fory(config: .init(trackRef: true, compatible: false)) - fory.register(Node.self, id: 210) + let fory = Fory(config: .init(trackRef: true, compatible: false)) + fory.register(Node.self, id: 210) - let value: AnyObject = Node(value: 123) - let data = try fory.serialize(value) - let decoded: AnyObject = try fory.deserialize(data) + let value: AnyObject = Node(value: 123) + let data = try fory.serialize(value) + let decoded: AnyObject = try fory.deserialize(data) - let node = decoded as? Node - #expect(node != nil) - #expect(node?.value == 123) + let node = decoded as? Node + #expect(node != nil) + #expect(node?.value == 123) - var buffer = Data() - try fory.serialize(value, to: &buffer) - let decodedFrom: AnyObject = try fory.deserialize(from: ByteBuffer(data: buffer)) - #expect((decodedFrom as? Node)?.value == 123) + var buffer = Data() + try fory.serialize(value, to: &buffer) + let decodedFrom: AnyObject = try fory.deserialize(from: ByteBuffer(data: buffer)) + #expect((decodedFrom as? Node)?.value == 123) } @Test func topLevelAnySerializerRoundTrip() throws { - let fory = Fory() - fory.register(Address.self, id: 211) + let fory = Fory() + fory.register(Address.self, id: 211) - let value: any Serializer = Address(street: "AnyStreet", zip: 9090) - let data = try fory.serialize(value) - let decoded: any Serializer = try fory.deserialize(data) + let value: any Serializer = Address(street: "AnyStreet", zip: 9090) + let data = try fory.serialize(value) + let decoded: any Serializer = try fory.deserialize(data) - let address = decoded as? Address - #expect(address == Address(street: "AnyStreet", zip: 9090)) + let address = decoded as? Address + #expect(address == Address(street: "AnyStreet", zip: 9090)) - var buffer = Data() - try fory.serialize(value, to: &buffer) - let decodedFrom: any Serializer = try fory.deserialize(from: ByteBuffer(data: buffer)) - #expect(decodedFrom as? Address == Address(street: "AnyStreet", zip: 9090)) + var buffer = Data() + try fory.serialize(value, to: &buffer) + let decodedFrom: any Serializer = try fory.deserialize(from: ByteBuffer(data: buffer)) + #expect(decodedFrom as? Address == Address(street: "AnyStreet", zip: 9090)) } @Test func macroDynamicAnyObjectAndAnySerializerFieldsRoundTrip() throws { - let fory = Fory(config: .init(trackRef: true, compatible: false)) - fory.register(Node.self, id: 220) - fory.register(Address.self, id: 221) - fory.register(AnyObjectHolder.self, id: 222) - fory.register(AnySerializerHolder.self, id: 223) - - let sharedNode = Node(value: 77) - let objectHolder = AnyObjectHolder( - value: sharedNode, - optionalValue: nil, - items: [sharedNode, NSNull()] - ) - let objectData = try fory.serialize(objectHolder) - let objectDecoded: AnyObjectHolder = try fory.deserialize(objectData) - #expect((objectDecoded.value as? Node)?.value == 77) - #expect(objectDecoded.optionalValue == nil) - #expect(objectDecoded.items.count == 2) - #expect((objectDecoded.items[0] as? Node)?.value == 77) - #expect(objectDecoded.items[1] is NSNull) - - let serializerHolder = AnySerializerHolder( - value: Address(street: "Root", zip: 10001), - items: [Int32(11), Address(street: "Nested", zip: 10002)], - map: [ - "age": Int64(19), - "address": Address(street: "Mapped", zip: 10003) - ] - ) - let serializerData = try fory.serialize(serializerHolder) - let serializerDecoded: AnySerializerHolder = try fory.deserialize(serializerData) + let fory = Fory(config: .init(trackRef: true, compatible: false)) + fory.register(Node.self, id: 220) + fory.register(Address.self, id: 221) + fory.register(AnyObjectHolder.self, id: 222) + fory.register(AnySerializerHolder.self, id: 223) + + let sharedNode = Node(value: 77) + let objectHolder = AnyObjectHolder( + value: sharedNode, + optionalValue: nil, + items: [sharedNode, NSNull()] + ) + let objectData = try fory.serialize(objectHolder) + let objectDecoded: AnyObjectHolder = try fory.deserialize(objectData) + #expect((objectDecoded.value as? Node)?.value == 77) + #expect(objectDecoded.optionalValue == nil) + #expect(objectDecoded.items.count == 2) + #expect((objectDecoded.items[0] as? Node)?.value == 77) + #expect(objectDecoded.items[1] is NSNull) + + let serializerHolder = AnySerializerHolder( + value: Address(street: "Root", zip: 10001), + items: [Int32(11), Address(street: "Nested", zip: 10002)], + map: [ + "age": Int64(19), + "address": Address(street: "Mapped", zip: 10003) + ] + ) + let serializerData = try fory.serialize(serializerHolder) + let serializerDecoded: AnySerializerHolder = try fory.deserialize(serializerData) - #expect(serializerDecoded.value as? Address == Address(street: "Root", zip: 10001)) - #expect(serializerDecoded.items.count == 2) - #expect(serializerDecoded.items[0] as? Int32 == 11) - #expect(serializerDecoded.items[1] as? Address == Address(street: "Nested", zip: 10002)) - #expect(serializerDecoded.map["age"] as? Int64 == 19) - #expect(serializerDecoded.map["address"] as? Address == Address(street: "Mapped", zip: 10003)) + #expect(serializerDecoded.value as? Address == Address(street: "Root", zip: 10001)) + #expect(serializerDecoded.items.count == 2) + #expect(serializerDecoded.items[0] as? Int32 == 11) + #expect(serializerDecoded.items[1] as? Address == Address(street: "Nested", zip: 10002)) + #expect(serializerDecoded.map["age"] as? Int64 == 19) + #expect(serializerDecoded.map["address"] as? Address == Address(street: "Mapped", zip: 10003)) } @Test func dynamicAnySerializerTracksRefs() throws { - let fory = Fory(config: .init(trackRef: true, compatible: false)) - fory.register(Node.self, id: 226) - fory.register(AnySerializerHolder.self, id: 227) - - let shared = Node(value: 88) - shared.next = shared - let value = AnySerializerHolder( - value: shared, - items: [shared], - map: ["shared": shared] - ) + let fory = Fory(config: .init(trackRef: true, compatible: false)) + fory.register(Node.self, id: 226) + fory.register(AnySerializerHolder.self, id: 227) - let decoded: AnySerializerHolder = try fory.deserialize(try fory.serialize(value)) - let root = decoded.value as? Node - let item = decoded.items.first as? Node - let mapped = decoded.map["shared"] as? Node + let shared = Node(value: 88) + shared.next = shared + let value = AnySerializerHolder( + value: shared, + items: [shared], + map: ["shared": shared] + ) - #expect(root != nil) - #expect(root === item) - #expect(item === mapped) - #expect(root?.next === root) + let decoded: AnySerializerHolder = try fory.deserialize(try fory.serialize(value)) + let root = decoded.value as? Node + let item = decoded.items.first as? Node + let mapped = decoded.map["shared"] as? Node + + #expect(root != nil) + #expect(root === item) + #expect(item === mapped) + #expect(root?.next === root) } @Test func macroAnyFieldsRoundTrip() throws { - let fory = Fory() - fory.register(Address.self, id: 224) - fory.register(AnyFieldHolder.self, id: 225) - - let value = AnyFieldHolder( - value: Address(street: "AnyRoot", zip: 11001), - optionalValue: nil, - list: [Int32(7), "hello", Address(street: "AnyList", zip: 11002), NSNull()], - stringMap: [ - "count": Int64(3), - "name": "map", - "address": Address(street: "AnyMap", zip: 11003), - "empty": NSNull() - ], - int32Map: [ - 1: Int32(-9), - 2: "v2", - 3: Address(street: "AnyIntMap", zip: 11004), - 4: NSNull() - ] - ) - let data = try fory.serialize(value) - let decoded: AnyFieldHolder = try fory.deserialize(data) - - #expect(decoded.value as? Address == Address(street: "AnyRoot", zip: 11001)) - #expect(decoded.optionalValue == nil) - #expect(decoded.list.count == 4) - #expect(decoded.list[0] as? Int32 == 7) - #expect(decoded.list[1] as? String == "hello") - #expect(decoded.list[2] as? Address == Address(street: "AnyList", zip: 11002)) - #expect(decoded.list[3] is NSNull) - #expect(decoded.stringMap["count"] as? Int64 == 3) - #expect(decoded.stringMap["name"] as? String == "map") - #expect(decoded.stringMap["address"] as? Address == Address(street: "AnyMap", zip: 11003)) - #expect(decoded.stringMap["empty"] is NSNull) - #expect(decoded.int32Map[1] as? Int32 == -9) - #expect(decoded.int32Map[2] as? String == "v2") - #expect(decoded.int32Map[3] as? Address == Address(street: "AnyIntMap", zip: 11004)) - #expect(decoded.int32Map[4] is NSNull) + let fory = Fory() + fory.register(Address.self, id: 224) + fory.register(AnyFieldHolder.self, id: 225) + + let value = AnyFieldHolder( + value: Address(street: "AnyRoot", zip: 11001), + optionalValue: nil, + list: [Int32(7), "hello", Address(street: "AnyList", zip: 11002), NSNull()], + stringMap: [ + "count": Int64(3), + "name": "map", + "address": Address(street: "AnyMap", zip: 11003), + "empty": NSNull() + ], + int32Map: [ + 1: Int32(-9), + 2: "v2", + 3: Address(street: "AnyIntMap", zip: 11004), + 4: NSNull() + ] + ) + let data = try fory.serialize(value) + let decoded: AnyFieldHolder = try fory.deserialize(data) + + #expect(decoded.value as? Address == Address(street: "AnyRoot", zip: 11001)) + #expect(decoded.optionalValue == nil) + #expect(decoded.list.count == 4) + #expect(decoded.list[0] as? Int32 == 7) + #expect(decoded.list[1] as? String == "hello") + #expect(decoded.list[2] as? Address == Address(street: "AnyList", zip: 11002)) + #expect(decoded.list[3] is NSNull) + #expect(decoded.stringMap["count"] as? Int64 == 3) + #expect(decoded.stringMap["name"] as? String == "map") + #expect(decoded.stringMap["address"] as? Address == Address(street: "AnyMap", zip: 11003)) + #expect(decoded.stringMap["empty"] is NSNull) + #expect(decoded.int32Map[1] as? Int32 == -9) + #expect(decoded.int32Map[2] as? String == "v2") + #expect(decoded.int32Map[3] as? Address == Address(street: "AnyIntMap", zip: 11004)) + #expect(decoded.int32Map[4] is NSNull) } @Test func collectionAndMapRefTracking() throws { - let fory = Fory(config: .init(trackRef: true, compatible: false)) - fory.register(Node.self, id: 200) - - let shared = Node(value: 11) - let list: [Node?] = [shared, shared, nil] - let listData = try fory.serialize(list) - let listReader = ByteBuffer(data: listData) - _ = try fory.readHead(buffer: listReader) - _ = try listReader.readInt8() - _ = try listReader.readVarUInt32() - _ = try listReader.readVarUInt32() - let listHeader = try listReader.readUInt8() - #expect((listHeader & 0b0000_0001) != 0) - - let decodedList: [Node?] = try fory.deserialize(listData) - #expect(decodedList.count == 3) - #expect(decodedList[0] === decodedList[1]) - #expect(decodedList[2] == nil) - - let sharedValue = Node(value: 21) - let map: [Int8: Node?] = [1: sharedValue, 2: sharedValue] - let mapData = try fory.serialize(map) - let mapReader = ByteBuffer(data: mapData) - _ = try fory.readHead(buffer: mapReader) - _ = try mapReader.readInt8() - _ = try mapReader.readVarUInt32() - _ = try mapReader.readVarUInt32() - let mapChunkHeader = try mapReader.readUInt8() - #expect((mapChunkHeader & 0b0000_1000) != 0) - - let decodedMap: [Int8: Node?] = try fory.deserialize(mapData) - let v1 = decodedMap[1] ?? nil - let v2 = decodedMap[2] ?? nil - #expect(v1 != nil) - #expect(v1 === v2) + let fory = Fory(config: .init(trackRef: true, compatible: false)) + fory.register(Node.self, id: 200) + + let shared = Node(value: 11) + let list: [Node?] = [shared, shared, nil] + let listData = try fory.serialize(list) + let listReader = ByteBuffer(data: listData) + _ = try fory.readHead(buffer: listReader) + _ = try listReader.readInt8() + _ = try listReader.readVarUInt32() + _ = try listReader.readVarUInt32() + let listHeader = try listReader.readUInt8() + #expect((listHeader & 0b0000_0001) != 0) + + let decodedList: [Node?] = try fory.deserialize(listData) + #expect(decodedList.count == 3) + #expect(decodedList[0] === decodedList[1]) + #expect(decodedList[2] == nil) + + let sharedValue = Node(value: 21) + let map: [Int8: Node?] = [1: sharedValue, 2: sharedValue] + let mapData = try fory.serialize(map) + let mapReader = ByteBuffer(data: mapData) + _ = try fory.readHead(buffer: mapReader) + _ = try mapReader.readInt8() + _ = try mapReader.readVarUInt32() + _ = try mapReader.readVarUInt32() + let mapChunkHeader = try mapReader.readUInt8() + #expect((mapChunkHeader & 0b0000_1000) != 0) + + let decodedMap: [Int8: Node?] = try fory.deserialize(mapData) + let v1 = decodedMap[1] ?? nil + let v2 = decodedMap[2] ?? nil + #expect(v1 != nil) + #expect(v1 === v2) } @Test func macroFieldOrderFollowsForyRules() throws { - let fory = Fory(compatible: false) - fory.register(FieldOrder.self, id: 300) + let fory = Fory(compatible: false) + fory.register(FieldOrder.self, id: 300) - let value = FieldOrder(textTail: "tail", longValue: 123_456_789, shortValue: 17, intValue: 99) - let data = try fory.serialize(value) + let value = FieldOrder(textTail: "tail", longValue: 123_456_789, shortValue: 17, intValue: 99) + let data = try fory.serialize(value) - let buffer = ByteBuffer(data: data) - _ = try fory.readHead(buffer: buffer) - _ = try buffer.readInt8() // root ref flag - _ = try buffer.readVarUInt32() // type id - _ = try buffer.readVarUInt32() // user type id - _ = try buffer.readInt32() // schema hash + let buffer = ByteBuffer(data: data) + _ = try fory.readHead(buffer: buffer) + _ = try buffer.readInt8() // root ref flag + _ = try buffer.readVarUInt32() // type id + _ = try buffer.readVarUInt32() // user type id + _ = try buffer.readInt32() // schema hash - let first = try buffer.readInt16() - let second = try buffer.readVarInt64() - let third = try buffer.readVarInt32() + let first = try buffer.readInt16() + let second = try buffer.readVarInt64() + let third = try buffer.readVarInt32() - let tailContext = ReadContext(buffer: buffer, typeResolver: fory.typeResolver, config: fory.config) - let fourth = try String.foryReadData(tailContext) + let tailContext = ReadContext(buffer: buffer, typeResolver: fory.typeResolver, config: fory.config) + let fourth = try String.foryReadData(tailContext) - #expect(first == value.shortValue) - #expect(second == value.longValue) - #expect(third == value.intValue) - #expect(fourth == value.textTail) + #expect(first == value.shortValue) + #expect(second == value.longValue) + #expect(third == value.intValue) + #expect(fourth == value.textTail) } @Test func macroTaggedFieldsKeepGroupedPayloadOrder() throws { - let fory = Fory(compatible: false) - fory.register(TaggedFieldOrder.self, id: 303) + let fory = Fory(compatible: false) + fory.register(TaggedFieldOrder.self, id: 303) - let fields = TaggedFieldOrder.foryFieldsInfo(trackRef: false) - #expect(fields.map(\.fieldName) == ["intValue", "textTail"]) - #expect(fields.map(\.fieldID) == [10, 1]) + let fields = TaggedFieldOrder.foryFieldsInfo(trackRef: false) + #expect(fields.map(\.fieldName) == ["intValue", "textTail"]) + #expect(fields.map(\.fieldID) == [10, 1]) - let value = TaggedFieldOrder(textTail: "tail", intValue: 99) - let data = try fory.serialize(value) - let buffer = ByteBuffer(data: data) - _ = try fory.readHead(buffer: buffer) - _ = try buffer.readInt8() - _ = try buffer.readVarUInt32() - _ = try buffer.readVarUInt32() - _ = try buffer.readInt32() + let value = TaggedFieldOrder(textTail: "tail", intValue: 99) + let data = try fory.serialize(value) + let buffer = ByteBuffer(data: data) + _ = try fory.readHead(buffer: buffer) + _ = try buffer.readInt8() + _ = try buffer.readVarUInt32() + _ = try buffer.readVarUInt32() + _ = try buffer.readInt32() - #expect(try buffer.readVarInt32() == value.intValue) - let tailContext = ReadContext(buffer: buffer, typeResolver: fory.typeResolver, config: fory.config) - #expect(try String.foryReadData(tailContext) == value.textTail) + #expect(try buffer.readVarInt32() == value.intValue) + let tailContext = ReadContext(buffer: buffer, typeResolver: fory.typeResolver, config: fory.config) + #expect(try String.foryReadData(tailContext) == value.textTail) } @Test func macroNonPrimitiveFieldsSortByFieldIdentifier() throws { - let fields = NonPrimitiveFieldOrder.foryFieldsInfo(trackRef: false) + let fields = NonPrimitiveFieldOrder.foryFieldsInfo(trackRef: false) - #expect( - fields.map(\.fieldName) == [ - "intValue", "mapValue", "stringValue", "addressValue", "binaryValue" - ]) - #expect(fields.map(\.fieldID) == [nil, 10, 20, nil, nil]) + #expect( + fields.map(\.fieldName) == [ + "intValue", "mapValue", "stringValue", "addressValue", "binaryValue" + ]) + #expect(fields.map(\.fieldID) == [nil, 10, 20, nil, nil]) } @Test func macroFieldEncodingOverridesForUnsignedTypes() throws { - let fory = Fory(compatible: false) - fory.register(EncodedNumberFields.self, id: 301) + let fory = Fory(compatible: false) + fory.register(EncodedNumberFields.self, id: 301) - let value = EncodedNumberFields( - u32Fixed: 0x1122_3344, - u64Tagged: UInt64(Int32.max) + 99 - ) - let data = try fory.serialize(value) - let decoded: EncodedNumberFields = try fory.deserialize(data) - #expect(decoded == value) + let value = EncodedNumberFields( + u32Fixed: 0x1122_3344, + u64Tagged: UInt64(Int32.max) + 99 + ) + let data = try fory.serialize(value) + let decoded: EncodedNumberFields = try fory.deserialize(data) + #expect(decoded == value) - let buffer = ByteBuffer(data: data) - _ = try fory.readHead(buffer: buffer) - _ = try buffer.readInt8() - _ = try buffer.readVarUInt32() - _ = try buffer.readVarUInt32() - _ = try buffer.readInt32() + let buffer = ByteBuffer(data: data) + _ = try fory.readHead(buffer: buffer) + _ = try buffer.readInt8() + _ = try buffer.readVarUInt32() + _ = try buffer.readVarUInt32() + _ = try buffer.readInt32() - #expect(try buffer.readUInt32() == value.u32Fixed) - #expect(try buffer.readTaggedUInt64() == value.u64Tagged) + #expect(try buffer.readUInt32() == value.u32Fixed) + #expect(try buffer.readTaggedUInt64() == value.u64Tagged) } @Test func macroEnumUsesExplicitIntegerRawValue() throws { - let fory = Fory(config: .init(trackRef: false, compatible: false)) - fory.register(SparseStatus.self, id: 302) + let fory = Fory(config: .init(trackRef: false, compatible: false)) + fory.register(SparseStatus.self, id: 302) - let data = try fory.serialize(SparseStatus.ok) - let buffer = ByteBuffer(data: data) - _ = try fory.readHead(buffer: buffer) - _ = try buffer.readInt8() - _ = try buffer.readVarUInt32() - _ = try buffer.readVarUInt32() - #expect(try buffer.readVarUInt32() == 8192) + let data = try fory.serialize(SparseStatus.ok) + let buffer = ByteBuffer(data: data) + _ = try fory.readHead(buffer: buffer) + _ = try buffer.readInt8() + _ = try buffer.readVarUInt32() + _ = try buffer.readVarUInt32() + #expect(try buffer.readVarUInt32() == 8192) - let decoded: SparseStatus = try fory.deserialize(data) - #expect(decoded == .ok) + let decoded: SparseStatus = try fory.deserialize(data) + #expect(decoded == .ok) } @Test func macroFieldEncodingOverridesCompatibleTypeMeta() throws { - let fields = EncodedNumberFields.foryFieldsInfo(trackRef: false) - #expect(fields.count == 2) - #expect(fields[0].fieldName == "u32Fixed") - #expect(fields[0].fieldType.typeID == TypeId.uint32.rawValue) - #expect(fields[1].fieldName == "u64Tagged") - #expect(fields[1].fieldType.typeID == TypeId.taggedUInt64.rawValue) + let fields = EncodedNumberFields.foryFieldsInfo(trackRef: false) + #expect(fields.count == 2) + #expect(fields[0].fieldName == "u32Fixed") + #expect(fields[0].fieldType.typeID == TypeId.uint32.rawValue) + #expect(fields[1].fieldName == "u64Tagged") + #expect(fields[1].fieldType.typeID == TypeId.taggedUInt64.rawValue) } @Test func macroReducedPrecisionFieldsUseXlangTypeIDs() { - let fields = ReducedPrecisionMacroFields.foryFieldsInfo(trackRef: false) - #expect(fields.count == 4) - #expect( - fields.map(\.fieldName) == ["float16Value", "bfloat16Value", "bfloat16Array", "float16Array"]) - #expect( - fields.map(\.fieldType.typeID) == [ - TypeId.float16.rawValue, - TypeId.bfloat16.rawValue, - TypeId.bfloat16Array.rawValue, - TypeId.float16Array.rawValue - ]) + let fields = ReducedPrecisionMacroFields.foryFieldsInfo(trackRef: false) + #expect(fields.count == 4) + #expect( + fields.map(\.fieldName) == ["float16Value", "bfloat16Value", "bfloat16Array", "float16Array"]) + #expect( + fields.map(\.fieldType.typeID) == [ + TypeId.float16.rawValue, + TypeId.bfloat16.rawValue, + TypeId.bfloat16Array.rawValue, + TypeId.float16Array.rawValue + ]) } @Test func macroFieldIDsPopulateCompatibleTypeMeta() { - let fields = FieldIdConfigured.foryFieldsInfo(trackRef: false) - #expect(fields.count == 2) - - var byID: [Int16: TypeMeta.FieldInfo] = [:] - for field in fields { - if let id = field.fieldID { - byID[id] = field - } + let fields = FieldIdConfigured.foryFieldsInfo(trackRef: false) + #expect(fields.count == 2) + + var byID: [Int16: TypeMeta.FieldInfo] = [:] + for field in fields { + if let id = field.fieldID { + byID[id] = field } + } - #expect(byID[2]?.fieldName == "stableID") - #expect(byID[2]?.fieldType.typeID == TypeId.varint32.rawValue) - #expect(byID[5]?.fieldName == "fixedValue") - #expect(byID[5]?.fieldType.typeID == TypeId.int32.rawValue) + #expect(byID[2]?.fieldName == "stableID") + #expect(byID[2]?.fieldType.typeID == TypeId.varint32.rawValue) + #expect(byID[5]?.fieldName == "fixedValue") + #expect(byID[5]?.fieldType.typeID == TypeId.int32.rawValue) } @Test func macroFieldIDsDriveCompatibleStructDecodeAcrossRenames() throws { - let writer = Fory(config: .init(trackRef: false, compatible: true)) - writer.register(FieldIdSource.self, id: 9101) + let writer = Fory(config: .init(trackRef: false, compatible: true)) + writer.register(FieldIdSource.self, id: 9101) - let reader = Fory(config: .init(trackRef: false, compatible: true)) - reader.register(FieldIdTarget.self, id: 9101) + let reader = Fory(config: .init(trackRef: false, compatible: true)) + reader.register(FieldIdTarget.self, id: 9101) - let source = FieldIdSource(value: 42, label: "alpha") - let bytes = try writer.serialize(source) - let decoded: FieldIdTarget = try reader.deserialize(bytes) + let source = FieldIdSource(value: 42, label: "alpha") + let bytes = try writer.serialize(source) + let decoded: FieldIdTarget = try reader.deserialize(bytes) - #expect(decoded.renamedValue == source.value) - #expect(decoded.renamedLabel == source.label) + #expect(decoded.renamedValue == source.value) + #expect(decoded.renamedLabel == source.label) - let roundTrip = try reader.serialize(decoded) - let back: FieldIdSource = try writer.deserialize(roundTrip) - #expect(back == source) + let roundTrip = try reader.serialize(decoded) + let back: FieldIdSource = try writer.deserialize(roundTrip) + #expect(back == source) } @Test func macroFieldIDsDriveTaggedUnionDecodeAcrossRenames() throws { - let writer = Fory(config: .init(trackRef: false, compatible: true)) - writer.register(FieldIdUnionSource.self, id: 9102) + let writer = Fory(config: .init(trackRef: false, compatible: true)) + writer.register(FieldIdUnionSource.self, id: 9102) - let reader = Fory(config: .init(trackRef: false, compatible: true)) - reader.register(FieldIdUnionTarget.self, id: 9102) + let reader = Fory(config: .init(trackRef: false, compatible: true)) + reader.register(FieldIdUnionTarget.self, id: 9102) - let source = FieldIdUnionSource.number(123) - let bytes = try writer.serialize(source) - let decoded: FieldIdUnionTarget = try reader.deserialize(bytes) + let source = FieldIdUnionSource.number(123) + let bytes = try writer.serialize(source) + let decoded: FieldIdUnionTarget = try reader.deserialize(bytes) - switch decoded { - case .renamedNumber(let value): - #expect(value == 123) - default: - #expect(Bool(false)) - } + switch decoded { + case .renamedNumber(let value): + #expect(value == 123) + default: + #expect(Bool(false)) + } } @Test func compatibleNestedStructArrayRoundTrip() throws { - let writer = Fory(config: .init(trackRef: false, compatible: true)) - writer.register(CompatibleNestedItem.self, id: 9103) - writer.register(CompatibleNestedArrayHolder.self, id: 9104) - - let reader = Fory(config: .init(trackRef: false, compatible: true)) - reader.register(CompatibleNestedItem.self, id: 9103) - reader.register(CompatibleNestedArrayHolder.self, id: 9104) - - let value = CompatibleNestedArrayHolder( - items: [ - CompatibleNestedItem(id: 1, name: "alpha"), - CompatibleNestedItem(id: 2, name: "beta") - ] - ) - let bytes = try writer.serialize(value) - let decoded: CompatibleNestedArrayHolder = try reader.deserialize(bytes) - #expect(decoded == value) + let writer = Fory(config: .init(trackRef: false, compatible: true)) + writer.register(CompatibleNestedItem.self, id: 9103) + writer.register(CompatibleNestedArrayHolder.self, id: 9104) + + let reader = Fory(config: .init(trackRef: false, compatible: true)) + reader.register(CompatibleNestedItem.self, id: 9103) + reader.register(CompatibleNestedArrayHolder.self, id: 9104) + + let value = CompatibleNestedArrayHolder( + items: [ + CompatibleNestedItem(id: 1, name: "alpha"), + CompatibleNestedItem(id: 2, name: "beta") + ] + ) + let bytes = try writer.serialize(value) + let decoded: CompatibleNestedArrayHolder = try reader.deserialize(bytes) + #expect(decoded == value) } @Test func compatibleNestedStructOptionalArrayRoundTrip() throws { - let writer = Fory(config: .init(trackRef: false, compatible: true)) - writer.register(CompatibleNestedItem.self, id: 9103) - writer.register(CompatibleNestedOptionalArrayHolder.self, id: 9105) - - let reader = Fory(config: .init(trackRef: false, compatible: true)) - reader.register(CompatibleNestedItem.self, id: 9103) - reader.register(CompatibleNestedOptionalArrayHolder.self, id: 9105) - - let value = CompatibleNestedOptionalArrayHolder( - items: [ - CompatibleNestedItem(id: 1, name: "alpha"), - nil, - CompatibleNestedItem(id: 2, name: "beta") - ] - ) - let bytes = try writer.serialize(value) - let decoded: CompatibleNestedOptionalArrayHolder = try reader.deserialize(bytes) - #expect(decoded == value) + let writer = Fory(config: .init(trackRef: false, compatible: true)) + writer.register(CompatibleNestedItem.self, id: 9103) + writer.register(CompatibleNestedOptionalArrayHolder.self, id: 9105) + + let reader = Fory(config: .init(trackRef: false, compatible: true)) + reader.register(CompatibleNestedItem.self, id: 9103) + reader.register(CompatibleNestedOptionalArrayHolder.self, id: 9105) + + let value = CompatibleNestedOptionalArrayHolder( + items: [ + CompatibleNestedItem(id: 1, name: "alpha"), + nil, + CompatibleNestedItem(id: 2, name: "beta") + ] + ) + let bytes = try writer.serialize(value) + let decoded: CompatibleNestedOptionalArrayHolder = try reader.deserialize(bytes) + #expect(decoded == value) } @Test func compatibleNestedStructMapRoundTrip() throws { - let writer = Fory(config: .init(trackRef: false, compatible: true)) - writer.register(CompatibleNestedItem.self, id: 9103) - writer.register(CompatibleNestedMapHolder.self, id: 9106) - - let reader = Fory(config: .init(trackRef: false, compatible: true)) - reader.register(CompatibleNestedItem.self, id: 9103) - reader.register(CompatibleNestedMapHolder.self, id: 9106) - - let value = CompatibleNestedMapHolder( - items: [ - 1: CompatibleNestedItem(id: 10, name: "first"), - 2: CompatibleNestedItem(id: 20, name: "second") - ] - ) - let bytes = try writer.serialize(value) - let decoded: CompatibleNestedMapHolder = try reader.deserialize(bytes) - #expect(decoded == value) + let writer = Fory(config: .init(trackRef: false, compatible: true)) + writer.register(CompatibleNestedItem.self, id: 9103) + writer.register(CompatibleNestedMapHolder.self, id: 9106) + + let reader = Fory(config: .init(trackRef: false, compatible: true)) + reader.register(CompatibleNestedItem.self, id: 9103) + reader.register(CompatibleNestedMapHolder.self, id: 9106) + + let value = CompatibleNestedMapHolder( + items: [ + 1: CompatibleNestedItem(id: 10, name: "first"), + 2: CompatibleNestedItem(id: 20, name: "second") + ] + ) + let bytes = try writer.serialize(value) + let decoded: CompatibleNestedMapHolder = try reader.deserialize(bytes) + #expect(decoded == value) } @Test func pvlVarInt64AndVarUInt64Extremes() throws { - let uintValues: [UInt64] = [ - 0, - 1, - 127, - 128, - 16_383, - 16_384, - 2_097_151, - 2_097_152, - 268_435_455, - 268_435_456, - 34_359_738_367, - 34_359_738_368, - 4_398_046_511_103, - 4_398_046_511_104, - 562_949_953_421_311, - 562_949_953_421_312, - 72_057_594_037_927_935, - 72_057_594_037_927_936, - UInt64(Int64.max), - UInt64.max - ] - let intValues: [Int64] = [ - Int64.min, - Int64.min + 1, - -1_000_000_000_000, - -1_000_000, - -1_000, - -128, - -1, - 0, - 1, - 127, - 1_000, - 1_000_000, - 1_000_000_000_000, - Int64.max - 1, - Int64.max - ] - - let writeBuffer = ByteBuffer() - for value in uintValues { - writeBuffer.writeVarUInt64(value) - } - for value in intValues { - writeBuffer.writeVarInt64(value) - } - let minBuffer = ByteBuffer() - minBuffer.writeVarInt64(Int64.min) - #expect(minBuffer.count == 9) - #expect(minBuffer.storage.prefix(minBuffer.count).allSatisfy { $0 == 0xFF }) - - let encoded = Array(writeBuffer.storage.prefix(writeBuffer.count)) - - let readBuffer = ByteBuffer(bytes: encoded) - for value in uintValues { - #expect(try readBuffer.readVarUInt64() == value) - } - for value in intValues { - #expect(try readBuffer.readVarInt64() == value) - } - #expect(readBuffer.remaining == 0) + let uintValues: [UInt64] = [ + 0, + 1, + 127, + 128, + 16_383, + 16_384, + 2_097_151, + 2_097_152, + 268_435_455, + 268_435_456, + 34_359_738_367, + 34_359_738_368, + 4_398_046_511_103, + 4_398_046_511_104, + 562_949_953_421_311, + 562_949_953_421_312, + 72_057_594_037_927_935, + 72_057_594_037_927_936, + UInt64(Int64.max), + UInt64.max + ] + let intValues: [Int64] = [ + Int64.min, + Int64.min + 1, + -1_000_000_000_000, + -1_000_000, + -1_000, + -128, + -1, + 0, + 1, + 127, + 1_000, + 1_000_000, + 1_000_000_000_000, + Int64.max - 1, + Int64.max + ] + + let writeBuffer = ByteBuffer() + for value in uintValues { + writeBuffer.writeVarUInt64(value) + } + for value in intValues { + writeBuffer.writeVarInt64(value) + } + let minBuffer = ByteBuffer() + minBuffer.writeVarInt64(Int64.min) + #expect(minBuffer.count == 9) + #expect(minBuffer.storage.prefix(minBuffer.count).allSatisfy { $0 == 0xFF }) + + let encoded = Array(writeBuffer.storage.prefix(writeBuffer.count)) + + let readBuffer = ByteBuffer(bytes: encoded) + for value in uintValues { + #expect(try readBuffer.readVarUInt64() == value) + } + for value in intValues { + #expect(try readBuffer.readVarInt64() == value) + } + #expect(readBuffer.remaining == 0) } @Test func metaStringEncodingRoundTrip() throws { - let encoder = MetaStringEncoder.fieldName - let decoder = MetaStringDecoder.fieldName + let encoder = MetaStringEncoder.fieldName + let decoder = MetaStringDecoder.fieldName - let lower = try encoder.encode("alpha_beta", encoding: .lowerSpecial) - #expect(lower.encoding == .lowerSpecial) - #expect(try decoder.decode(bytes: lower.bytes, encoding: lower.encoding).value == "alpha_beta") + let lower = try encoder.encode("alpha_beta", encoding: .lowerSpecial) + #expect(lower.encoding == .lowerSpecial) + #expect(try decoder.decode(bytes: lower.bytes, encoding: lower.encoding).value == "alpha_beta") - let firstLower = try encoder.encode("User_name", encoding: .firstToLowerSpecial) - #expect(firstLower.encoding == .firstToLowerSpecial) - #expect( - try decoder.decode(bytes: firstLower.bytes, encoding: firstLower.encoding).value == "User_name") + let firstLower = try encoder.encode("User_name", encoding: .firstToLowerSpecial) + #expect(firstLower.encoding == .firstToLowerSpecial) + #expect( + try decoder.decode(bytes: firstLower.bytes, encoding: firstLower.encoding).value == "User_name") - let allLower = try encoder.encode("MyHTTPType", encoding: .allToLowerSpecial) - #expect(allLower.encoding == .allToLowerSpecial) - #expect( - try decoder.decode(bytes: allLower.bytes, encoding: allLower.encoding).value == "MyHTTPType") + let allLower = try encoder.encode("MyHTTPType", encoding: .allToLowerSpecial) + #expect(allLower.encoding == .allToLowerSpecial) + #expect( + try decoder.decode(bytes: allLower.bytes, encoding: allLower.encoding).value == "MyHTTPType") - let lowerUpperDigit = try encoder.encode("userId2", encoding: .lowerUpperDigitSpecial) - #expect(lowerUpperDigit.encoding == .lowerUpperDigitSpecial) - #expect( - try decoder.decode(bytes: lowerUpperDigit.bytes, encoding: lowerUpperDigit.encoding).value - == "userId2") + let lowerUpperDigit = try encoder.encode("userId2", encoding: .lowerUpperDigitSpecial) + #expect(lowerUpperDigit.encoding == .lowerUpperDigitSpecial) + #expect( + try decoder.decode(bytes: lowerUpperDigit.bytes, encoding: lowerUpperDigit.encoding).value + == "userId2") - let autoUtf8 = try encoder.encode("naïve_meta") - #expect(autoUtf8.encoding == .utf8) - #expect( - try decoder.decode(bytes: autoUtf8.bytes, encoding: autoUtf8.encoding).value == "naïve_meta") + let autoUtf8 = try encoder.encode("naïve_meta") + #expect(autoUtf8.encoding == .utf8) + #expect( + try decoder.decode(bytes: autoUtf8.bytes, encoding: autoUtf8.encoding).value == "naïve_meta") } @Test func typeMetaRoundTripByName() throws { - let namespace = try MetaStringEncoder.namespace.encode("com.example") - let typeName = try MetaStringEncoder.typeName.encode("UserProfile") - - let fields: [TypeMeta.FieldInfo] = [ - .init( - fieldID: nil, - fieldName: "createdAt", - fieldType: .init(typeID: TypeId.varint64.rawValue, nullable: false) - ), - .init( - fieldID: nil, - fieldName: "tags", - fieldType: .init( - typeID: TypeId.list.rawValue, - nullable: false, - generics: [.init(typeID: TypeId.string.rawValue, nullable: true)] - ) - ), - .init( - fieldID: nil, - fieldName: "attributes", - fieldType: .init( - typeID: TypeId.map.rawValue, - nullable: true, - generics: [ - .init(typeID: TypeId.string.rawValue, nullable: false), - .init(typeID: TypeId.varint32.rawValue, nullable: true) - ] - ) - ), - .init( - fieldID: 7, - fieldName: "ignored_for_tag_mode", - fieldType: .init(typeID: TypeId.varint32.rawValue, nullable: false) - ) - ] - - let meta = try TypeMeta( - typeID: TypeId.namedStruct.rawValue, - userTypeID: nil, - namespace: namespace, - typeName: typeName, - registerByName: true, - fields: fields + let namespace = try MetaStringEncoder.namespace.encode("com.example") + let typeName = try MetaStringEncoder.typeName.encode("UserProfile") + + let fields: [TypeMeta.FieldInfo] = [ + .init( + fieldID: nil, + fieldName: "createdAt", + fieldType: .init(typeID: TypeId.varint64.rawValue, nullable: false) + ), + .init( + fieldID: nil, + fieldName: "tags", + fieldType: .init( + typeID: TypeId.list.rawValue, + nullable: false, + generics: [.init(typeID: TypeId.string.rawValue, nullable: true)] + ) + ), + .init( + fieldID: nil, + fieldName: "attributes", + fieldType: .init( + typeID: TypeId.map.rawValue, + nullable: true, + generics: [ + .init(typeID: TypeId.string.rawValue, nullable: false), + .init(typeID: TypeId.varint32.rawValue, nullable: true) + ] + ) + ), + .init( + fieldID: 7, + fieldName: "ignored_for_tag_mode", + fieldType: .init(typeID: TypeId.varint32.rawValue, nullable: false) ) - - let encoded = try meta.encode() - let decoded = try TypeMeta.decode(encoded) - - #expect(decoded.registerByName == true) - #expect(decoded.namespace.value == "com.example") - #expect(decoded.typeName.value == "UserProfile") - #expect(decoded.typeID == TypeId.namedStruct.rawValue) - #expect(decoded.userTypeID == nil) - #expect(decoded.fields.count == 4) - #expect(decoded.fields[0].fieldName == "created_at") - #expect(decoded.fields[3].fieldID == 7) + ] + + let meta = try TypeMeta( + typeID: TypeId.namedStruct.rawValue, + userTypeID: nil, + namespace: namespace, + typeName: typeName, + registerByName: true, + fields: fields + ) + + let encoded = try meta.encode() + let decoded = try TypeMeta.decode(encoded) + + #expect(decoded.registerByName == true) + #expect(decoded.namespace.value == "com.example") + #expect(decoded.typeName.value == "UserProfile") + #expect(decoded.typeID == TypeId.namedStruct.rawValue) + #expect(decoded.userTypeID == nil) + #expect(decoded.fields.count == 4) + #expect(decoded.fields[0].fieldName == "created_at") + #expect(decoded.fields[3].fieldID == 7) } @Test func typeMetaRoundTripByID() throws { - let emptyNamespace = MetaString.empty(specialChar1: ".", specialChar2: "_") - let emptyTypeName = MetaString.empty(specialChar1: "$", specialChar2: "_") - - let meta = try TypeMeta( - typeID: TypeId.structType.rawValue, - userTypeID: 101, - namespace: emptyNamespace, - typeName: emptyTypeName, - registerByName: false, - fields: [] - ) + let emptyNamespace = MetaString.empty(specialChar1: ".", specialChar2: "_") + let emptyTypeName = MetaString.empty(specialChar1: "$", specialChar2: "_") - let encoded = try meta.encode() - let decoded = try TypeMeta.decode(encoded) + let meta = try TypeMeta( + typeID: TypeId.structType.rawValue, + userTypeID: 101, + namespace: emptyNamespace, + typeName: emptyTypeName, + registerByName: false, + fields: [] + ) - #expect(decoded.registerByName == false) - #expect(decoded.typeID == TypeId.structType.rawValue) - #expect(decoded.userTypeID == 101) - #expect(decoded.fields.isEmpty) + let encoded = try meta.encode() + let decoded = try TypeMeta.decode(encoded) + + #expect(decoded.registerByName == false) + #expect(decoded.typeID == TypeId.structType.rawValue) + #expect(decoded.userTypeID == 101) + #expect(decoded.fields.isEmpty) } @Test func typeMetaHeaderHashIncludesHeaderLowBits() throws { - let emptyNamespace = MetaString.empty(specialChar1: ".", specialChar2: "_") - let emptyTypeName = MetaString.empty(specialChar1: "$", specialChar2: "_") - - let meta = try TypeMeta( - typeID: TypeId.structType.rawValue, - userTypeID: 102, - namespace: emptyNamespace, - typeName: emptyTypeName, - registerByName: false, - fields: [] - ) - - var encoded = try meta.encode() - let header = try ByteBuffer(bytes: encoded).readUInt64() - let hashMask = UInt64.max << 12 - let bodyOnlyHash = bodyOnlyTypeMetaHeaderHash(Array(encoded.dropFirst(8))) - #expect((header & hashMask) != bodyOnlyHash) - let rewrittenHeader = bodyOnlyHash | (header & ~hashMask) - for index in 0..<8 { - encoded[index] = UInt8(truncatingIfNeeded: rewrittenHeader >> (index * 8)) - } - - #expect(throws: ForyError.self) { - _ = try TypeMeta.decode(encoded) - } + let emptyNamespace = MetaString.empty(specialChar1: ".", specialChar2: "_") + let emptyTypeName = MetaString.empty(specialChar1: "$", specialChar2: "_") + + let meta = try TypeMeta( + typeID: TypeId.structType.rawValue, + userTypeID: 102, + namespace: emptyNamespace, + typeName: emptyTypeName, + registerByName: false, + fields: [] + ) + + var encoded = try meta.encode() + let header = try ByteBuffer(bytes: encoded).readUInt64() + let hashMask = UInt64.max << 12 + let bodyOnlyHash = bodyOnlyTypeMetaHeaderHash(Array(encoded.dropFirst(8))) + #expect((header & hashMask) != bodyOnlyHash) + let rewrittenHeader = bodyOnlyHash | (header & ~hashMask) + for index in 0..<8 { + encoded[index] = UInt8(truncatingIfNeeded: rewrittenHeader >> (index * 8)) + } + + #expect(throws: ForyError.self) { + _ = try TypeMeta.decode(encoded) + } } private func bodyOnlyTypeMetaHeaderHash(_ body: [UInt8]) -> UInt64 { - let shifted = MurmurHash3.x64_128(body, seed: 47).0 << 12 - let signed = Int64(bitPattern: shifted) - let absSigned = signed == Int64.min ? signed : Swift.abs(signed) - return UInt64(bitPattern: absSigned) & (UInt64.max << 12) + let shifted = MurmurHash3.x64_128(body, seed: 47).0 << 12 + let signed = Int64(bitPattern: shifted) + let absSigned = signed == Int64.min ? signed : Swift.abs(signed) + return UInt64(bitPattern: absSigned) & (UInt64.max << 12) } From c3086ec59ee18508fb63a65626f2bcb3208ff608 Mon Sep 17 00:00:00 2001 From: chaokunyang Date: Sat, 27 Jun 2026 02:56:59 +0800 Subject: [PATCH 02/54] fix: repair container memory budget CI --- go/fory/buffer.go | 11 +++++++++++ go/fory/container_memory_budget_test.go | 18 ++++++++++++++++++ go/fory/fory.go | 9 +++++++++ .../CompatibleDifferentSchemaExample.java | 2 +- python/pyfory/buffer.pxi | 9 +++++++++ python/pyfory/context.py | 3 +-- .../scala/ForySerializerDerivationTest.scala | 2 +- 7 files changed, 50 insertions(+), 4 deletions(-) diff --git a/go/fory/buffer.go b/go/fory/buffer.go index 89e29f938d..1a1a067b7e 100644 --- a/go/fory/buffer.go +++ b/go/fory/buffer.go @@ -482,6 +482,17 @@ func (b *ByteBuffer) ReaderIndex() int { return b.readerIndex } +func (b *ByteBuffer) readableBytes() int { + end := b.writerIndex + if len(b.data) > end { + end = len(b.data) + } + if b.readerIndex >= end { + return 0 + } + return end - b.readerIndex +} + func (b *ByteBuffer) SetReaderIndex(index int) { b.readerIndex = index } diff --git a/go/fory/container_memory_budget_test.go b/go/fory/container_memory_budget_test.go index 16959b3d0a..e83e8d7a03 100644 --- a/go/fory/container_memory_budget_test.go +++ b/go/fory/container_memory_budget_test.go @@ -85,6 +85,24 @@ func TestContainerMemoryBudgetKnownVsStreamRoot(t *testing.T) { require.Len(t, fromStream, len(values)) } +func TestContainerMemoryBudgetBufferRoots(t *testing.T) { + writer := New(WithCompatible(false)) + value := []string{"a", "b"} + data, err := writer.Serialize(value) + require.NoError(t, err) + + reader := New(WithCompatible(false)) + var fromCallback []string + err = reader.DeserializeWithCallbackBuffers(NewByteBuffer(data), &fromCallback, nil) + require.NoError(t, err) + require.Equal(t, value, fromCallback) + + var fromBuffer []string + err = reader.DeserializeFrom(NewByteBuffer(data), &fromBuffer) + require.NoError(t, err) + require.Equal(t, value, fromBuffer) +} + func TestContainerMemoryBudgetExplicitOverride(t *testing.T) { writer := New(WithCompatible(false)) values := make([]any, 12000) diff --git a/go/fory/fory.go b/go/fory/fory.go index 7bfb9867ef..3b6360aece 100644 --- a/go/fory/fory.go +++ b/go/fory/fory.go @@ -666,6 +666,11 @@ func (f *Fory) DeserializeFrom(buf *ByteBuffer, v any) error { // Temporarily swap buffer origBuffer := f.readCtx.buffer f.readCtx.buffer = buf + f.readCtx.initContainerMemoryBudget(buf.readableBytes(), false) + if f.readCtx.HasError() { + f.readCtx.buffer = origBuffer + return f.readCtx.TakeError() + } readHeader(f.readCtx) if f.readCtx.HasError() { @@ -761,6 +766,10 @@ func (f *Fory) DeserializeWithCallbackBuffers(buffer *ByteBuffer, v any, buffers f.readCtx.buffer = nil f.readCtx.outOfBandBuffers = nil }() + f.readCtx.initContainerMemoryBudget(buffer.readableBytes(), false) + if f.readCtx.HasError() { + return f.readCtx.TakeError() + } // Set up out-of-band buffers if provided if buffers != nil { f.readCtx.outOfBandBuffers = buffers diff --git a/integration_tests/graalvm_tests/src/main/java/org/apache/fory/graalvm/CompatibleDifferentSchemaExample.java b/integration_tests/graalvm_tests/src/main/java/org/apache/fory/graalvm/CompatibleDifferentSchemaExample.java index 41ddae5cda..c70d016deb 100644 --- a/integration_tests/graalvm_tests/src/main/java/org/apache/fory/graalvm/CompatibleDifferentSchemaExample.java +++ b/integration_tests/graalvm_tests/src/main/java/org/apache/fory/graalvm/CompatibleDifferentSchemaExample.java @@ -89,7 +89,7 @@ private static Serializer readSerializerForTarget( MemoryBuffer buffer = MemoryUtils.wrap(bytes); buffer.readByte(); ReadContext readContext = fory.getReadContext(); - readContext.prepare(buffer, null, false); + readContext.prepare(buffer, null, false, buffer.remaining(), false); try { readContext.getRefReader().tryPreserveRefId(buffer); TypeInfo typeInfo = fory.getTypeResolver().readTypeInfo(readContext, targetClass); diff --git a/python/pyfory/buffer.pxi b/python/pyfory/buffer.pxi index 4fa77e9e68..53ff441f25 100644 --- a/python/pyfory/buffer.pxi +++ b/python/pyfory/buffer.pxi @@ -335,6 +335,15 @@ cdef class Buffer: f"Address range {offset, offset + length} out of bound {0, size_}", ) + cpdef inline ensure_readable(self, int32_t length): + if length < 0: + raise_fory_error(CErrorCode.InvalidData, f"Readable byte count {length} is negative") + if length == 0: + return + if not self.c_buffer.ensure_readable(length, self._error): + if not self._error.ok(): + self._raise_if_error() + cpdef inline write_bool(self, c_bool value): self.c_buffer.write_uint8(value) diff --git a/python/pyfory/context.py b/python/pyfory/context.py index a923731c4b..8b620629da 100644 --- a/python/pyfory/context.py +++ b/python/pyfory/context.py @@ -527,8 +527,7 @@ def check_readable_bytes(self, length): raise ValueError(f"Readable byte count {length} is negative") if length == 0: return - reader_index = self.buffer.get_reader_index() - self.buffer.check_bound(reader_index, length) + self.buffer.ensure_readable(length) def prepare( self, diff --git a/scala/src/test/scala-3/org/apache/fory/serializer/scala/ForySerializerDerivationTest.scala b/scala/src/test/scala-3/org/apache/fory/serializer/scala/ForySerializerDerivationTest.scala index 7cf598f381..22e7044d23 100644 --- a/scala/src/test/scala-3/org/apache/fory/serializer/scala/ForySerializerDerivationTest.scala +++ b/scala/src/test/scala-3/org/apache/fory/serializer/scala/ForySerializerDerivationTest.scala @@ -571,7 +571,7 @@ import org.apache.fory.scala.ForyScala buffer.readVarUInt32() shouldBe 0 buffer.readerIndex(0) val readContext = fory.getReadContext - readContext.prepare(buffer, null, false) + readContext.prepare(buffer, null, false, buffer.remaining(), false) try serializer.read(readContext) shouldBe value finally readContext.reset() } From 649b436df3ca12f814787d3651492d45ed1ef7c9 Mon Sep 17 00:00:00 2001 From: chaokunyang Date: Tue, 30 Jun 2026 11:38:30 +0800 Subject: [PATCH 03/54] fix(cpp): use portable container budget estimates --- .agents/languages/cpp.md | 12 +- AGENTS.md | 2 +- .../serialization/collection_serializer.h | 227 +++++++++++++----- .../container_memory_budget_test.cc | 122 ++++++++-- cpp/fory/serialization/map_serializer.h | 14 +- docs/guide/cpp/configuration.md | 11 +- docs/security/deserialization.md | 8 +- .../xlang_implementation_guide.md | 20 +- 8 files changed, 316 insertions(+), 100 deletions(-) diff --git a/.agents/languages/cpp.md b/.agents/languages/cpp.md index f640aa8552..6a77070570 100644 --- a/.agents/languages/cpp.md +++ b/.agents/languages/cpp.md @@ -22,8 +22,16 @@ Load this file when changing `cpp/`, Cython build plumbing, or C++ xlang behavio explicit limit; known byte roots use `inputBytes * 8 + 64 KiB`, while stream roots use fixed `128 MiB`. Reserve estimated container-owned memory before allocation but preserve existing byte-availability checks and their non-empty metadata ordering. Skip only dedicated string, - binary, primitive vector, and primitive dense-array owners; general `std::vector` for - non-primitive `T` is inline container storage and must be charged. + binary, primitive vector, and primitive dense-array owners; `std::vector` is the C++ + standard-container exception and should charge rounded packed-bit storage. General + `std::vector` for non-primitive `T` is inline container storage and must be charged. +- C++ container budget formulas must be portable lower-bound estimates, not STL heap-layout + accounting. Generic collection-like containers charge `count_or_capacity * sizeof(value_type)`, + map-like containers charge `count * (sizeof(key_type) + sizeof(mapped_type))`, and set-like + containers charge `count * sizeof(key_type)`. Do not add guessed node/header/debug-STL overhead, + red-black-tree fields, allocator probing, object-layout inspection, or generic per-entry pointer + overhead. Charge unordered bucket tables only if the owner path has a cheap pre-allocation bucket + count; otherwise leave them uncharged. ## Key Paths diff --git a/AGENTS.md b/AGENTS.md index c8346e0f4f..bee96b4fe4 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -33,7 +33,7 @@ This is the entry point for AI guidance in Apache Fory. Read this file first, th - Check the spec before implementation. For wire behavior and xlang mapping, use the specs as the source of truth and never copy one runtime's bug into another runtime just to make tests pass. - Do not make assumptions about runtime behavior, ownership, registration, metadata construction, protocol semantics, or test coverage. Read the current code, owning docs/specs, and relevant tests before making a design judgment or implementation decision. If the evidence is incomplete, inspect more or state the uncertainty explicitly instead of filling gaps from memory or analogy with another runtime. - For untrusted deserialization, read `docs/security/deserialization.md` before changing allocation, stream filling, skip, reference, metadata, or policy validation behavior. Variable-length deserialization must not allocate or reserve backing/output capacity from attacker-declared lengths or counts before the byte owner has proven proportional readable bytes with `checkReadableBytes` or the runtime equivalent. Container memory-budget reservation is accounting only and may happen before that byte check, but it must not replace the byte check. -- Root deserialization container memory budgets are estimated container-owned memory, not exact heap accounting and not raw slots. Positive `maxContainerMemoryBytes` wins; `-1` auto uses `inputBytes * 8 + 64 KiB` for known-length roots and fixed `128 MiB` for true stream or unknown-length roots. Charge fixed container cost, backing/reference/inline storage, map table and entry overhead, and zero-size containers; skip only dedicated string, binary, primitive array, and primitive dense-array owners. Do not add dynamic stream bytes-read accounting or nested hot-path cleanup just for this budget. +- Root deserialization container memory budgets are estimated container-owned memory, not exact heap accounting and not raw slots. Positive `maxContainerMemoryBytes` wins; `-1` auto uses `inputBytes * 8 + 64 KiB` for known-length roots and fixed `128 MiB` for true stream or unknown-length roots. Charge fixed container cost, backing/reference/inline storage, map table and entry overhead where the runtime has cheap reliable signals, and zero-size containers; skip only dedicated string, binary, primitive array, and primitive dense-array owners. Native runtimes such as C++ must prefer portable lower-bound formulas over non-portable STL or allocator layout guesses. Do not add dynamic stream bytes-read accounting or nested hot-path cleanup just for this budget. - For remote TypeDef/TypeMeta reads, the checked metadata cache is the only owner of remote "already validated" state. Cache hit means the header was previously parsed, body/hash-validated, policy-checked, and published by that cache, so the hot path must skip the body and use cached metadata without extra validation, hashing, limit checks, exact-local checks, allocation, or policy work. A known expected local TypeDef/TypeMeta header/hash match is a local-schema hit, not a remote cache miss: it may skip the body and use the local TypeInfo/TypeMeta without schema-version counting or cache publish. Cache miss is the only path that parses and validates non-local metadata, enforces limits, performs exact-local byte comparison when needed, and publishes remote metadata to the cache. Do not add nullable accepted-header fields, sentinel headers, per-TypeInfo markers, pending metadata state, parallel header-low/header-high slots, or parallel acceptance state for this decision. If a runtime needs a metadata hit hint, cache the concrete checked metadata owner object, such as the TypeInfo, TypeDef, or TypeMeta used by that runtime, and compare its validated header identity directly. - When a user corrects a non-obvious invariant, encode it in the nearest source comment before continuing, and also update `AGENTS.md`, `.agents/**`, docs, or specs when the rule is reusable beyond one file. Do not rely only on chat history, task notes, commit messages, or benchmark logs for corrections that protect security, protocol behavior, ownership, naming, or hot-path performance. - Reject semantic hacks. Do not bypass broken semantics by deleting cases, simplifying callers, adding coercion hooks, or using workaround fallbacks; fix the underlying bug and prove it with focused tests. diff --git a/cpp/fory/serialization/collection_serializer.h b/cpp/fory/serialization/collection_serializer.h index 7c89e1a265..5aca15f728 100644 --- a/cpp/fory/serialization/collection_serializer.h +++ b/cpp/fory/serialization/collection_serializer.h @@ -22,6 +22,7 @@ #include "fory/serialization/array_serializer.h" #include "fory/serialization/serializer.h" #include +#include #include #include #include @@ -29,6 +30,7 @@ #include #include #include +#include #include #include #include @@ -380,32 +382,12 @@ struct has_reserve inline constexpr bool has_reserve_v = has_reserve::value; -constexpr size_t kContainerEntryOverheadBytes = 16; -constexpr size_t kContainerReferenceBytes = sizeof(void *); - -template -struct is_std_vector_container : std::false_type {}; - -template -struct is_std_vector_container> : std::true_type {}; - -template -inline constexpr bool is_std_vector_container_v = - is_std_vector_container::value; - template constexpr size_t collection_element_memory_bytes() { using Elem = typename Container::value_type; - if constexpr (is_std_vector_container_v) { - return sizeof(Elem); - } else { - static_assert(sizeof(Elem) <= std::numeric_limits::max() - - kContainerEntryOverheadBytes - - kContainerReferenceBytes * 2, - "container element memory estimate overflows"); - return sizeof(Elem) + kContainerEntryOverheadBytes + - kContainerReferenceBytes * 2; - } + // Portable lower-bound estimate only: STL node/header/debug-layout details + // differ across implementations, so generic collections charge value storage. + return sizeof(Elem); } template @@ -431,6 +413,33 @@ inline bool reserve_collection(Container &result, ReadContext &ctx, return true; } +template +inline bool reserve_collection(std::vector &result, + ReadContext &ctx, uint32_t length) { + if (FORY_PREDICT_FALSE(ctx.has_error())) { + return false; + } + constexpr size_t fixed_bytes = sizeof(std::vector); + constexpr size_t max_packed_bytes = + (static_cast(std::numeric_limits::max()) + CHAR_BIT - + 1) / + CHAR_BIT; + static_assert(fixed_bytes <= + std::numeric_limits::max() - max_packed_bytes, + "vector memory estimate overflows"); + const size_t packed_bytes = + (static_cast(length) + CHAR_BIT - 1) / CHAR_BIT; + if (FORY_PREDICT_FALSE( + !ctx.reserve_container_memory(fixed_bytes + packed_bytes))) { + return false; + } + if (FORY_PREDICT_FALSE(!ctx.buffer().ensure_readable(length, ctx.error()))) { + return false; + } + result.reserve(length); + return true; +} + template inline bool reserve_empty_collection(ReadContext &ctx) { if (FORY_PREDICT_FALSE(ctx.has_error())) { @@ -579,6 +588,126 @@ inline Container read_collection_data_slow(ReadContext &ctx, uint32_t length) { return result; } +/// Read forward_list data without a temporary vector so budget accounting only +/// covers the destination container's portable lower-bound storage. +template +inline std::forward_list +read_forward_list_data_slow(ReadContext &ctx, uint32_t length) { + std::forward_list result; + if (length == 0) { + if (FORY_PREDICT_FALSE( + (!reserve_empty_collection>(ctx)))) { + return result; + } + return result; + } + + constexpr bool elem_is_polymorphic = is_polymorphic_v; + + uint8_t bitmap = ctx.read_uint8(ctx.error()); + if (FORY_PREDICT_FALSE(ctx.has_error())) { + return result; + } + + bool track_ref = (bitmap & COLL_TRACKING_REF) != 0; + bool has_null = (bitmap & COLL_HAS_NULL) != 0; + bool is_decl_type = (bitmap & COLL_DECL_ELEMENT_TYPE) != 0; + bool is_same_type = (bitmap & COLL_IS_SAME_TYPE) != 0; + + const TypeInfo *elem_type_info = nullptr; + if (is_same_type && !is_decl_type) { + elem_type_info = ctx.read_any_type_info(ctx.error()); + if (FORY_PREDICT_FALSE(ctx.has_error())) { + return result; + } + } + + if (FORY_PREDICT_FALSE(!reserve_collection(result, ctx, length))) { + return result; + } + + auto tail = result.before_begin(); + auto append = [&](T &&elem) { + tail = result.insert_after(tail, std::move(elem)); + }; + auto append_default = [&]() { tail = result.emplace_after(tail); }; + + if (is_same_type) { + if (track_ref) { + for (uint32_t i = 0; i < length; ++i) { + if (FORY_PREDICT_FALSE(ctx.has_error())) { + return result; + } + if constexpr (elem_is_polymorphic) { + auto elem = Serializer::read_with_type_info(ctx, RefMode::Tracking, + *elem_type_info); + append(std::move(elem)); + } else { + auto elem = Serializer::read(ctx, RefMode::Tracking, false); + append(std::move(elem)); + } + } + } else if (!has_null) { + for (uint32_t i = 0; i < length; ++i) { + if (FORY_PREDICT_FALSE(ctx.has_error())) { + return result; + } + if constexpr (elem_is_polymorphic) { + auto elem = Serializer::read_with_type_info(ctx, RefMode::None, + *elem_type_info); + append(std::move(elem)); + } else { + auto elem = Serializer::read(ctx, RefMode::None, false); + append(std::move(elem)); + } + } + } else { + for (uint32_t i = 0; i < length; ++i) { + if (FORY_PREDICT_FALSE(ctx.has_error())) { + return result; + } + bool has_value = read_null_only_flag(ctx, RefMode::NullOnly); + if (!has_value) { + append_default(); + } else if constexpr (elem_is_polymorphic) { + auto elem = Serializer::read_with_type_info(ctx, RefMode::None, + *elem_type_info); + append(std::move(elem)); + } else { + auto elem = Serializer::read(ctx, RefMode::None, false); + append(std::move(elem)); + } + } + } + } else { + if (has_null && !track_ref) { + for (uint32_t i = 0; i < length; ++i) { + if (FORY_PREDICT_FALSE(ctx.has_error())) { + return result; + } + bool has_value = read_null_only_flag(ctx, RefMode::NullOnly); + if (!has_value) { + append_default(); + } else { + auto elem = Serializer::read(ctx, RefMode::None, true); + append(std::move(elem)); + } + } + } else { + for (uint32_t i = 0; i < length; ++i) { + if (FORY_PREDICT_FALSE(ctx.has_error())) { + return result; + } + auto elem = Serializer::read( + ctx, track_ref ? RefMode::Tracking : RefMode::None, true); + append(std::move(elem)); + } + } + } + + return result; +} + // ============================================================================ // std::vector serializer // ============================================================================ @@ -1204,11 +1333,11 @@ template struct Serializer> { return std::vector(); } - if (FORY_PREDICT_FALSE(!ctx.buffer().ensure_readable(size, ctx.error()))) { - return std::vector(); + std::vector result; + if (FORY_PREDICT_FALSE(!reserve_collection(result, ctx, size))) { + return result; } - - std::vector result(size); + result.resize(size); Buffer &buffer = ctx.buffer(); if (size > 0) { const uint8_t *src = buffer.data() + buffer.reader_index(); @@ -1698,13 +1827,11 @@ struct Serializer> { } // Dispatch to slow path for polymorphic/shared-ref elements - // Read elements into a temporary vector then build forward_list - // (forward_list doesn't have push_back, only push_front) - std::vector temp; constexpr bool is_slow_path = is_polymorphic_v || is_shared_ref_v; if constexpr (is_slow_path) { - temp = read_collection_data_slow>(ctx, length); + return read_forward_list_data_slow(ctx, length); } else { + auto tail = result.before_begin(); // Fast path for non-polymorphic, non-shared-ref elements // Elements header bitmap (CollectionFlags) @@ -1737,54 +1864,45 @@ struct Serializer> { if (FORY_PREDICT_FALSE(!reserve_collection(result, ctx, length))) { return result; } - - if (FORY_PREDICT_FALSE(!reserve_collection(temp, ctx, length))) { - return result; - } // Fast path: no tracking, no nulls, elements have declared type if (!track_ref && !has_null && is_same_type) { for (uint32_t i = 0; i < length; ++i) { if (FORY_PREDICT_FALSE(ctx.has_error())) { - break; + return result; } auto elem = Serializer::read(ctx, RefMode::None, false); - temp.push_back(std::move(elem)); + tail = result.insert_after(tail, std::move(elem)); } } else { // General path: handle HAS_NULL and/or TRACKING_REF for (uint32_t i = 0; i < length; ++i) { if (FORY_PREDICT_FALSE(ctx.has_error())) { - break; + return result; } if (track_ref) { auto elem = Serializer::read(ctx, RefMode::Tracking, false); - temp.push_back(std::move(elem)); + tail = result.insert_after(tail, std::move(elem)); } else if (has_null) { bool has_value_elem = read_null_only_flag(ctx, RefMode::NullOnly); if (!has_value_elem) { - temp.emplace_back(); + tail = result.emplace_after(tail); } else { if constexpr (is_nullable_v) { using Inner = nullable_element_t; auto inner = Serializer::read(ctx, RefMode::None, false); - temp.emplace_back(std::move(inner)); + tail = result.emplace_after(tail, std::move(inner)); } else { auto elem = Serializer::read(ctx, RefMode::None, false); - temp.push_back(std::move(elem)); + tail = result.insert_after(tail, std::move(elem)); } } } else { auto elem = Serializer::read(ctx, RefMode::None, false); - temp.push_back(std::move(elem)); + tail = result.insert_after(tail, std::move(elem)); } } } } - - // Build forward_list in reverse order using push_front - for (auto it = temp.rbegin(); it != temp.rend(); ++it) { - result.push_front(std::move(*it)); - } return result; } @@ -2072,20 +2190,13 @@ struct Serializer> { if (FORY_PREDICT_FALSE(!reserve_collection(result, ctx, size))) { return result; } - std::vector temp; - if (FORY_PREDICT_FALSE(!reserve_collection(temp, ctx, size))) { - return result; - } + auto tail = result.before_begin(); for (uint32_t i = 0; i < size; ++i) { if (FORY_PREDICT_FALSE(ctx.has_error())) { - break; + return result; } auto elem = Serializer::read_data(ctx); - temp.push_back(std::move(elem)); - } - // Build forward_list in reverse order - for (auto it = temp.rbegin(); it != temp.rend(); ++it) { - result.push_front(std::move(*it)); + tail = result.insert_after(tail, std::move(elem)); } return result; } diff --git a/cpp/fory/serialization/container_memory_budget_test.cc b/cpp/fory/serialization/container_memory_budget_test.cc index 781e9312bc..296f74027a 100644 --- a/cpp/fory/serialization/container_memory_budget_test.cc +++ b/cpp/fory/serialization/container_memory_budget_test.cc @@ -19,10 +19,18 @@ #include "fory/serialization/fory.h" #include "gtest/gtest.h" +#include +#include #include +#include +#include +#include #include +#include #include #include +#include +#include #include #include @@ -54,6 +62,17 @@ struct BudgetSiblings { FORY_STRUCT(BudgetSiblings, left, right); }; +struct BudgetFixedArrayOwner { + std::array prefix{}; + std::vector items; + + bool operator==(const BudgetFixedArrayOwner &other) const { + return prefix == other.prefix && items == other.items; + } + + FORY_STRUCT(BudgetFixedArrayOwner, prefix, items); +}; + template auto with_fory(int64_t max_container_memory_bytes, Fn &&fn) { auto fory = Fory::builder() @@ -64,6 +83,7 @@ auto with_fory(int64_t max_container_memory_bytes, Fn &&fn) { .build(); fory.register_struct(1); fory.register_struct(2); + fory.register_struct(3); return std::forward(fn)(fory); } @@ -79,6 +99,24 @@ size_t nested_empty_budget(size_t count) { return sizeof(Outer) + count * sizeof(Inner) + count * sizeof(Inner); } +template +void expect_budget_boundary(const T &value, size_t required) { + ASSERT_GT(required, 0u); + auto bytes = serialize_value(value); + + auto small_result = + with_fory(static_cast(required - 1), + [&](Fory &fory) { return fory.deserialize(bytes); }); + ASSERT_FALSE(small_result.ok()); + EXPECT_EQ(small_result.error().code(), ErrorCode::InvalidData); + + auto exact_result = + with_fory(static_cast(required), + [&](Fory &fory) { return fory.deserialize(bytes); }); + ASSERT_TRUE(exact_result.ok()) << exact_result.error().to_string(); + EXPECT_EQ(exact_result.value(), value); +} + TEST(ContainerMemoryBudgetTest, KnownLengthAutoBudget) { constexpr size_t count = 3000; std::vector> value(count); @@ -197,24 +235,78 @@ TEST(ContainerMemoryBudgetTest, SiblingCumulativeBudget) { TEST(ContainerMemoryBudgetTest, MapBudget) { std::map value{{"a", 1}, {"b", 2}, {"c", 3}}; - auto bytes = serialize_value(value); - const size_t entry_bytes = - sizeof(std::string) + sizeof(int32_t) + 16 + sizeof(void *) * 3; + const size_t entry_bytes = sizeof(std::string) + sizeof(int32_t); const size_t required = sizeof(value) + value.size() * entry_bytes; - auto small_result = - with_fory(static_cast(required - 1), [&](Fory &fory) { - return fory.deserialize>(bytes); - }); - ASSERT_FALSE(small_result.ok()); - EXPECT_EQ(small_result.error().code(), ErrorCode::InvalidData); + expect_budget_boundary(value, required); +} - auto exact_result = - with_fory(static_cast(required), [&](Fory &fory) { - return fory.deserialize>(bytes); - }); - ASSERT_TRUE(exact_result.ok()) << exact_result.error().to_string(); - EXPECT_EQ(exact_result.value(), value); +TEST(ContainerMemoryBudgetTest, CollectionLowerBounds) { + std::deque deque_value(4); + expect_budget_boundary(deque_value, + sizeof(deque_value) + + deque_value.size() * sizeof(BudgetItem)); + + std::list list_value(4); + expect_budget_boundary( + list_value, sizeof(list_value) + list_value.size() * sizeof(BudgetItem)); + + std::forward_list forward_value(4); + expect_budget_boundary(forward_value, sizeof(forward_value) + + size_t{4} * sizeof(BudgetItem)); +} + +TEST(ContainerMemoryBudgetTest, VectorBoolUsesPackedStorage) { + std::vector value(33); + value[0] = true; + value[32] = true; + const size_t packed_bytes = (value.size() + CHAR_BIT - 1) / CHAR_BIT; + const size_t required = sizeof(value) + packed_bytes; + ASSERT_LT(required, sizeof(value) + value.size()); + + expect_budget_boundary(value, required); +} + +TEST(ContainerMemoryBudgetTest, OrderedSetAndMapLowerBounds) { + std::set set_value{1, 2, 3, 4}; + expect_budget_boundary(set_value, sizeof(set_value) + + set_value.size() * sizeof(int32_t)); + + std::map map_value{{"a", 1}, {"b", 2}}; + expect_budget_boundary( + map_value, sizeof(map_value) + map_value.size() * (sizeof(std::string) + + sizeof(int32_t))); +} + +TEST(ContainerMemoryBudgetTest, UnorderedContainersLowerBounds) { + std::unordered_set set_value{1, 2, 3, 4}; + expect_budget_boundary(set_value, sizeof(set_value) + + set_value.size() * sizeof(int32_t)); + + std::unordered_map map_value{{"a", 1}, {"b", 2}}; + expect_budget_boundary( + map_value, sizeof(map_value) + map_value.size() * (sizeof(std::string) + + sizeof(int32_t))); +} + +TEST(ContainerMemoryBudgetTest, ArrayHasNoStandaloneReservation) { + std::array value{{1, 2, 3, 4}}; + auto bytes = serialize_value(value); + auto result = with_fory(1, [&](Fory &fory) { + return fory.deserialize>(bytes); + }); + ASSERT_TRUE(result.ok()) << result.error().to_string(); + EXPECT_EQ(result.value(), value); +} + +TEST(ContainerMemoryBudgetTest, FixedInlineOwnerChargesNestedVector) { + BudgetFixedArrayOwner value; + value.prefix = {{1, 2, 3, 4}}; + value.items.resize(3); + const size_t required = + sizeof(value.items) + value.items.size() * sizeof(BudgetItem); + + expect_budget_boundary(value, required); } TEST(ContainerMemoryBudgetTest, DensePathsSkipped) { diff --git a/cpp/fory/serialization/map_serializer.h b/cpp/fory/serialization/map_serializer.h index a7a3bc615d..a654dae83b 100644 --- a/cpp/fory/serialization/map_serializer.h +++ b/cpp/fory/serialization/map_serializer.h @@ -82,9 +82,6 @@ struct MapReserver inline bool reserve_map(MapType &map, ReadContext &ctx, uint32_t length) { // Lazy error propagation may continue into later readers; do not let that @@ -94,14 +91,13 @@ inline bool reserve_map(MapType &map, ReadContext &ctx, uint32_t length) { } using Key = typename MapType::key_type; using Value = typename MapType::mapped_type; - static_assert(sizeof(Key) <= std::numeric_limits::max() - - sizeof(Value) - kMapEntryBudgetBytes - - kMapReferenceBudgetBytes * 3, + // Portable lower-bound estimate only: ordered and unordered map node layouts + // vary across STL implementations, allocators, and debug modes. + static_assert(sizeof(Key) <= + std::numeric_limits::max() - sizeof(Value), "map entry memory estimate overflows"); constexpr size_t fixed_bytes = sizeof(MapType); - constexpr size_t elem_bytes = sizeof(Key) + sizeof(Value) + - kMapEntryBudgetBytes + - kMapReferenceBudgetBytes * 3; + constexpr size_t elem_bytes = sizeof(Key) + sizeof(Value); if (FORY_PREDICT_FALSE((!ctx.template reserve_counted_container_memory< fixed_bytes, elem_bytes>(length)))) { return false; diff --git a/docs/guide/cpp/configuration.md b/docs/guide/cpp/configuration.md index aee6e633d5..199d7cab72 100644 --- a/docs/guide/cpp/configuration.md +++ b/docs/guide/cpp/configuration.md @@ -112,10 +112,13 @@ automatic limit is the root input size multiplied by `8`, plus `64 KiB`. For stream roots, the automatic limit is `128 MiB` because the full root size is not known up front. Positive values always override the automatic limit. -This budget is an estimate for container-owned memory such as collection -objects, backing storage, map entries, and object/reference arrays. It is not an -exact process heap limit. Dedicated string, binary, and primitive dense-array -payloads continue to rely on their byte-availability checks instead. +This budget is a portable lower-bound estimate for container-owned memory such +as collection objects, backing storage, map value storage, and object/reference +arrays. It is not an exact process heap limit and does not include STL +implementation details such as debug nodes or allocator headers. Dedicated +string, binary, and primitive dense-array payloads continue to rely on their +byte-availability checks instead. `std::vector` is counted as packed +standard-container storage. **Default:** `-1` diff --git a/docs/security/deserialization.md b/docs/security/deserialization.md index 109fb225c2..39003627f6 100644 --- a/docs/security/deserialization.md +++ b/docs/security/deserialization.md @@ -225,8 +225,8 @@ Container budget accounting should: deserialization `finally`; - reject arithmetic overflow before comparing budget or allocating; - charge fixed container object cost, backing capacity, map table and entry - overhead, reference arrays, and inline or value storage where a runtime stores - elements inline; + overhead where the runtime has cheap reliable signals, reference arrays, and + inline or value storage where a runtime stores elements inline; - charge fixed cost even for zero-size containers; - preserve existing byte-availability checks before backing allocation or capacity reservation; @@ -241,6 +241,10 @@ the inline element storage instead of treating those elements as references. General inline-value containers must not be skipped just because dedicated primitive dense arrays are skipped. +Native runtimes may use conservative lower-bound estimates when exact container +layout is not portable. For example, C++ STL node, allocator, and debug-mode +overheads should not be guessed when only value storage is reliably known. + ## Skip Semantics Skipping unknown or incompatible data is classified by concrete impact, not by diff --git a/docs/specification/xlang_implementation_guide.md b/docs/specification/xlang_implementation_guide.md index 53f41887fd..2fd57ffa5e 100644 --- a/docs/specification/xlang_implementation_guide.md +++ b/docs/specification/xlang_implementation_guide.md @@ -407,15 +407,17 @@ fixed `128 MiB` for true stream or unknown-length root input. Do not add dynamic stream bytes-read accounting for this budget. The budget estimates container-owned memory, not exact heap bytes. Charge fixed -container object cost, backing capacity, map table and entry overhead, -reference arrays, and inline/value element storage where the runtime stores -container elements inline. Charge zero-size containers for their fixed cost. -Skip dedicated string, binary, primitive array, and primitive dense-array owners, -but do not skip general inline-value containers such as vectors or lists of -value objects. If reference slot size is not cheap or reliable to query, use a -4-byte reference slot. Reject arithmetic overflow before budget comparison or -allocation, and keep the existing `checkReadableBytes` proof before backing -allocation or capacity reservation. +container object cost, backing capacity, map table and entry overhead where the +runtime has cheap reliable signals, reference arrays, and inline/value element +storage where the runtime stores container elements inline. Charge zero-size +containers for their fixed cost. Skip dedicated string, binary, primitive +array, and primitive dense-array owners, but do not skip general inline-value +containers such as vectors or lists of value objects. If reference slot size is +not cheap or reliable to query, use a 4-byte reference slot. Native runtimes may +use conservative lower-bound estimates instead of guessing non-portable +container, allocator, or debug-layout details. Reject arithmetic overflow before +budget comparison or allocation, and keep the existing `checkReadableBytes` +proof before backing allocation or capacity reservation. For TypeDef or TypeMeta bodies, first prove that the encoded metadata body bytes are readable through the byte owner. Field-list allocation should happen after From 45106e4a78d91039630cbd97d02bf5c97f455505 Mon Sep 17 00:00:00 2001 From: chaokunyang Date: Tue, 30 Jun 2026 12:53:12 +0800 Subject: [PATCH 04/54] docs: pair benchmark cases for perf comparisons --- .agents/skills/fory-performance-optimization/SKILL.md | 4 ++++ AGENTS.md | 1 + 2 files changed, 5 insertions(+) diff --git a/.agents/skills/fory-performance-optimization/SKILL.md b/.agents/skills/fory-performance-optimization/SKILL.md index 2bbe17f801..a66955da8f 100644 --- a/.agents/skills/fory-performance-optimization/SKILL.md +++ b/.agents/skills/fory-performance-optimization/SKILL.md @@ -15,6 +15,8 @@ Deliver measurable performance improvements in Apache Fory without protocol drif - Profile before changing hot code. - Change one bottleneck at a time. - Benchmark sequentially on the same machine state (one benchmark process at a time). +- Compare old/new benchmark results case-by-case in adjacent pairs: run one case on `apache/main`, + then immediately run that same case on the current branch before moving to the next case. - Keep only measured wins or explicitly requested architecture cleanups. - Revert speculative changes that do not pay off. - Align with reference runtimes (usually C++ first, then Rust/Java) when behavior and ownership models differ. @@ -73,6 +75,8 @@ Deliver measurable performance improvements in Apache Fory without protocol drif 7. Benchmark and compare. - Run targeted benchmark at least twice sequentially. +- Pair each baseline case with the matching current-branch case before starting another case, so + both measurements see closer machine load conditions. - Use longer duration when signal is noisy. - Run one short full-suite sanity benchmark to catch collateral regressions. diff --git a/AGENTS.md b/AGENTS.md index bee96b4fe4..be3962a6fa 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -176,6 +176,7 @@ This is the entry point for AI guidance in Apache Fory. Read this file first, th - When comparing benchmark results against `apache/main`, use a separate sibling worktree named `fory-benchmark-baseline` by default. Before creating a new worktree, check whether `../fory-benchmark-baseline` already exists and reuse it to avoid repeated benchmark dependency rebuilds. Always fetch and sync that baseline worktree to the latest `apache/main` before measuring it, and store benchmark result files under that worktree so older runs remain available as reference data. Treat stored benchmark results as historical references, not truth, because machine load and benchmark variance change over time. Create a different baseline worktree only when explicitly requested. - Before benchmarking a checked-out version, install or build the required Fory packages for that version, such as the Java artifacts, Python package, and the target runtime package needed by the benchmark. - Run and close old/new benchmark comparisons for exactly one language at a time. If that language has a slowdown greater than 1%, keep working only on that language until the slowdown is within 1% before moving to the next language. +- Within one language, compare benchmarks case-by-case in adjacent old/new pairs: run one case on fresh `apache/main`, then immediately run the same case on the current branch before moving to the next case. Do not batch all baseline cases and then all current cases, because machine load drift makes that comparison noisier. - Treat a same-benchmark slowdown greater than 1% as unresolved until the retained median is within 1% of the baseline. Faster results are acceptable only after verifying that generated code, benchmark shape, safety checks, and protocol semantics did not skip required work. Do not add artificial slowdowns or benchmark-shape changes to force a match. - Do not change protocol behavior, benchmark payloads, or public APIs solely to manufacture performance wins. - For performance work, run the relevant benchmark immediately after each change and report the command plus before/after numbers. From f34e7c3a09c61f262d87d046dab69893726eaf59 Mon Sep 17 00:00:00 2001 From: chaokunyang Date: Tue, 30 Jun 2026 13:52:54 +0800 Subject: [PATCH 05/54] feat: add lower-bound container memory budget --- .agents/languages/cpp.md | 12 +- .agents/languages/csharp.md | 9 +- .agents/languages/dart.md | 2 +- .agents/languages/go.md | 14 +- .agents/languages/java.md | 14 +- .agents/languages/javascript.md | 11 +- .agents/languages/python.md | 11 +- .agents/languages/rust.md | 13 +- .agents/languages/swift.md | 9 +- AGENTS.md | 2 +- .../serialization/collection_serializer.h | 19 +-- .../container_memory_budget_test.cc | 81 ++++------- cpp/fory/serialization/context.cc | 12 +- cpp/fory/serialization/context.h | 28 +--- cpp/fory/serialization/map_serializer.h | 8 +- .../src/Fory.Generator/ForyModelGenerator.cs | 6 +- csharp/src/Fory/CollectionSerializers.cs | 30 ++-- csharp/src/Fory/DictionarySerializers.cs | 17 ++- csharp/src/Fory/NullableKeyDictionary.cs | 17 ++- .../Fory/PrimitiveDictionarySerializers.cs | 18 ++- csharp/src/Fory/ReadContext.cs | 105 ++------------ .../Fory.Tests/ContainerMemoryBudgetTests.cs | 42 +++--- .../fory/lib/src/context/read_context.dart | 40 ------ .../serializer/collection_serializers.dart | 114 ++++++++-------- .../lib/src/serializer/map_serializers.dart | 120 ++++++++-------- .../test/container_memory_budget_test.dart | 40 +++--- docs/guide/cpp/configuration.md | 15 +- docs/guide/csharp/configuration.md | 7 +- docs/guide/dart/configuration.md | 9 +- docs/guide/go/configuration.md | 6 +- docs/guide/java/configuration.md | 7 +- docs/guide/javascript/configuration.md | 11 +- docs/guide/python/configuration.md | 5 +- docs/guide/rust/configuration.md | 7 +- docs/guide/swift/configuration.md | 13 +- docs/security/deserialization.md | 24 +++- .../xlang_implementation_guide.md | 34 +++-- go/fory/array.go | 2 +- go/fory/codegen/decoder.go | 18 +-- go/fory/container_memory_budget_test.go | 26 ++-- go/fory/field_serializer.go | 2 +- go/fory/map.go | 9 +- go/fory/map_primitive.go | 38 +++--- go/fory/reader.go | 113 ++------------- go/fory/set.go | 18 ++- go/fory/slice.go | 8 +- go/fory/slice_dyn.go | 10 +- go/fory/slice_primitive.go | 2 +- go/fory/slice_primitive_list.go | 8 +- go/fory/tests/structs_fory_gen.go | 32 ++--- .../org/apache/fory/context/ReadContext.java | 25 ---- .../fory/serializer/ArraySerializers.java | 4 +- .../CompatibleCollectionArrayReader.java | 6 +- .../collection/CollectionLikeSerializer.java | 4 +- .../collection/CollectionSerializers.java | 4 +- .../collection/MapLikeSerializer.java | 3 +- .../serializer/ContainerMemoryBudgetTest.java | 77 ++--------- javascript/packages/core/lib/context.ts | 40 ------ .../packages/core/lib/gen/collection.ts | 8 +- javascript/packages/core/lib/gen/map.ts | 6 +- javascript/test/containerMemoryBudget.test.ts | 25 ++-- python/pyfory/collection.pxi | 72 +++++----- python/pyfory/collection.py | 7 +- python/pyfory/context.pxi | 34 ++--- python/pyfory/context.py | 24 +--- python/pyfory/serializer.py | 4 +- .../tests/test_container_memory_budget.py | 17 +-- rust/fory-core/src/context.rs | 59 ++------ rust/fory-core/src/serializer/codec.rs | 14 +- rust/fory-core/src/serializer/collection.rs | 6 +- rust/fory-core/src/serializer/map.rs | 10 +- .../tests/test_container_memory_budget.rs | 41 +++--- swift/Sources/Fory/AnySerializer.swift | 26 +++- .../Sources/Fory/CollectionSerializers.swift | 49 ++++++- swift/Sources/Fory/FieldCodecs.swift | 64 +++++++-- swift/Sources/Fory/ReadContext.swift | 129 ++---------------- .../ContainerMemoryBudgetTests.swift | 58 +++----- 77 files changed, 860 insertions(+), 1174 deletions(-) diff --git a/.agents/languages/cpp.md b/.agents/languages/cpp.md index 6a77070570..30f780e266 100644 --- a/.agents/languages/cpp.md +++ b/.agents/languages/cpp.md @@ -21,17 +21,19 @@ Load this file when changing `cpp/`, Cython build plumbing, or C++ xlang behavio `Fory::deserialize` overload. Keep `max_container_memory_bytes` as `-1 / auto` or a positive explicit limit; known byte roots use `inputBytes * 8 + 64 KiB`, while stream roots use fixed `128 MiB`. Reserve estimated container-owned memory before allocation but preserve existing - byte-availability checks and their non-empty metadata ordering. Skip only dedicated string, + byte-availability checks and their non-empty metadata ordering. `ReadContext` may expose only raw + byte reservation and generic counted-byte arithmetic; collection/map formulas belong in serializer + owners. Empty containers with no dynamic + backing storage normally charge zero. Skip only dedicated string, binary, primitive vector, and primitive dense-array owners; `std::vector` is the C++ standard-container exception and should charge rounded packed-bit storage. General `std::vector` for non-primitive `T` is inline container storage and must be charged. - C++ container budget formulas must be portable lower-bound estimates, not STL heap-layout accounting. Generic collection-like containers charge `count_or_capacity * sizeof(value_type)`, map-like containers charge `count * (sizeof(key_type) + sizeof(mapped_type))`, and set-like - containers charge `count * sizeof(key_type)`. Do not add guessed node/header/debug-STL overhead, - red-black-tree fields, allocator probing, object-layout inspection, or generic per-entry pointer - overhead. Charge unordered bucket tables only if the owner path has a cheap pre-allocation bucket - count; otherwise leave them uncharged. + containers charge `count * sizeof(key_type)`. Do not charge standalone `sizeof(Container)` and do + not add guessed node/header/debug-STL overhead, red-black-tree fields, allocator probing, + object-layout inspection, generic per-entry pointer overhead, or unordered bucket-table guesses. ## Key Paths diff --git a/.agents/languages/csharp.md b/.agents/languages/csharp.md index 8785b440e1..c29acd822b 100644 --- a/.agents/languages/csharp.md +++ b/.agents/languages/csharp.md @@ -12,8 +12,13 @@ Load this file when changing `csharp/` or C# xlang behavior. - Generated C# gRPC service companions are compiler-owned files that depend on application-provided gRPC packages, not `csharp/src/Fory`. Keep gRPC package references out of the Fory runtime package. - C# generated schema modules are source-file owners. Service companions must use that module's `ThreadSafeFory` and must not introduce namespace-owned aliases or duplicate serializer registration paths. - Compatible scalar, list-array, and binary/uint8-array adaptations are immediate-field-only. Recursive matched-field comparison for collection elements, array elements, map keys, and map values must require exact nullability, ref tracking, generic arity, and type shape except documented user-type family normalization. -- Root deserialization container memory budget state belongs to `ReadContext`. C# public roots are memory-backed today, so auto uses known input length; generated serializers may call `ReadContext`'s generated-code reservation helpers, but should not expose or depend on serializer helper classes such as `CollectionCodec`. -- For C# container budget formulas, distinguish inline value storage from reference storage: use cheap value-type size for `List`/`T[]` value paths and the 4-byte reference fallback for reference paths. Dedicated string, binary, and primitive dense-array serializers stay skipped and rely on byte availability checks. +- Root deserialization container memory budget state belongs to `ReadContext`. C# public roots are memory-backed today, so auto uses known input length. `ReadContext` may expose only raw byte reservation and generic counted-byte arithmetic; concrete serializers and generated serializers must compute list/array/map byte formulas before calling it. +- For C# container budget formulas, distinguish inline value storage from reference storage: use + cheap value-type size for `List`/`T[]` value paths and the 4-byte reference fallback for + reference paths. Maps charge key plus value storage, linked/hash/tree conversions must not add + guessed node or entry overhead, and empty containers with no backing storage normally charge zero. + Dedicated string, binary, and primitive dense-array serializers stay skipped and rely on byte + availability checks. - When extending C# tests from Java references, prioritize xlang spec behavior and the public C# contract before adding complex Java-specific parity cases. ## Commands diff --git a/.agents/languages/dart.md b/.agents/languages/dart.md index 9db2504693..085cede2a1 100644 --- a/.agents/languages/dart.md +++ b/.agents/languages/dart.md @@ -14,7 +14,7 @@ Load this file when changing `dart/`. - Keep root numeric wrapper defaults separate from generated field metadata. Root wrapper resolution belongs in the builtin resolver, while annotations and generated metadata choose fixed, tagged, or declared-field encodings. - Dart 64-bit carriers are optimized for each platform. Do not replace native extension-type wrappers with allocation-heavy classes or route web/native hot paths through `BigInt` unless the user approves a representation change. - In `Buffer`, cursor, serializer, and generated-code hot paths, prefer direct byte/local integer operations and conditional import/export files over callbacks, records, holder objects, wrapper round-trips, or runtime platform branches. -- Root deserialization container memory budgets are owned by `ReadContext`; `maxContainerMemoryBytes` defaults to `-1 / auto`, positive explicit values win, and Dart auto uses `buffer.readableBytes * 8 + 64 KiB` because roots are memory-backed. Charge Dart lists, sets, maps, object/reference arrays, compatible list-to-array inline storage, and compatible array-to-list materialization before allocation. Skip only dedicated string, binary, `BoolList`, and typed-array dense owner paths with byte checks. Do not add stream bytes-read accounting, per-element accounting, or extra hot-path allocations for this budget. +- Root deserialization container memory budgets are owned by `ReadContext`; `maxContainerMemoryBytes` defaults to `-1 / auto`, positive explicit values win, and Dart auto uses `buffer.readableBytes * 8 + 64 KiB` because roots are memory-backed. `ReadContext` may expose only raw byte reservation and generic counted-byte arithmetic; list/set/map/object-array formulas belong in serializer owners. Charge Dart list/set/object-array reference slots, map key/value slots, compatible list-to-array inline storage, and compatible array-to-list materialization before allocation. Empty containers with no backing storage normally charge zero. Skip only dedicated string, binary, `BoolList`, and typed-array dense owner paths with byte checks. Do not add stream bytes-read accounting, per-element accounting, or extra hot-path allocations for this budget. - Do not add parallel header-low/header-high slot caches or multi-slot recent caches in TypeMeta hot paths to chase benchmark gaps. Header-cache hits must use the concrete checked cache owner directly; if a hit hint is needed, cache one TypeInfo/TypeMeta object and compare the validated header identity on that object, not separate low/high header fields or benchmark-pattern state. - If Dart TypeMeta cache ownership changes, keep the invariant in a source comment near the hit path: a checked metadata-cache hit skips the body and must not grow low-bit sentinels, accepted-header fields, parallel header slots, or benchmark-pattern state. - Dart expected-type TypeDef reads should compare the expected `TypeInfo` object's cached local TypeDef header before consulting the parsed-metadata map. A match is a direct local-schema hit: skip the remote body, add the expected type to the per-read shared type table, and do not publish to `ParsedTypeMetaCache`, record a remote schema version, or parse/hash the body. diff --git a/.agents/languages/go.md b/.agents/languages/go.md index 949dd5030e..e0aa050d14 100644 --- a/.agents/languages/go.md +++ b/.agents/languages/go.md @@ -10,11 +10,15 @@ Load this file when changing `go/fory/` or Go xlang behavior. - Root deserialization container memory budgets are owned by `ReadContext`. `WithMaxContainerMemoryBytes` defaults to `-1 / auto`; byte-slice roots use `inputBytes * 8 + 64 KiB`, and `DeserializeFromReader`/`DeserializeFromStream` - use fixed `128 MiB`. Charge Go slices, maps, map-backed sets, LIST-encoded - inline/value slices, and generated container reads before allocation. Fixed - arrays are caller-owned and normally not charged; `arrayDynSerializer` charges - its temporary slice. Skip only dedicated string, binary, BufferObject, - primitive ARRAY slice, and primitive array owners with byte checks. + use fixed `128 MiB`. `ReadContext` may expose only raw byte reservation and + generic counted-byte arithmetic; slice/map formulas belong in handwritten or + generated serializer owners. Charge Go slices as `len * elemBytes`, maps as + `len * (keyBytes + valueBytes)`, map-backed sets, LIST-encoded inline/value + slices, and generated container reads before allocation. Empty containers with + no backing storage normally charge zero. Fixed arrays are caller-owned and + normally not charged; `arrayDynSerializer` charges its temporary slice. Skip + only dedicated string, binary, BufferObject, primitive ARRAY slice, and + primitive array owners with byte checks. - Set `FORY_PANIC_ON_ERROR=1` when debugging a failing Go test so you get the full call stack. - Do not set `FORY_PANIC_ON_ERROR=1` when running the full Go test suite, because some tests assert on error contents. diff --git a/.agents/languages/java.md b/.agents/languages/java.md index f60cfb249b..0e84f65453 100644 --- a/.agents/languages/java.md +++ b/.agents/languages/java.md @@ -18,11 +18,15 @@ Load this file when changing anything under `java/` or when Java drives a cross- and is initialized by `Fory` root APIs. Public config is `maxContainerMemoryBytes` with `-1` auto, positive explicit override, known-length auto `inputBytes * 8 + 64 KiB`, and stream/unknown auto - `128 MiB`. Collection/map/object-array serializers should charge estimated - container-owned memory before allocation while preserving existing - `checkReadableBytes` guards before backing allocation or capacity - reservation. Do not add nested serializer-path `try/finally`, per-element - work, or dynamic stream bytes-read accounting for this budget. + `128 MiB`. `ReadContext` may expose only raw byte reservation and generic + counted-byte arithmetic; collection/map/object-array formulas belong in the + concrete serializer owner. Java collection/object-array paths charge reference slots only, and + maps charge two reference slots per entry. Fixed/header, map table, and map + entry overhead are not charged unless a future owner documents a conservative + independent lower-bound signal. Preserve existing `checkReadableBytes` guards + before backing allocation or capacity reservation. Do not add nested + serializer-path `try/finally`, per-element work, or dynamic stream bytes-read + accounting for this budget. - Generated serializers must not retain runtime context fields. `Fory` should stay a root-operation facade rather than accumulating serializer or convenience state. - When the serializer class and constructor shape are known at the call site, prefer direct constructor lambdas or direct instantiation over reflective `Serializers.newSerializer(...)`. - For GraalVM, use `fory codegen` to generate serializers when building native images. Do not add reflection configuration except for JDK `proxy`. diff --git a/.agents/languages/javascript.md b/.agents/languages/javascript.md index 4781b5ece2..afdc8519c6 100644 --- a/.agents/languages/javascript.md +++ b/.agents/languages/javascript.md @@ -16,11 +16,14 @@ Load this file when changing `javascript/`. - Normalize optional boolean config values at config construction; do not carry `null` through runtime paths when it means `false`. - JavaScript root deserialization container memory budgeting belongs to `ReadContext`. `maxContainerMemoryBytes` uses `-1` auto, positive explicit limits, and known - `Uint8Array` root length as `inputBytes * 8 + 64 KiB`. Generated and dynamic + `Uint8Array` root length as `inputBytes * 8 + 64 KiB`. `ReadContext` may expose only raw + byte reservation and generic counted-byte arithmetic; generated and dynamic list/set/map readers must reserve before allocation while preserving existing - byte checks. Keep dedicated string, binary, and dense typed-array owners out of - this budget; compatible list-to-typed-array reads must charge typed inline - storage. + byte checks. Lists/sets/object arrays charge 4-byte reference slots, maps charge + two 4-byte references per entry, and empty containers with no backing storage + normally charge zero. Keep dedicated string, binary, and dense typed-array + owners out of this budget; compatible list-to-typed-array reads must charge + typed inline storage. - Regenerated compatible read serializers are remote-schema-specific. After classification marks a field as direct, compatible scalar, or skip, generated JavaScript should emit straight-line remote-field-order code. Do not add an outer matched-id switch unless the current regenerated shape cannot preserve those semantics. - Compatible scalar codegen must decide the exact remote/local scalar pair before emitting source. Generate the concrete `reader.readXxx()` call plus inline trivial conversions such as boolean-to-string or numeric widening, and keep helpers only for semantic validation such as range checks, exactness checks, decimal parsing/formatting, and string-to-bool. Do not call a generic hot-path converter that redispatches on `remoteTypeId`, `localTypeId`, field descriptors, or field names. - Compatible scalar conversion is immediate-field-only. Recursive schema comparison for collection elements, array elements, map keys, and map values must reject scalar mismatches instead of applying the top-level scalar conversion matrix. diff --git a/.agents/languages/python.md b/.agents/languages/python.md index 3ed69c6eb7..0468193825 100644 --- a/.agents/languages/python.md +++ b/.agents/languages/python.md @@ -15,10 +15,13 @@ Load this file when changing `python/`, Cython serialization, or Python xlang be - Keep Python and Cython context/ref-tracking branch conditions and stack mutations semantically aligned unless a documented intentional difference exists. - Root deserialization container memory budgets are owned by pure-Python and Cython `ReadContext`. Keep `max_container_memory_bytes` public on `pyfory.Fory`/`Config`; `-1` uses known-length - `inputBytes * 8 + 64 KiB` or fixed `128 MiB` for stream roots. Reserve fixed container cost plus - reference slots for list/tuple/set, map object/table/entry estimates for dict, and object-dtype - ndarray item storage. Keep string, bytes, `array.array`, primitive dense array, and primitive - ndarray owners skipped, and preserve byte-availability checks after budget reservation. + `inputBytes * 8 + 64 KiB` or fixed `128 MiB` for stream roots. `ReadContext` may expose only raw + byte reservation and generic counted-byte arithmetic; collection and dict formulas belong in the + pure-Python or Cython serializer owner. Lists, tuples, sets, and + object-dtype ndarray item storage charge `count * PyObject*`; dicts charge + `entryCount * 2 * PyObject*`. Fixed/header cost defaults to zero unless a path documents an + independent lower-bound owner. Keep string, bytes, `array.array`, primitive dense array, and + primitive ndarray owners skipped, and preserve byte-availability checks after budget reservation. - Public value constructors should accept normal Python values. Raw-bit, raw-buffer, and memoryview entry points should be explicit low-level APIs, and packed carriers should expose the buffer protocol from the actual storage owner when appropriate. - When debugging runtime or benchmark behavior, install the local package into the exact interpreter under test instead of relying on mixed `PYTHONPATH` state. - For wheel or extension pipeline changes, derive extension-module paths from current build targets, packaging config, or wheel payload discovery rather than historical module names. diff --git a/.agents/languages/rust.md b/.agents/languages/rust.md index a31126586c..1f19e27d4f 100644 --- a/.agents/languages/rust.md +++ b/.agents/languages/rust.md @@ -21,13 +21,16 @@ Load this file when changing `rust/` or Rust xlang behavior. - Root deserialization container memory budget state belongs to `ReadContext` and is initialized by the root `Fory` read methods before the header is consumed. Rust roots are byte-slice/`Reader` backed, so auto budget uses `inputBytes * 8 + 64 KiB`; do not add dynamic bytes-read accounting. -- Rust `Vec` stores inline element storage, so general LIST paths charge fixed `Vec` cost plus - `len * size_of::()`, including `Vec` and `Vec`. Dedicated primitive dense - ARRAY `Vec` readers, strings, binary, and primitive fixed-array owners stay skipped and keep - their byte checks. + `ReadContext` may expose only raw byte reservation and generic counted-byte arithmetic; `Vec`, + collection, map, and derive codec formulas belong in their serializer owners. +- Rust `Vec` stores inline element storage, so general LIST paths charge + `len * size_of::()`, including `Vec` and `Vec`. Maps charge + `len * (size_of::() + size_of::())`. Dedicated primitive dense ARRAY `Vec` readers, + strings, binary, and primitive fixed-array owners stay skipped and keep their byte checks. - Direct `Serializer` collection/map paths and derive `Codec` collection/map paths are separate allocation owners. Keep reservations in both before `Vec::with_capacity`, - `HashMap::with_capacity`, or collection materialization; charge zero-size containers. + `HashMap::with_capacity`, or collection materialization. Empty containers with no dynamic backing + normally charge zero. ## Key Paths diff --git a/.agents/languages/swift.md b/.agents/languages/swift.md index 0e2607ac59..440f89384d 100644 --- a/.agents/languages/swift.md +++ b/.agents/languages/swift.md @@ -15,8 +15,13 @@ Load this file when changing `swift/` or Swift xlang behavior. - Preserve distinct temporal semantics. Timestamp values and day-only local dates should have protocol-accurate helper names and no stale aliases after a refactor. - When temporal or public-type refactors touch generated Swift code, sweep message fields, union payloads, macros, xlang harnesses, and integration fixtures together. - Compatible scalar, list-array, and binary/uint8-array adaptations are immediate-field-only. Recursive matched-field comparison for collection elements, array elements, map keys, and map values must require exact nullability, ref tracking, generic arity, and type shape except documented user-type family normalization. -- Root deserialization container memory budget state belongs to `ReadContext`. Swift public roots are `Data` and `ByteBuffer`, so auto uses known root bytes; do not add stream bytes-read accounting or serializer-local budget state. -- For Swift container budget formulas, distinguish inline/value storage from reference storage: use `MemoryLayout.stride` for value arrays/lists/maps and the 4-byte reference fallback for `Serializer.isRefType` / `FieldCodec.isRefType` paths. Dedicated `String`, `Data`/binary, and primitive packed-array owners stay skipped, except compatible packed-array-to-list reads must charge the target list materialization before allocation. +- Root deserialization container memory budget state belongs to `ReadContext`. Swift public roots are `Data` and `ByteBuffer`, so auto uses known root bytes; do not add stream bytes-read accounting or serializer-local budget state. `ReadContext` may expose only raw byte reservation and generic counted-byte arithmetic; array/set/map formulas belong in serializer and field-codec owners. +- For Swift container budget formulas, distinguish inline/value storage from reference storage: use + `MemoryLayout.stride` for value arrays/lists/sets/maps and the 4-byte reference fallback for + `Serializer.isRefType` / `FieldCodec.isRefType` paths. Maps charge key plus value storage, and + empty containers with no backing storage normally charge zero. Dedicated `String`, `Data`/binary, + and primitive packed-array owners stay skipped, except compatible packed-array-to-list reads must + charge the target list materialization before allocation. ## Commands diff --git a/AGENTS.md b/AGENTS.md index be3962a6fa..51663f9f77 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -33,7 +33,7 @@ This is the entry point for AI guidance in Apache Fory. Read this file first, th - Check the spec before implementation. For wire behavior and xlang mapping, use the specs as the source of truth and never copy one runtime's bug into another runtime just to make tests pass. - Do not make assumptions about runtime behavior, ownership, registration, metadata construction, protocol semantics, or test coverage. Read the current code, owning docs/specs, and relevant tests before making a design judgment or implementation decision. If the evidence is incomplete, inspect more or state the uncertainty explicitly instead of filling gaps from memory or analogy with another runtime. - For untrusted deserialization, read `docs/security/deserialization.md` before changing allocation, stream filling, skip, reference, metadata, or policy validation behavior. Variable-length deserialization must not allocate or reserve backing/output capacity from attacker-declared lengths or counts before the byte owner has proven proportional readable bytes with `checkReadableBytes` or the runtime equivalent. Container memory-budget reservation is accounting only and may happen before that byte check, but it must not replace the byte check. -- Root deserialization container memory budgets are estimated container-owned memory, not exact heap accounting and not raw slots. Positive `maxContainerMemoryBytes` wins; `-1` auto uses `inputBytes * 8 + 64 KiB` for known-length roots and fixed `128 MiB` for true stream or unknown-length roots. Charge fixed container cost, backing/reference/inline storage, map table and entry overhead where the runtime has cheap reliable signals, and zero-size containers; skip only dedicated string, binary, primitive array, and primitive dense-array owners. Native runtimes such as C++ must prefer portable lower-bound formulas over non-portable STL or allocator layout guesses. Do not add dynamic stream bytes-read accounting or nested hot-path cleanup just for this budget. +- Root deserialization container memory budgets estimate lower-bound container-owned storage, not exact heap accounting and not raw slots. Positive `maxContainerMemoryBytes` wins; `-1` auto uses `inputBytes * 8 + 64 KiB` for known-length roots and fixed `128 MiB` for true stream or unknown-length roots. Read context/read state owns only raw byte accounting plus generic counted-byte arithmetic such as `reserveContainerMemory(bytes)` or `reserveCountedContainerMemory(count, elementBytes)`; it must not expose collection/map/array semantic reservation APIs. Concrete serializers and generated serializers own the formulas: reference-backed containers/object arrays charge reference slots, inline/value containers charge element storage, reference-backed maps charge two references per entry, and inline/value maps charge key plus value storage. Fixed/header cost defaults to zero and is charged only for documented independent lower-bound storage not already covered by parent inline/value storage; empty containers without dynamic backing normally charge zero. Skip only dedicated string, binary, primitive array, and primitive dense-array owners. Do not guess table, bucket, node, entry, object-header, array-header, allocator, or debug-layout overhead. Do not add dynamic stream bytes-read accounting or nested hot-path cleanup just for this budget. - For remote TypeDef/TypeMeta reads, the checked metadata cache is the only owner of remote "already validated" state. Cache hit means the header was previously parsed, body/hash-validated, policy-checked, and published by that cache, so the hot path must skip the body and use cached metadata without extra validation, hashing, limit checks, exact-local checks, allocation, or policy work. A known expected local TypeDef/TypeMeta header/hash match is a local-schema hit, not a remote cache miss: it may skip the body and use the local TypeInfo/TypeMeta without schema-version counting or cache publish. Cache miss is the only path that parses and validates non-local metadata, enforces limits, performs exact-local byte comparison when needed, and publishes remote metadata to the cache. Do not add nullable accepted-header fields, sentinel headers, per-TypeInfo markers, pending metadata state, parallel header-low/header-high slots, or parallel acceptance state for this decision. If a runtime needs a metadata hit hint, cache the concrete checked metadata owner object, such as the TypeInfo, TypeDef, or TypeMeta used by that runtime, and compare its validated header identity directly. - When a user corrects a non-obvious invariant, encode it in the nearest source comment before continuing, and also update `AGENTS.md`, `.agents/**`, docs, or specs when the rule is reusable beyond one file. Do not rely only on chat history, task notes, commit messages, or benchmark logs for corrections that protect security, protocol behavior, ownership, naming, or hot-path performance. - Reject semantic hacks. Do not bypass broken semantics by deleting cases, simplifying callers, adding coercion hooks, or using workaround fallbacks; fix the underlying bug and prove it with focused tests. diff --git a/cpp/fory/serialization/collection_serializer.h b/cpp/fory/serialization/collection_serializer.h index 5aca15f728..3449e1d6fe 100644 --- a/cpp/fory/serialization/collection_serializer.h +++ b/cpp/fory/serialization/collection_serializer.h @@ -398,10 +398,10 @@ inline bool reserve_collection(Container &result, ReadContext &ctx, if (FORY_PREDICT_FALSE(ctx.has_error())) { return false; } - constexpr size_t fixed_bytes = sizeof(Container); constexpr size_t elem_bytes = collection_element_memory_bytes(); - if (FORY_PREDICT_FALSE((!ctx.template reserve_counted_container_memory< - fixed_bytes, elem_bytes>(length)))) { + if (FORY_PREDICT_FALSE( + (!ctx.template reserve_counted_container_memory( + length)))) { return false; } if (FORY_PREDICT_FALSE(!ctx.buffer().ensure_readable(length, ctx.error()))) { @@ -419,18 +419,9 @@ inline bool reserve_collection(std::vector &result, if (FORY_PREDICT_FALSE(ctx.has_error())) { return false; } - constexpr size_t fixed_bytes = sizeof(std::vector); - constexpr size_t max_packed_bytes = - (static_cast(std::numeric_limits::max()) + CHAR_BIT - - 1) / - CHAR_BIT; - static_assert(fixed_bytes <= - std::numeric_limits::max() - max_packed_bytes, - "vector memory estimate overflows"); const size_t packed_bytes = (static_cast(length) + CHAR_BIT - 1) / CHAR_BIT; - if (FORY_PREDICT_FALSE( - !ctx.reserve_container_memory(fixed_bytes + packed_bytes))) { + if (FORY_PREDICT_FALSE(!ctx.reserve_container_memory(packed_bytes))) { return false; } if (FORY_PREDICT_FALSE(!ctx.buffer().ensure_readable(length, ctx.error()))) { @@ -445,7 +436,7 @@ inline bool reserve_empty_collection(ReadContext &ctx) { if (FORY_PREDICT_FALSE(ctx.has_error())) { return false; } - return ctx.reserve_container_memory(sizeof(Container)); + return ctx.reserve_container_memory(0); } // Helper to insert element into container (vector or set) diff --git a/cpp/fory/serialization/container_memory_budget_test.cc b/cpp/fory/serialization/container_memory_budget_test.cc index 296f74027a..e255a98a26 100644 --- a/cpp/fory/serialization/container_memory_budget_test.cc +++ b/cpp/fory/serialization/container_memory_budget_test.cc @@ -26,6 +26,7 @@ #include #include #include +#include #include #include #include @@ -95,8 +96,7 @@ template std::vector serialize_value(const T &value) { size_t nested_empty_budget(size_t count) { using Inner = std::vector; - using Outer = std::vector; - return sizeof(Outer) + count * sizeof(Inner) + count * sizeof(Inner); + return count * sizeof(Inner); } template @@ -118,32 +118,16 @@ void expect_budget_boundary(const T &value, size_t required) { } TEST(ContainerMemoryBudgetTest, KnownLengthAutoBudget) { - constexpr size_t count = 3000; - std::vector> value(count); - auto bytes = serialize_value(value); - const size_t auto_limit = bytes.size() * 8 + kKnownBudgetSlack; - const size_t required = nested_empty_budget(count); - ASSERT_GT(required, auto_limit); - - auto default_result = with_fory(-1, [&](Fory &fory) { - return fory.deserialize>>(bytes); - }); - ASSERT_FALSE(default_result.ok()); - EXPECT_EQ(default_result.error().code(), ErrorCode::InvalidData); - - auto explicit_auto_result = - with_fory(static_cast(auto_limit), [&](Fory &fory) { - return fory.deserialize>>(bytes); - }); - ASSERT_FALSE(explicit_auto_result.ok()); - EXPECT_EQ(explicit_auto_result.error().code(), ErrorCode::InvalidData); - - auto explicit_result = - with_fory(static_cast(required), [&](Fory &fory) { - return fory.deserialize>>(bytes); - }); - ASSERT_TRUE(explicit_result.ok()) << explicit_result.error().to_string(); - EXPECT_EQ(explicit_result.value(), value); + Config config; + config.max_container_memory_bytes = -1; + ReadContext context(config, std::make_unique()); + constexpr size_t root_bytes = 17; + const size_t expected = root_bytes * 8 + kKnownBudgetSlack; + + ASSERT_TRUE(context.init_container_budget_known(root_bytes)); + ASSERT_TRUE(context.reserve_container_memory(expected)); + ASSERT_FALSE(context.reserve_container_memory(1)); + EXPECT_EQ(context.take_error().code(), ErrorCode::InvalidData); } TEST(ContainerMemoryBudgetTest, StreamAutoBudget) { @@ -172,8 +156,7 @@ TEST(ContainerMemoryBudgetTest, StreamAutoBudget) { TEST(ContainerMemoryBudgetTest, ExplicitOverride) { std::vector value(8); auto bytes = serialize_value(value); - const size_t required = - sizeof(std::vector) + value.size() * sizeof(BudgetItem); + const size_t required = value.size() * sizeof(BudgetItem); auto small_result = with_fory(static_cast(required - 1), [&](Fory &fory) { @@ -190,10 +173,10 @@ TEST(ContainerMemoryBudgetTest, ExplicitOverride) { EXPECT_EQ(exact_result.value(), value); } -TEST(ContainerMemoryBudgetTest, EmptyContainersChargeFixedCost) { +TEST(ContainerMemoryBudgetTest, NestedEmptyContainersUseParentStorage) { std::vector> value(1); auto bytes = serialize_value(value); - const size_t required = nested_empty_budget(1); + const size_t required = sizeof(std::vector); auto small_result = with_fory(static_cast(required - 1), [&](Fory &fory) { @@ -215,8 +198,7 @@ TEST(ContainerMemoryBudgetTest, SiblingCumulativeBudget) { value.left.resize(16); value.right.resize(16); auto bytes = serialize_value(value); - const size_t one_vector = - sizeof(std::vector) + value.left.size() * sizeof(BudgetItem); + const size_t one_vector = value.left.size() * sizeof(BudgetItem); auto small_result = with_fory(static_cast(one_vector), [&](Fory &fory) { @@ -236,24 +218,20 @@ TEST(ContainerMemoryBudgetTest, SiblingCumulativeBudget) { TEST(ContainerMemoryBudgetTest, MapBudget) { std::map value{{"a", 1}, {"b", 2}, {"c", 3}}; const size_t entry_bytes = sizeof(std::string) + sizeof(int32_t); - const size_t required = sizeof(value) + value.size() * entry_bytes; + const size_t required = value.size() * entry_bytes; expect_budget_boundary(value, required); } TEST(ContainerMemoryBudgetTest, CollectionLowerBounds) { std::deque deque_value(4); - expect_budget_boundary(deque_value, - sizeof(deque_value) + - deque_value.size() * sizeof(BudgetItem)); + expect_budget_boundary(deque_value, deque_value.size() * sizeof(BudgetItem)); std::list list_value(4); - expect_budget_boundary( - list_value, sizeof(list_value) + list_value.size() * sizeof(BudgetItem)); + expect_budget_boundary(list_value, list_value.size() * sizeof(BudgetItem)); std::forward_list forward_value(4); - expect_budget_boundary(forward_value, sizeof(forward_value) + - size_t{4} * sizeof(BudgetItem)); + expect_budget_boundary(forward_value, size_t{4} * sizeof(BudgetItem)); } TEST(ContainerMemoryBudgetTest, VectorBoolUsesPackedStorage) { @@ -261,32 +239,28 @@ TEST(ContainerMemoryBudgetTest, VectorBoolUsesPackedStorage) { value[0] = true; value[32] = true; const size_t packed_bytes = (value.size() + CHAR_BIT - 1) / CHAR_BIT; - const size_t required = sizeof(value) + packed_bytes; - ASSERT_LT(required, sizeof(value) + value.size()); + const size_t required = packed_bytes; + ASSERT_LT(required, value.size()); expect_budget_boundary(value, required); } TEST(ContainerMemoryBudgetTest, OrderedSetAndMapLowerBounds) { std::set set_value{1, 2, 3, 4}; - expect_budget_boundary(set_value, sizeof(set_value) + - set_value.size() * sizeof(int32_t)); + expect_budget_boundary(set_value, set_value.size() * sizeof(int32_t)); std::map map_value{{"a", 1}, {"b", 2}}; expect_budget_boundary( - map_value, sizeof(map_value) + map_value.size() * (sizeof(std::string) + - sizeof(int32_t))); + map_value, map_value.size() * (sizeof(std::string) + sizeof(int32_t))); } TEST(ContainerMemoryBudgetTest, UnorderedContainersLowerBounds) { std::unordered_set set_value{1, 2, 3, 4}; - expect_budget_boundary(set_value, sizeof(set_value) + - set_value.size() * sizeof(int32_t)); + expect_budget_boundary(set_value, set_value.size() * sizeof(int32_t)); std::unordered_map map_value{{"a", 1}, {"b", 2}}; expect_budget_boundary( - map_value, sizeof(map_value) + map_value.size() * (sizeof(std::string) + - sizeof(int32_t))); + map_value, map_value.size() * (sizeof(std::string) + sizeof(int32_t))); } TEST(ContainerMemoryBudgetTest, ArrayHasNoStandaloneReservation) { @@ -303,8 +277,7 @@ TEST(ContainerMemoryBudgetTest, FixedInlineOwnerChargesNestedVector) { BudgetFixedArrayOwner value; value.prefix = {{1, 2, 3, 4}}; value.items.resize(3); - const size_t required = - sizeof(value.items) + value.items.size() * sizeof(BudgetItem); + const size_t required = value.items.size() * sizeof(BudgetItem); expect_budget_boundary(value, required); } diff --git a/cpp/fory/serialization/context.cc b/cpp/fory/serialization/context.cc index 686db558ad..132847c99b 100644 --- a/cpp/fory/serialization/context.cc +++ b/cpp/fory/serialization/context.cc @@ -740,17 +740,13 @@ const TypeInfo *ReadContext::read_any_type_info(Error &error) { } bool ReadContext::reserve_counted_container_checked(uint32_t length, - size_t fixed_bytes, size_t elem_bytes) { - if (FORY_PREDICT_FALSE( - elem_bytes != 0 && - static_cast(length) > - (std::numeric_limits::max() - fixed_bytes) / - elem_bytes)) { + if (FORY_PREDICT_FALSE(elem_bytes != 0 && + static_cast(length) > + std::numeric_limits::max() / elem_bytes)) { return set_container_memory_overflow(length, elem_bytes); } - return reserve_container_memory(static_cast(length) * elem_bytes + - fixed_bytes); + return reserve_container_memory(static_cast(length) * elem_bytes); } bool ReadContext::set_container_memory_error(const std::string &message) { diff --git a/cpp/fory/serialization/context.h b/cpp/fory/serialization/context.h index 5d2bbc3c60..09d321710d 100644 --- a/cpp/fory/serialization/context.h +++ b/cpp/fory/serialization/context.h @@ -566,34 +566,15 @@ class ReadContext { return true; } - FORY_ALWAYS_INLINE bool reserve_counted_container_memory(uint32_t length, - size_t fixed_bytes, - size_t elem_bytes) { - if (length == 0) { - return reserve_container_memory(fixed_bytes); - } - constexpr size_t kMaxLength = - static_cast(std::numeric_limits::max()); - if (FORY_PREDICT_TRUE(elem_bytes <= - (std::numeric_limits::max() - fixed_bytes) / - kMaxLength)) { - return reserve_container_memory(static_cast(length) * elem_bytes + - fixed_bytes); - } - return reserve_counted_container_checked(length, fixed_bytes, elem_bytes); - } - - template + template FORY_ALWAYS_INLINE bool reserve_counted_container_memory(uint32_t length) { constexpr size_t kMaxLength = static_cast(std::numeric_limits::max()); if constexpr (elem_bytes <= - (std::numeric_limits::max() - fixed_bytes) / - kMaxLength) { - return reserve_container_memory(static_cast(length) * elem_bytes + - fixed_bytes); + std::numeric_limits::max() / kMaxLength) { + return reserve_container_memory(static_cast(length) * elem_bytes); } else { - return reserve_counted_container_checked(length, fixed_bytes, elem_bytes); + return reserve_counted_container_checked(length, elem_bytes); } } @@ -761,7 +742,6 @@ class ReadContext { check_remote_type_meta_limit(const TypeMeta &type_meta); void record_remote_type_meta(const std::string &type_key); FORY_NOINLINE bool reserve_counted_container_checked(uint32_t length, - size_t fixed_bytes, size_t elem_bytes); FORY_NOINLINE bool set_container_memory_error(const std::string &message); FORY_NOINLINE bool set_container_memory_overflow(uint32_t length, diff --git a/cpp/fory/serialization/map_serializer.h b/cpp/fory/serialization/map_serializer.h index a654dae83b..c07a5f0a44 100644 --- a/cpp/fory/serialization/map_serializer.h +++ b/cpp/fory/serialization/map_serializer.h @@ -96,10 +96,10 @@ inline bool reserve_map(MapType &map, ReadContext &ctx, uint32_t length) { static_assert(sizeof(Key) <= std::numeric_limits::max() - sizeof(Value), "map entry memory estimate overflows"); - constexpr size_t fixed_bytes = sizeof(MapType); constexpr size_t elem_bytes = sizeof(Key) + sizeof(Value); - if (FORY_PREDICT_FALSE((!ctx.template reserve_counted_container_memory< - fixed_bytes, elem_bytes>(length)))) { + if (FORY_PREDICT_FALSE( + (!ctx.template reserve_counted_container_memory( + length)))) { return false; } if (FORY_PREDICT_FALSE(!ctx.buffer().ensure_readable(length, ctx.error()))) { @@ -113,7 +113,7 @@ template inline bool reserve_empty_map(ReadContext &ctx) { if (FORY_PREDICT_FALSE(ctx.has_error())) { return false; } - return ctx.reserve_container_memory(sizeof(MapType)); + return ctx.reserve_container_memory(0); } /// write chunk size at header offset diff --git a/csharp/src/Fory.Generator/ForyModelGenerator.cs b/csharp/src/Fory.Generator/ForyModelGenerator.cs index 50c4682515..de535233c2 100644 --- a/csharp/src/Fory.Generator/ForyModelGenerator.cs +++ b/csharp/src/Fory.Generator/ForyModelGenerator.cs @@ -1161,14 +1161,16 @@ private static void EmitReadCompatibleListArrayPayload( sb.AppendLine($"{indent}}}"); string elementTypeName = codec.CarrierKind == CarrierKind.Array ? ElementTypeName(codec.TypeName) : PackedArrayElementTypeName(codec.TypeId); uint elementTypeId = PackedArrayElementTypeId(codec.TypeId); + string elementBytesExpr = + $"(typeof({elementTypeName}).IsValueType ? global::System.Runtime.CompilerServices.Unsafe.SizeOf<{elementTypeName}>() : 4)"; if (codec.CarrierKind == CarrierKind.Array) { - sb.AppendLine($"{indent}context.ReserveArrayMemory<{elementTypeName}>({lengthVar});"); + sb.AppendLine($"{indent}context.ReserveCountedContainerMemory({lengthVar}, {elementBytesExpr});"); sb.AppendLine($"{indent}{codec.TypeName} {targetVar} = new {ElementTypeName(codec.TypeName)}[{lengthVar}];"); } else { - sb.AppendLine($"{indent}context.ReserveListMemory<{elementTypeName}>({lengthVar});"); + sb.AppendLine($"{indent}context.ReserveCountedContainerMemory({lengthVar}, {elementBytesExpr});"); sb.AppendLine($"{indent}{codec.TypeName} {targetVar} = new({lengthVar});"); } diff --git a/csharp/src/Fory/CollectionSerializers.cs b/csharp/src/Fory/CollectionSerializers.cs index 2e5a610d00..d9baf94203 100644 --- a/csharp/src/Fory/CollectionSerializers.cs +++ b/csharp/src/Fory/CollectionSerializers.cs @@ -17,6 +17,7 @@ using System.Collections; using System.Collections.Immutable; +using System.Runtime.CompilerServices; namespace Apache.Fory; @@ -31,6 +32,17 @@ internal static class CollectionBits internal static class CollectionCodec { + private const int ReferenceBytes = 4; + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal static int ElementBytes() => typeof(T).IsValueType ? Unsafe.SizeOf() : ReferenceBytes; + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal static void ReserveElementStorage(ReadContext context, int count) + { + context.ReserveCountedContainerMemory(count, ElementBytes()); + } + private static bool NeedsCompatibleElementTypeMeta(TypeInfo typeInfo, WriteContext context) { return context.Compatible && @@ -201,7 +213,7 @@ public static List ReadCollectionData(Serializer elementSerializer, Rea int length = checked((int)context.Reader.ReadVarUInt32()); if (length == 0) { - context.ReserveListMemory(length); + ReserveElementStorage(context, length); return []; } @@ -214,7 +226,7 @@ public static List ReadCollectionData(Serializer elementSerializer, Rea bool hasNull = (header & CollectionBits.HasNull) != 0; bool declared = (header & CollectionBits.DeclaredElementType) != 0; bool sameType = (header & CollectionBits.SameType) != 0; - context.ReserveNonEmptyListMemory(length); + ReserveElementStorage(context, length); context.Reader.CheckBound(length); List values = new(length); if (!sameType) @@ -524,7 +536,7 @@ public override void WriteData(WriteContext context, in T[] value, bool hasGener public override T[] ReadData(ReadContext context) { List values = CollectionCodec.ReadCollectionData(context.TypeResolver.GetSerializer(), context); - context.ReserveArrayMemory(values.Count); + CollectionCodec.ReserveElementStorage(context, values.Count); return values.ToArray(); } } @@ -558,7 +570,7 @@ public override void WriteData(WriteContext context, in HashSet value, bool h public override HashSet ReadData(ReadContext context) { List values = CollectionCodec.ReadCollectionData(context.TypeResolver.GetSerializer(), context); - context.ReserveLinkedCollectionMemory(values.Count); + CollectionCodec.ReserveElementStorage(context, values.Count); return [.. values]; } } @@ -576,7 +588,7 @@ public override void WriteData(WriteContext context, in SortedSet value, bool public override SortedSet ReadData(ReadContext context) { List values = CollectionCodec.ReadCollectionData(context.TypeResolver.GetSerializer(), context); - context.ReserveLinkedCollectionMemory(values.Count); + CollectionCodec.ReserveElementStorage(context, values.Count); return [.. values]; } } @@ -594,7 +606,7 @@ public override void WriteData(WriteContext context, in ImmutableHashSet valu public override ImmutableHashSet ReadData(ReadContext context) { List values = CollectionCodec.ReadCollectionData(context.TypeResolver.GetSerializer(), context); - context.ReserveLinkedCollectionMemory(values.Count); + CollectionCodec.ReserveElementStorage(context, values.Count); return ImmutableHashSet.CreateRange(values); } } @@ -612,7 +624,7 @@ public override void WriteData(WriteContext context, in LinkedList value, boo public override LinkedList ReadData(ReadContext context) { List values = CollectionCodec.ReadCollectionData(context.TypeResolver.GetSerializer(), context); - context.ReserveLinkedCollectionMemory(values.Count); + CollectionCodec.ReserveElementStorage(context, values.Count); return new LinkedList(values); } } @@ -630,7 +642,7 @@ public override void WriteData(WriteContext context, in Queue value, bool has public override Queue ReadData(ReadContext context) { List values = CollectionCodec.ReadCollectionData(context.TypeResolver.GetSerializer(), context); - context.ReserveLinkedCollectionMemory(values.Count); + CollectionCodec.ReserveElementStorage(context, values.Count); Queue queue = new(values.Count); for (int i = 0; i < values.Count; i++) { @@ -667,7 +679,7 @@ public override void WriteData(WriteContext context, in Stack value, bool has public override Stack ReadData(ReadContext context) { List values = CollectionCodec.ReadCollectionData(context.TypeResolver.GetSerializer(), context); - context.ReserveLinkedCollectionMemory(values.Count); + CollectionCodec.ReserveElementStorage(context, values.Count); Stack stack = new(values.Count); for (int i = 0; i < values.Count; i++) { diff --git a/csharp/src/Fory/DictionarySerializers.cs b/csharp/src/Fory/DictionarySerializers.cs index bdcec3222a..444c2c6fbd 100644 --- a/csharp/src/Fory/DictionarySerializers.cs +++ b/csharp/src/Fory/DictionarySerializers.cs @@ -16,6 +16,7 @@ // under the License. using System.Collections.Concurrent; +using System.Runtime.CompilerServices; namespace Apache.Fory; @@ -33,6 +34,18 @@ public abstract class DictionaryLikeSerializer : Seri where TDictionary : class, IDictionary where TKey : notnull { + private const int ReferenceBytes = 4; + private static readonly long MapElementBytes = (long)ElementBytes() + ElementBytes(); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static int ElementBytes() => typeof(T).IsValueType ? Unsafe.SizeOf() : ReferenceBytes; + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static void ReserveMapStorage(ReadContext context, int count) + { + context.ReserveCountedContainerMemory(count, MapElementBytes); + } + public override TDictionary DefaultValue => null!; protected abstract TDictionary CreateMap(int capacity); @@ -214,11 +227,11 @@ public override TDictionary ReadData(ReadContext context) int totalLength = checked((int)context.Reader.ReadVarUInt32()); if (totalLength == 0) { - context.ReserveMapMemory(totalLength); + ReserveMapStorage(context, totalLength); return CreateMap(0); } - context.ReserveNonEmptyMapMemory(totalLength); + ReserveMapStorage(context, totalLength); context.Reader.CheckBound(totalLength); TDictionary map = CreateMap(totalLength); bool keyDynamicType = keyTypeInfo.IsDynamicType; diff --git a/csharp/src/Fory/NullableKeyDictionary.cs b/csharp/src/Fory/NullableKeyDictionary.cs index fe573cae02..ae08f775a0 100644 --- a/csharp/src/Fory/NullableKeyDictionary.cs +++ b/csharp/src/Fory/NullableKeyDictionary.cs @@ -16,6 +16,7 @@ // under the License. using System.Collections; +using System.Runtime.CompilerServices; namespace Apache.Fory; @@ -390,6 +391,18 @@ IEnumerator IEnumerable.GetEnumerator() public sealed class NullableKeyDictionarySerializer : Serializer> { + private const int ReferenceBytes = 4; + private static readonly long MapElementBytes = (long)ElementBytes() + ElementBytes(); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static int ElementBytes() => typeof(T).IsValueType ? Unsafe.SizeOf() : ReferenceBytes; + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static void ReserveMapStorage(ReadContext context, int count) + { + context.ReserveCountedContainerMemory(count, MapElementBytes); + } + public override NullableKeyDictionary DefaultValue => null!; public override void WriteData(WriteContext context, in NullableKeyDictionary value, bool hasGenerics) @@ -537,11 +550,11 @@ public override NullableKeyDictionary ReadData(ReadContext context int totalLength = checked((int)context.Reader.ReadVarUInt32()); if (totalLength == 0) { - context.ReserveMapMemory(totalLength); + ReserveMapStorage(context, totalLength); return new NullableKeyDictionary(); } - context.ReserveNonEmptyMapMemory(totalLength); + ReserveMapStorage(context, totalLength); context.Reader.CheckBound(totalLength); NullableKeyDictionary map = new(totalLength); bool keyDynamicType = keyTypeInfo.IsDynamicType; diff --git a/csharp/src/Fory/PrimitiveDictionarySerializers.cs b/csharp/src/Fory/PrimitiveDictionarySerializers.cs index e753280f72..594abd6833 100644 --- a/csharp/src/Fory/PrimitiveDictionarySerializers.cs +++ b/csharp/src/Fory/PrimitiveDictionarySerializers.cs @@ -17,6 +17,7 @@ using System.Collections; using System.Collections.Concurrent; +using System.Runtime.CompilerServices; namespace Apache.Fory; @@ -663,6 +664,19 @@ public static void WriteMap() => typeof(T).IsValueType ? Unsafe.SizeOf() : ReferenceBytes; + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static void ReserveMapStorage(ReadContext context, int count) + { + context.ReserveCountedContainerMemory( + count, + (long)ElementBytes() + ElementBytes()); + } + public static TMap ReadMap(ReadContext context) where TKey : notnull where TKeyCodec : struct, IPrimitiveDictionaryCodec @@ -672,11 +686,11 @@ public static TMap ReadMap( int totalLength = checked((int)context.Reader.ReadVarUInt32()); if (totalLength == 0) { - context.ReserveMapMemory(totalLength); + ReserveMapStorage(context, totalLength); return TMapOps.Create(0); } - context.ReserveNonEmptyMapMemory(totalLength); + ReserveMapStorage(context, totalLength); context.Reader.CheckBound(totalLength); TMap map = TMapOps.Create(totalLength); TypeId keyTypeId = TKeyCodec.WireTypeId; diff --git a/csharp/src/Fory/ReadContext.cs b/csharp/src/Fory/ReadContext.cs index 31cc878714..865e2046dc 100644 --- a/csharp/src/Fory/ReadContext.cs +++ b/csharp/src/Fory/ReadContext.cs @@ -15,7 +15,6 @@ // specific language governing permissions and limitations // under the License. -using System.ComponentModel; using System.Runtime.CompilerServices; namespace Apache.Fory; @@ -25,11 +24,6 @@ public sealed class ReadContext private const int MinRemoteTypeMetaLimit = 8192; internal const long KnownContainerBudgetSlackBytes = 64 * 1024; internal const long UnknownContainerBudgetBytes = 128L * 1024 * 1024; - internal const int ContainerFixedBytes = 32; - internal const int ArrayHeaderBytes = 24; - internal const int ReferenceBytes = 4; - internal const int CollectionEntryOverheadBytes = 16; - internal const int MapEntryOverheadBytes = 24; private readonly ReusableArray _typeMetaRefs = new(); private readonly UInt64Map _typeMetasByHeader = new(); @@ -82,90 +76,6 @@ public ReadContext( internal RefReader RefReader { get; } - [MethodImpl(MethodImplOptions.AggressiveInlining)] - internal static int ElementBytes() => ContainerElementBytes.Value; - - private static class ContainerElementBytes - { - internal static readonly int Value = typeof(T).IsValueType ? Unsafe.SizeOf() : ReferenceBytes; - } - - private static class MapElementBytes - { - internal static readonly int Value = - ElementBytes() + ElementBytes() + MapEntryOverheadBytes + ReferenceBytes; - } - - /// - /// Reserves estimated list-owned memory for generated serializer code. - /// Configure instead of calling this directly. - /// - [EditorBrowsable(EditorBrowsableState.Never)] - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public void ReserveListMemory(int length) - { - if (length == 0) - { - ReserveContainerMemory(ContainerFixedBytes); - return; - } - - ReserveNonEmptyListMemory(length); - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - internal void ReserveNonEmptyListMemory(int length) - { - ReserveContainerMemory((long)(uint)length * ElementBytes() + ContainerFixedBytes + ArrayHeaderBytes); - } - - /// - /// Reserves estimated array-owned memory for generated serializer code. - /// Configure instead of calling this directly. - /// - [EditorBrowsable(EditorBrowsableState.Never)] - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public void ReserveArrayMemory(int length) - { - ReserveCountedContainerMemory( - length, - ArrayHeaderBytes, - ElementBytes()); - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - internal void ReserveLinkedCollectionMemory(int length) - { - if (length == 0) - { - ReserveContainerMemory(ContainerFixedBytes); - return; - } - - ReserveContainerMemory( - (long)(uint)length * (ElementBytes() + CollectionEntryOverheadBytes + ReferenceBytes * 2) + - ContainerFixedBytes); - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - internal void ReserveMapMemory(int length) - { - if (length == 0) - { - ReserveContainerMemory(ContainerFixedBytes); - return; - } - - ReserveNonEmptyMapMemory(length); - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - internal void ReserveNonEmptyMapMemory(int length) - { - ReserveContainerMemory( - (long)(uint)length * MapElementBytes.Value + ContainerFixedBytes + ArrayHeaderBytes * 2); - } - [MethodImpl(MethodImplOptions.AggressiveInlining)] internal void InitContainerBudgetKnown(int rootBytes) { @@ -192,9 +102,20 @@ internal void ReserveContainerMemory(long bytes) } [MethodImpl(MethodImplOptions.AggressiveInlining)] - internal void ReserveCountedContainerMemory(int count, int fixedBytes, int elementBytes) + internal void ReserveCountedContainerMemory(int count, long elementBytes) { - ReserveContainerMemory((long)(uint)count * elementBytes + fixedBytes); + if (count < 0 || elementBytes < 0) + { + ThrowContainerBudgetOverflow(); + } + + uint length = (uint)count; + if (elementBytes != 0 && length > long.MaxValue / elementBytes) + { + ThrowContainerBudgetOverflow(); + } + + ReserveContainerMemory((long)length * elementBytes); } [MethodImpl(MethodImplOptions.NoInlining)] diff --git a/csharp/tests/Fory.Tests/ContainerMemoryBudgetTests.cs b/csharp/tests/Fory.Tests/ContainerMemoryBudgetTests.cs index dabaa03b28..40894a379e 100644 --- a/csharp/tests/Fory.Tests/ContainerMemoryBudgetTests.cs +++ b/csharp/tests/Fory.Tests/ContainerMemoryBudgetTests.cs @@ -16,6 +16,7 @@ // under the License. using System.Buffers; +using System.Runtime.CompilerServices; using Apache.Fory; using ForyRuntime = Apache.Fory.Fory; @@ -43,6 +44,10 @@ public sealed class BudgetArrayHolder public sealed class ContainerMemoryBudgetTests { + private const int ReferenceBytes = 4; + + private static int ElementBytes() => typeof(T).IsValueType ? Unsafe.SizeOf() : ReferenceBytes; + private static ForyRuntime NewFory(long maxContainerMemoryBytes = -1) { return ForyRuntime.Builder() @@ -62,51 +67,40 @@ private static byte[] Serialize(T value) private static long ListBudget(int count) { - return count == 0 - ? ReadContext.ContainerFixedBytes - : ReadContext.ContainerFixedBytes + ReadContext.ArrayHeaderBytes + - (long)count * ReadContext.ElementBytes(); + return (long)count * ElementBytes(); } private static long ArrayBudget(int count) { - return ReadContext.ArrayHeaderBytes + (long)count * ReadContext.ElementBytes(); + return (long)count * ElementBytes(); } private static long MapBudget(int count) { - return count == 0 - ? ReadContext.ContainerFixedBytes - : ReadContext.ContainerFixedBytes + ReadContext.ArrayHeaderBytes * 2 + - (long)count * (ReadContext.ElementBytes() + ReadContext.ElementBytes() + - ReadContext.MapEntryOverheadBytes + ReadContext.ReferenceBytes); + return (long)count * (ElementBytes() + ElementBytes()); } [Fact] - public void KnownLengthAutoBudgetRejectsLargeNestedEmpties() + public void KnownLengthAutoBudgetUsesInputBytes() { - const int count = 3000; - List> value = Enumerable.Range(0, count).Select(_ => new List()).ToList(); - byte[] bytes = Serialize(value); - long autoLimit = bytes.LongLength * 8 + ReadContext.KnownContainerBudgetSlackBytes; - long required = ListBudget>(count) + count * ListBudget(0); - Assert.True(required > autoLimit); - - Assert.Throws(() => NewFory().Deserialize>>(bytes)); + const int rootBytes = 17; + long expected = rootBytes * 8 + ReadContext.KnownContainerBudgetSlackBytes; + ReadContext context = new(new ByteReader([]), new TypeResolver(), NewFory().Config); - List> result = NewFory(required).Deserialize>>(bytes); - Assert.Equal(count, result.Count); + context.InitContainerBudgetKnown(rootBytes); + context.ReserveContainerMemory(expected); + Assert.Throws(() => context.ReserveContainerMemory(ReferenceBytes)); } [Fact] - public void ReadOnlySequenceUsesKnownLengthAutoBudget() + public void ReadOnlySequenceUsesKnownLengthRoot() { - const int count = 3000; + const int count = 6; List> value = Enumerable.Range(0, count).Select(_ => new List()).ToList(); byte[] bytes = Serialize(value); ReadOnlySequence sequence = new(bytes); - Assert.Throws(() => NewFory().Deserialize>>(ref sequence)); + Assert.Equal(count, NewFory().Deserialize>>(ref sequence).Count); } [Fact] diff --git a/dart/packages/fory/lib/src/context/read_context.dart b/dart/packages/fory/lib/src/context/read_context.dart index 1acf28c0d6..ba186c4f80 100644 --- a/dart/packages/fory/lib/src/context/read_context.dart +++ b/dart/packages/fory/lib/src/context/read_context.dart @@ -47,11 +47,6 @@ import 'package:fory/src/types/uint64.dart'; final class ReadContext { static const int _knownRootBudgetMultiplier = 8; static const int _knownRootBudgetSlackBytes = 64 * 1024; - static const int _collectionObjectBytes = 24; - static const int _mapObjectBytes = 48; - static const int _arrayHeaderBytes = 16; - static const int _mapEntryBytes = 32; - static const int _referenceBytes = 4; static const int _maxSafeBudgetBytes = 9007199254740991; /// Effective runtime configuration for the active operation. @@ -118,41 +113,6 @@ final class ReadContext { @internal @pragma('vm:prefer-inline') - void reserveCollectionMemory(int numElements) { - final bytes = _collectionObjectBytes + numElements * _referenceBytes; - final remaining = _remainingContainerMemoryBytes - bytes; - if (remaining < 0) { - _throwContainerMemoryExceeded(bytes); - } - _remainingContainerMemoryBytes = remaining; - } - - @internal - @pragma('vm:prefer-inline') - void reserveMapMemory(int numElements) { - final bytes = - _mapObjectBytes + - numElements * - (_referenceBytes * 2 + _mapEntryBytes + _referenceBytes * 3); - final remaining = _remainingContainerMemoryBytes - bytes; - if (remaining < 0) { - _throwContainerMemoryExceeded(bytes); - } - _remainingContainerMemoryBytes = remaining; - } - - @internal - @pragma('vm:prefer-inline') - void reserveTypedArrayMemory(int numElements, int elementBytes) { - final bytes = _arrayHeaderBytes + numElements * elementBytes; - final remaining = _remainingContainerMemoryBytes - bytes; - if (remaining < 0) { - _throwContainerMemoryExceeded(bytes); - } - _remainingContainerMemoryBytes = remaining; - } - - @internal void reserveContainerMemory(int bytes) { if (bytes < 0 || bytes > _maxSafeBudgetBytes) { _throwContainerMemoryOverflow(bytes); diff --git a/dart/packages/fory/lib/src/serializer/collection_serializers.dart b/dart/packages/fory/lib/src/serializer/collection_serializers.dart index b80839f2d6..5d8f7234ac 100644 --- a/dart/packages/fory/lib/src/serializer/collection_serializers.dart +++ b/dart/packages/fory/lib/src/serializer/collection_serializers.dart @@ -38,6 +38,8 @@ import 'package:fory/src/types/float16.dart'; import 'package:fory/src/types/int64.dart'; import 'package:fory/src/types/uint64.dart'; +const int _referenceBytes = 4; + @pragma('vm:prefer-inline') void _writeDirectTypeInfoValue( WriteContext context, @@ -270,10 +272,10 @@ final class ListSerializer extends Serializer { } final declaredTypeInfo = elementFieldType == null || - elementFieldType.isDynamic || - elementFieldType.typeId == TypeIds.unknown - ? null - : context.typeResolver.resolveFieldType(elementFieldType); + elementFieldType.isDynamic || + elementFieldType.typeId == TypeIds.unknown + ? null + : context.typeResolver.resolveFieldType(elementFieldType); final usesDeclaredType = declaredTypeInfo != null && usesDeclaredTypeInfo( @@ -296,9 +298,8 @@ final class ListSerializer extends Serializer { sameType: analysis.sameType, ); context.buffer.writeUint8(header); - final sameTypeInfo = !usesDeclaredType && analysis.sameType - ? analysis.sameTypeInfo - : null; + final sameTypeInfo = + !usesDeclaredType && analysis.sameType ? analysis.sameTypeInfo : null; if (!usesDeclaredType && sameTypeInfo != null && analysis.firstNonNull != null) { @@ -384,7 +385,7 @@ final class SetSerializer extends Serializer { elementFieldType, hasPreservedRef: hasPreservedRef, ); - context.reserveCollectionMemory(values.length); + context.reserveContainerMemory(values.length * _referenceBytes); return Set.of(values); } } @@ -402,9 +403,8 @@ Object? readCompatibleMatchedCollectionArrayField( final remoteType = remoteField.fieldType; if (isCompatibleArrayType(localType.typeId) && remoteType.typeId == TypeIds.list) { - final elementType = remoteType.arguments.isEmpty - ? null - : remoteType.arguments.single; + final elementType = + remoteType.arguments.isEmpty ? null : remoteType.arguments.single; if (elementType == null || _arrayElementTypeId(localType.typeId) != _compatibleArrayElementTypeId(elementType.typeId)) { @@ -421,9 +421,8 @@ Object? readCompatibleMatchedCollectionArrayField( } if (localType.typeId == TypeIds.list && isCompatibleArrayType(remoteType.typeId)) { - final localElementType = localType.arguments.isEmpty - ? null - : localType.arguments.single; + final localElementType = + localType.arguments.isEmpty ? null : localType.arguments.single; if (localElementType == null || _arrayElementTypeId(remoteType.typeId) != _compatibleArrayElementTypeId(localElementType.typeId)) { @@ -493,9 +492,8 @@ bool _listElementMatchesArray( int arrayTypeId, { required bool requireUnframedElement, }) { - final elementType = listType.arguments.isEmpty - ? null - : listType.arguments.single; + final elementType = + listType.arguments.isEmpty ? null : listType.arguments.single; // Nullable element schema is allowed for list -> array; actual // null payload elements fail in the dense-array reader. Ref-tracked // element framing is rejected here because this path stays primitive-only. @@ -512,7 +510,7 @@ Object _readCompatibleListAsArrayField( String fieldName, ) { final size = context.buffer.readVarUint32(); - context.reserveTypedArrayMemory(size, _arrayElementBytes(arrayTypeId)); + context.reserveContainerMemory(size * _arrayElementBytes(arrayTypeId)); if (size == 0) { return _newArrayValue(arrayTypeId, 0); } @@ -584,9 +582,8 @@ int _arrayElementBytes(int arrayTypeId) { TypeIds.bfloat16Array => 2, TypeIds.int32Array || TypeIds.uint32Array || TypeIds.float32Array => 4, TypeIds.int64Array || TypeIds.uint64Array || TypeIds.float64Array => 8, - _ => throw StateError( - 'Unsupported compatible array field type $arrayTypeId.', - ), + _ => + throw StateError('Unsupported compatible array field type $arrayTypeId.'), }; } @@ -605,9 +602,8 @@ Object _newArrayValue(int arrayTypeId, int length) { TypeIds.bfloat16Array => Bfloat16List(length), TypeIds.float32Array => Float32List(length), TypeIds.float64Array => Float64List(length), - _ => throw StateError( - 'Unsupported compatible array field type $arrayTypeId.', - ), + _ => + throw StateError('Unsupported compatible array field type $arrayTypeId.'), }; } @@ -622,9 +618,8 @@ void _setArrayValue(Object target, int arrayTypeId, int index, Object? value) { case TypeIds.int32Array: (target as Int32List)[index] = value as int; case TypeIds.int64Array: - (target as Int64List)[index] = value is int - ? Int64(value) - : value as Int64; + (target as Int64List)[index] = + value is int ? Int64(value) : value as Int64; case TypeIds.uint8Array: (target as Uint8List)[index] = value as int; case TypeIds.uint16Array: @@ -632,9 +627,8 @@ void _setArrayValue(Object target, int arrayTypeId, int index, Object? value) { case TypeIds.uint32Array: (target as Uint32List)[index] = value as int; case TypeIds.uint64Array: - (target as Uint64List)[index] = value is int - ? Uint64(value) - : value as Uint64; + (target as Uint64List)[index] = + value is int ? Uint64(value) : value as Uint64; case TypeIds.float16Array: (target as Float16List)[index] = value as double; case TypeIds.bfloat16Array: @@ -650,11 +644,11 @@ void _setArrayValue(Object target, int arrayTypeId, int index, Object? value) { Object _arrayToListValue(ReadContext context, Object? raw) { if (raw is BoolList) { - context.reserveCollectionMemory(raw.length); + context.reserveContainerMemory(raw.length * _referenceBytes); return raw.toList(); } if (raw is Iterable) { - context.reserveCollectionMemory(raw.length); + context.reserveContainerMemory(raw.length * _referenceBytes); return raw.toList(); } throw StateError('Expected compatible array payload.'); @@ -675,29 +669,29 @@ List readTypedListPayload( } final directTypeInfo = state.declaredTypeInfo ?? state.sameTypeInfo; if (directTypeInfo != null && !state.trackRef && !state.hasNull) { - final directFieldType = state.declaredTypeInfo != null - ? state.elementFieldType - : null; + final directFieldType = + state.declaredTypeInfo != null ? state.elementFieldType : null; if (directTypeInfo.type == T && directTypeInfo.kind == RegistrationKind.struct) { final structSerializer = directTypeInfo.structSerializer!; context.buffer.checkReadableBytes(state.size); - final result = directTypeInfo.remoteTypeDef == null - ? List.generate( - state.size, - (_) => structSerializer.readValue(context, directTypeInfo) as T, - growable: false, - ) - : List.generate( - state.size, - (_) => - structSerializer.readGeneratedCompatibleValue( - context, - directTypeInfo, - ) - as T, - growable: false, - ); + final result = + directTypeInfo.remoteTypeDef == null + ? List.generate( + state.size, + (_) => structSerializer.readValue(context, directTypeInfo) as T, + growable: false, + ) + : List.generate( + state.size, + (_) => + structSerializer.readGeneratedCompatibleValue( + context, + directTypeInfo, + ) + as T, + growable: false, + ); if (state.tracksDepth) { context.decreaseDepth(); } @@ -745,7 +739,7 @@ Set readTypedSetPayload( T Function(Object? value) convert, ) { final values = readTypedListPayload(context, elementFieldType, convert); - context.reserveCollectionMemory(values.length); + context.reserveContainerMemory(values.length * _referenceBytes); return Set.of(values); } @@ -937,7 +931,7 @@ _PreparedListRead _prepareListRead( FieldType? elementFieldType, ) { final size = context.buffer.readVarUint32(); - context.reserveCollectionMemory(size); + context.reserveContainerMemory(size * _referenceBytes); if (size == 0) { return _PreparedListRead( size: 0, @@ -964,13 +958,15 @@ _PreparedListRead _prepareListRead( elementFieldType != null && (usesDeclaredType || (sameType && TypeIds.isUserType(elementFieldType.typeId))); - final expectedElementTypeInfo = needsExpectedElementType - ? context.typeResolver.tryResolveFieldType(elementFieldType) - : null; + final expectedElementTypeInfo = + needsExpectedElementType + ? context.typeResolver.tryResolveFieldType(elementFieldType) + : null; final declaredTypeInfo = usesDeclaredType ? expectedElementTypeInfo : null; - final sameTypeInfo = (!usesDeclaredType && sameType) - ? context.readTypeMetaValue(expectedElementTypeInfo) - : null; + final sameTypeInfo = + (!usesDeclaredType && sameType) + ? context.readTypeMetaValue(expectedElementTypeInfo) + : null; final tracksDepth = (declaredTypeInfo != null && tracksNestedPayloadDepth(declaredTypeInfo)) || diff --git a/dart/packages/fory/lib/src/serializer/map_serializers.dart b/dart/packages/fory/lib/src/serializer/map_serializers.dart index 0391699b23..a872837bab 100644 --- a/dart/packages/fory/lib/src/serializer/map_serializers.dart +++ b/dart/packages/fory/lib/src/serializer/map_serializers.dart @@ -26,6 +26,8 @@ import 'package:fory/src/resolver/type_resolver.dart'; import 'package:fory/src/serializer/collection_serializers.dart'; import 'package:fory/src/serializer/serializer.dart'; +const int _referenceBytes = 4; + abstract final class MapFlags { static const int trackingKeyRef = 0x01; static const int keyHasNull = 0x02; @@ -56,13 +58,14 @@ final class MapSerializer extends Serializer { required bool trackRef, }) { context.buffer.writeVarUint32(values.length); - final declaredKeyTypeInfo = keyFieldType == null || keyFieldType.isDynamic - ? null - : context.typeResolver.resolveFieldType(keyFieldType); + final declaredKeyTypeInfo = + keyFieldType == null || keyFieldType.isDynamic + ? null + : context.typeResolver.resolveFieldType(keyFieldType); final declaredValueTypeInfo = valueFieldType == null || valueFieldType.isDynamic - ? null - : context.typeResolver.resolveFieldType(valueFieldType); + ? null + : context.typeResolver.resolveFieldType(valueFieldType); final keyDeclared = declaredKeyTypeInfo != null && usesDeclaredTypeInfo( @@ -105,17 +108,17 @@ final class MapSerializer extends Serializer { (keyDeclared ? declaredKeyTypeInfo.supportsRef : (key == null || - context.typeResolver - .resolveValue(key as Object) - .supportsRef)); + context.typeResolver + .resolveValue(key as Object) + .supportsRef)); final valueTrackRef = valueRequestedRef && (valueDeclared ? declaredValueTypeInfo.supportsRef : (value == null || - context.typeResolver - .resolveValue(value as Object) - .supportsRef)); + context.typeResolver + .resolveValue(value as Object) + .supportsRef)); _writeNullChunk( context, key, @@ -131,12 +134,14 @@ final class MapSerializer extends Serializer { ); continue; } - final chunkKeyTypeInfo = keyDeclared - ? declaredKeyTypeInfo - : context.typeResolver.resolveValue(key as Object); - final chunkValueTypeInfo = valueDeclared - ? declaredValueTypeInfo - : context.typeResolver.resolveValue(value as Object); + final chunkKeyTypeInfo = + keyDeclared + ? declaredKeyTypeInfo + : context.typeResolver.resolveValue(key as Object); + final chunkValueTypeInfo = + valueDeclared + ? declaredValueTypeInfo + : context.typeResolver.resolveValue(value as Object); final chunkKeyTrackRef = keyRequestedRef && chunkKeyTypeInfo.supportsRef; final chunkValueTrackRef = valueRequestedRef && chunkValueTypeInfo.supportsRef; @@ -186,12 +191,14 @@ final class MapSerializer extends Serializer { pendingEntry = nextEntry; break; } - final nextKeyTypeInfo = keyDeclared - ? declaredKeyTypeInfo - : context.typeResolver.resolveValue(nextKey as Object); - final nextValueTypeInfo = valueDeclared - ? declaredValueTypeInfo - : context.typeResolver.resolveValue(nextValue as Object); + final nextKeyTypeInfo = + keyDeclared + ? declaredKeyTypeInfo + : context.typeResolver.resolveValue(nextKey as Object); + final nextValueTypeInfo = + valueDeclared + ? declaredValueTypeInfo + : context.typeResolver.resolveValue(nextValue as Object); final nextKeyTrackRef = keyRequestedRef && nextKeyTypeInfo.supportsRef; final nextValueTrackRef = valueRequestedRef && nextValueTypeInfo.supportsRef; @@ -252,15 +259,16 @@ Map readTypedMapPayload( bool hasPreservedRef = false, }) { var remaining = context.buffer.readVarUint32(); - context.reserveMapMemory(remaining); + context.reserveContainerMemory(remaining * 2 * _referenceBytes); context.buffer.checkReadableBytes(remaining); - final declaredKeyTypeInfo = keyFieldType == null || keyFieldType.isDynamic - ? null - : context.typeResolver.resolveFieldType(keyFieldType); + final declaredKeyTypeInfo = + keyFieldType == null || keyFieldType.isDynamic + ? null + : context.typeResolver.resolveFieldType(keyFieldType); final declaredValueTypeInfo = valueFieldType == null || valueFieldType.isDynamic - ? null - : context.typeResolver.resolveFieldType(valueFieldType); + ? null + : context.typeResolver.resolveFieldType(valueFieldType); final result = {}; if (hasPreservedRef) { context.reference(result); @@ -308,32 +316,34 @@ Map readTypedMapPayload( context.increaseDepth(); } for (var index = 0; index < chunkSize; index += 1) { - final key = keyDeclared - ? _readDeclaredMapValue( - context, - keyFieldType!, - declaredKeyTypeInfo!, - trackRef: keyTrackRef, - ) - : _readResolvedMapValue( - context, - keyTypeInfo!, - null, - trackRef: keyTrackRef, - ); - final value = valueDeclared - ? _readDeclaredMapValue( - context, - valueFieldType!, - declaredValueTypeInfo!, - trackRef: valueTrackRef, - ) - : _readResolvedMapValue( - context, - valueTypeInfo!, - null, - trackRef: valueTrackRef, - ); + final key = + keyDeclared + ? _readDeclaredMapValue( + context, + keyFieldType!, + declaredKeyTypeInfo!, + trackRef: keyTrackRef, + ) + : _readResolvedMapValue( + context, + keyTypeInfo!, + null, + trackRef: keyTrackRef, + ); + final value = + valueDeclared + ? _readDeclaredMapValue( + context, + valueFieldType!, + declaredValueTypeInfo!, + trackRef: valueTrackRef, + ) + : _readResolvedMapValue( + context, + valueTypeInfo!, + null, + trackRef: valueTrackRef, + ); result[convertKey(key)] = convertValue(value); } if (tracksDepth) { diff --git a/dart/packages/fory/test/container_memory_budget_test.dart b/dart/packages/fory/test/container_memory_budget_test.dart index 61d6970300..e8dc46bc71 100644 --- a/dart/packages/fory/test/container_memory_budget_test.dart +++ b/dart/packages/fory/test/container_memory_budget_test.dart @@ -41,10 +41,7 @@ class BudgetGeneratedEnvelope { @SetField(element: StringType()) Set tags = {}; - @MapField( - key: StringType(), - value: Int32Type(encoding: Encoding.fixed), - ) + @MapField(key: StringType(), value: Int32Type(encoding: Encoding.fixed)) Map counts = {}; } @@ -150,25 +147,25 @@ void main() { expect(() => Fory(maxContainerMemoryBytes: -2), throwsArgumentError); }); - test('charges nested empty containers', () { + test('uses parent storage for nested empty containers', () { final value = [[]]; - expect(() => _readWithBudget(value, 51), _throwsContainerBudget); - expect(_readWithBudget(value, 52), equals(value)); + expect(() => _readWithBudget(value, 3), _throwsContainerBudget); + expect(_readWithBudget(value, 4), equals(value)); }); test('charges sibling containers cumulatively', () { final value = [[], [], []]; - expect(() => _readWithBudget(value, 107), _throwsContainerBudget); - expect(_readWithBudget(value, 108), equals(value)); + expect(() => _readWithBudget(value, 11), _throwsContainerBudget); + expect(_readWithBudget(value, 12), equals(value)); }); - test('charges map table and entries', () { + test('charges map entries', () { final value = {'a': 1}; - expect(() => _readWithBudget(value, 99), _throwsContainerBudget); - expect(_readWithBudget(value, 100), equals(value)); + expect(() => _readWithBudget(value, 7), _throwsContainerBudget); + expect(_readWithBudget(value, 8), equals(value)); }); test('charges generated list set and map reads', () { @@ -181,14 +178,14 @@ void main() { ..counts = {'one': 1}, ); - final failingReader = Fory(maxContainerMemoryBytes: 183); + final failingReader = Fory(maxContainerMemoryBytes: 19); _registerGenerated(failingReader); expect( () => failingReader.deserialize(bytes), _throwsContainerBudget, ); - final passingReader = Fory(maxContainerMemoryBytes: 184); + final passingReader = Fory(maxContainerMemoryBytes: 20); _registerGenerated(passingReader); final roundTrip = passingReader.deserialize( bytes, @@ -205,14 +202,14 @@ void main() { BudgetCompatibleListEnvelope()..values = [1, 2, 3], ); - final arrayFail = Fory(maxContainerMemoryBytes: 27); + final arrayFail = Fory(maxContainerMemoryBytes: 11); _registerCompatibleArray(arrayFail); expect( () => arrayFail.deserialize(listBytes), _throwsContainerBudget, ); - final arrayPass = Fory(maxContainerMemoryBytes: 28); + final arrayPass = Fory(maxContainerMemoryBytes: 12); _registerCompatibleArray(arrayPass); expect( arrayPass @@ -229,14 +226,14 @@ void main() { ..values = Int32List.fromList([1, 2, 3]), ); - final listFail = Fory(maxContainerMemoryBytes: 35); + final listFail = Fory(maxContainerMemoryBytes: 11); _registerCompatibleList(listFail); expect( () => listFail.deserialize(arrayBytes), _throwsContainerBudget, ); - final listPass = Fory(maxContainerMemoryBytes: 36); + final listPass = Fory(maxContainerMemoryBytes: 12); _registerCompatibleList(listPass); expect( listPass.deserialize(arrayBytes).values, @@ -260,9 +257,10 @@ void main() { }); test('keeps byte availability checks before allocation', () { - final listBuffer = Buffer() - ..writeVarUint32(64) - ..writeUint8(0); + final listBuffer = + Buffer() + ..writeVarUint32(64) + ..writeUint8(0); final listContext = _readContext(listBuffer); expect( () => ListSerializer.readPayload(listContext, null), diff --git a/docs/guide/cpp/configuration.md b/docs/guide/cpp/configuration.md index 199d7cab72..3041fe04df 100644 --- a/docs/guide/cpp/configuration.md +++ b/docs/guide/cpp/configuration.md @@ -112,13 +112,14 @@ automatic limit is the root input size multiplied by `8`, plus `64 KiB`. For stream roots, the automatic limit is `128 MiB` because the full root size is not known up front. Positive values always override the automatic limit. -This budget is a portable lower-bound estimate for container-owned memory such -as collection objects, backing storage, map value storage, and object/reference -arrays. It is not an exact process heap limit and does not include STL -implementation details such as debug nodes or allocator headers. Dedicated -string, binary, and primitive dense-array payloads continue to rely on their -byte-availability checks instead. `std::vector` is counted as packed -standard-container storage. +This budget is a portable lower-bound estimate for container-owned storage such +as dynamic collection backing storage, map key/value storage, and +object/reference array slots. It is not an exact process heap limit and does +not include STL implementation details such as debug nodes, table buckets, or +allocator headers. Empty containers with no dynamic backing normally do not +consume the budget. Dedicated string, binary, and primitive dense-array payloads +continue to rely on their byte-availability checks instead. `std::vector` +is counted as packed standard-container storage. **Default:** `-1` diff --git a/docs/guide/csharp/configuration.md b/docs/guide/csharp/configuration.md index c9e8e80cf6..b463e3226b 100644 --- a/docs/guide/csharp/configuration.md +++ b/docs/guide/csharp/configuration.md @@ -99,7 +99,8 @@ Fory fory = Fory.Builder() ### `MaxContainerMemoryBytes(long value)` -Sets the maximum estimated container-owned memory accepted during one root deserialization. +Sets the maximum estimated lower-bound container-owned storage accepted during one root +deserialization. ```csharp Fory fory = Fory.Builder() @@ -188,8 +189,8 @@ Security-related configuration: - Register only the expected types before deserializing untrusted payloads. - Use `CheckStructVersion(true)` with `Compatible(false)` for intentional same-schema payloads. - Set `MaxDepth(...)` to reject unexpectedly deep dynamic object graphs. -- Set `MaxContainerMemoryBytes(...)` to cap estimated list, array, set, and map memory during one - root deserialization. +- Set `MaxContainerMemoryBytes(...)` to cap estimated lower-bound list, array, set, and map storage + during one root deserialization. - Keep the remote schema metadata limits at their defaults unless the data is not malicious and a trusted peer sends larger metadata or many schema versions. - Prefer generated or registered concrete models over broad dynamic fields for untrusted input. diff --git a/docs/guide/dart/configuration.md b/docs/guide/dart/configuration.md index c7f851e253..d1466d83af 100644 --- a/docs/guide/dart/configuration.md +++ b/docs/guide/dart/configuration.md @@ -110,10 +110,11 @@ final fory = Fory( ### `maxContainerMemoryBytes` -Limits estimated container-owned memory for one root deserialization. The budget covers Dart lists, -sets, maps, object/reference arrays, and compatible list/array materialization. It does not count -strings, binary values, or dense typed-array payloads, which are protected by byte-availability -checks. +Limits estimated lower-bound container-owned storage for one root deserialization. The budget +covers Dart list/set/object-reference slots, map key/value slots, and compatible list/array +materialization. Empty containers without backing storage normally do not consume the budget. It +does not count strings, binary values, or dense typed-array payloads, which are protected by +byte-availability checks. The default is `-1`, which means auto. Dart root inputs are memory-backed, so auto derives from the root input size: diff --git a/docs/guide/go/configuration.md b/docs/guide/go/configuration.md index d1cb294c9c..636fb4a22c 100644 --- a/docs/guide/go/configuration.md +++ b/docs/guide/go/configuration.md @@ -144,8 +144,10 @@ inputBytes * 8 + 64 KiB ``` `DeserializeFromReader` and `DeserializeFromStream` use `128 MiB` because the -full root length is unknown. The budget covers Go slices, maps, sets, and -generated container reads. Strings, binary blobs, and primitive dense array +full root length is unknown. The budget covers lower-bound slice backing +storage, map key/value storage, sets, and generated container reads. Empty +containers without backing storage normally do not consume the budget. Strings, +binary blobs, and primitive dense array owners keep their byte-availability checks and are not charged to this budget. Set a positive value when a service needs a stricter or larger limit for trusted data. diff --git a/docs/guide/java/configuration.md b/docs/guide/java/configuration.md index 4e46d512a7..ec4d7b1eba 100644 --- a/docs/guide/java/configuration.md +++ b/docs/guide/java/configuration.md @@ -99,9 +99,10 @@ Security-related options: - `requireClassRegistration(true)` restricts deserialization to registered classes. - `withMaxDepth(...)` rejects unexpectedly deep object graphs. -- `withMaxContainerMemoryBytes(...)` bounds estimated container-owned memory during one root - deserialization. Keep `-1` for the automatic input-shaped default, or set a positive byte limit - when trusted payloads need a larger or smaller limit. +- `withMaxContainerMemoryBytes(...)` bounds estimated lower-bound container-owned storage during + one root deserialization. Empty containers without backing storage normally do not consume the + budget. Keep `-1` for the automatic input-shaped default, or set a positive byte limit when + trusted payloads need a larger or smaller limit. - `withMaxTypeFields(...)` and `withMaxTypeMetaBytes(...)` bound the field count and encoded body size of one received remote metadata body. - `withMaxSchemaVersionsPerType(...)` and diff --git a/docs/guide/javascript/configuration.md b/docs/guide/javascript/configuration.md index 71175c301e..4f732c040c 100644 --- a/docs/guide/javascript/configuration.md +++ b/docs/guide/javascript/configuration.md @@ -96,10 +96,13 @@ generated from Fory schema IDL. See [Schema Evolution](schema-evolution.md). ## Container Memory Budget -`maxContainerMemoryBytes` limits estimated memory committed by arrays, sets, -maps, and container backing storage during one root deserialization. The default -`-1` derives an automatic limit from the input bytes. JavaScript deserializes -from `Uint8Array` roots, so the automatic limit is `inputBytes * 8 + 64 KiB`. +`maxContainerMemoryBytes` limits estimated lower-bound container-owned storage +accepted during one root deserialization. The budget covers array, set, object +array, and map reference slots; it is not an exact JavaScript heap limit. Empty +containers without backing storage normally do not consume the budget. The +default `-1` derives an automatic limit from the input bytes. JavaScript +deserializes from `Uint8Array` roots, so the automatic limit is +`inputBytes * 8 + 64 KiB`. Use a positive byte value to set an explicit lower or higher limit: diff --git a/docs/guide/python/configuration.md b/docs/guide/python/configuration.md index 26cb42e50f..5922a204b9 100644 --- a/docs/guide/python/configuration.md +++ b/docs/guide/python/configuration.md @@ -225,8 +225,9 @@ Received remote metadata is also limited: - `max_type_meta_bytes` limits the encoded body bytes accepted for one received TypeDef body. - `max_schema_versions_per_type` limits accepted remote metadata versions for one logical type. - `max_average_schema_versions_per_type` limits the average across accepted remote types. -- `max_container_memory_bytes` limits estimated list, tuple, set, dict, and object-array storage - created during one root deserialization. The default `-1` uses `input_bytes * 8 + 64 KiB` for +- `max_container_memory_bytes` limits estimated lower-bound list, tuple, set, dict, and + object-array storage created during one root deserialization. Empty containers without backing + storage normally do not consume the budget. The default `-1` uses `input_bytes * 8 + 64 KiB` for known-length inputs and `128 MiB` for stream inputs. Set a positive byte value for trusted payloads that legitimately contain larger container graphs. diff --git a/docs/guide/rust/configuration.md b/docs/guide/rust/configuration.md index 6e04126bf0..f47295ce70 100644 --- a/docs/guide/rust/configuration.md +++ b/docs/guide/rust/configuration.md @@ -112,9 +112,10 @@ let fory = Fory::builder() ### Container Memory Budget -`max_container_memory_bytes(...)` limits the estimated memory that deserialization may allocate for -containers such as lists, sets, and maps during one root read. The default is `-1`, which selects an -automatic limit based on the input size: +`max_container_memory_bytes(...)` limits the estimated lower-bound container-owned storage accepted +during one root read. The budget covers `Vec`/collection element storage and map key/value storage; +it is not an exact process heap limit. Empty containers without backing storage normally do not +consume the budget. The default is `-1`, which selects an automatic limit based on the input size: ```rust let fory = Fory::builder().max_container_memory_bytes(-1).build(); diff --git a/docs/guide/swift/configuration.md b/docs/guide/swift/configuration.md index 40b80744fb..1301050edf 100644 --- a/docs/guide/swift/configuration.md +++ b/docs/guide/swift/configuration.md @@ -93,10 +93,11 @@ let fory = Fory(compatible: false, checkClassVersion: true) `maxDepth` bounds decoded payload nesting depth. -`maxContainerMemoryBytes` bounds the estimated container-owned memory accepted during one root -deserialization. Use `-1` for the default automatic limit. Swift roots are currently `Data` or -`ByteBuffer`, so auto uses the root input byte length times `8`, plus `64 KiB`. A positive value -overrides the automatic limit. `0` and negative values other than `-1` are rejected. +`maxContainerMemoryBytes` bounds the estimated lower-bound container-owned storage accepted during +one root deserialization. Swift roots are currently `Data` or `ByteBuffer`, so auto uses the root +input byte length times `8`, plus `64 KiB`. Empty containers without backing storage normally do +not consume the budget. Use `-1` for the default automatic limit; a positive value overrides it. +`0` and negative values other than `-1` are rejected. Compatible-mode remote metadata is also limited: @@ -148,7 +149,7 @@ Security-related configuration: - Register only the expected generated models before deserializing untrusted payloads. - Use `checkClassVersion` with `compatible: false` for intentional same-schema payloads. - Set `maxDepth` for the largest nesting depth your service accepts. -- Set `maxContainerMemoryBytes` to cap estimated list, set, array, and map memory during one root - deserialization. +- Set `maxContainerMemoryBytes` to cap estimated lower-bound list, set, array, and map storage + during one root deserialization. - Keep the remote schema metadata limits at their defaults unless the data is not malicious and a trusted peer sends larger metadata or many schema versions. diff --git a/docs/security/deserialization.md b/docs/security/deserialization.md index 39003627f6..4bbf4a8cc2 100644 --- a/docs/security/deserialization.md +++ b/docs/security/deserialization.md @@ -223,11 +223,18 @@ Container budget accounting should: - happen in root-operation read state, with cleanup owned by the root deserialization `finally`; +- keep read context/read state limited to raw byte reservation and generic + counted-byte arithmetic; collection/map/array storage formulas belong in the + concrete serializer or generated serializer owner; - reject arithmetic overflow before comparing budget or allocating; -- charge fixed container object cost, backing capacity, map table and entry - overhead where the runtime has cheap reliable signals, reference arrays, and - inline or value storage where a runtime stores elements inline; -- charge fixed cost even for zero-size containers; +- estimate lower-bound owner storage: reference-backed containers and + object/reference arrays charge reference slots, inline/value containers charge + element storage, reference-backed maps charge two references per entry, and + inline/value maps charge key plus value storage; +- treat fixed/header cost as zero by default, charging it only when the owner + path creates an independently allocated container/control entity that is not + already covered by parent inline/value storage and the charged size is a + documented conservative lower bound; - preserve existing byte-availability checks before backing allocation or capacity reservation; - skip dedicated string, binary, primitive array, and primitive dense-array @@ -241,9 +248,12 @@ the inline element storage instead of treating those elements as references. General inline-value containers must not be skipped just because dedicated primitive dense arrays are skipped. -Native runtimes may use conservative lower-bound estimates when exact container -layout is not portable. For example, C++ STL node, allocator, and debug-mode -overheads should not be guessed when only value storage is reliably known. +Runtimes should not guess object headers, array headers, allocator headers, +debug-mode fields, hash buckets, tree links, hash-chain links, node headers, +map-entry objects, spare blocks, or runtime table layouts unless the owner path +has a cheap, stable, explicit lower-bound storage signal and documents the +formula. C++ STL node, allocator, and debug-mode overheads should not be guessed +when only value storage is reliably known. ## Skip Semantics diff --git a/docs/specification/xlang_implementation_guide.md b/docs/specification/xlang_implementation_guide.md index 2fd57ffa5e..f960461708 100644 --- a/docs/specification/xlang_implementation_guide.md +++ b/docs/specification/xlang_implementation_guide.md @@ -406,18 +406,28 @@ configuration uses `inputBytes * 8 + 64 KiB` for known-length root input and fixed `128 MiB` for true stream or unknown-length root input. Do not add dynamic stream bytes-read accounting for this budget. -The budget estimates container-owned memory, not exact heap bytes. Charge fixed -container object cost, backing capacity, map table and entry overhead where the -runtime has cheap reliable signals, reference arrays, and inline/value element -storage where the runtime stores container elements inline. Charge zero-size -containers for their fixed cost. Skip dedicated string, binary, primitive -array, and primitive dense-array owners, but do not skip general inline-value -containers such as vectors or lists of value objects. If reference slot size is -not cheap or reliable to query, use a 4-byte reference slot. Native runtimes may -use conservative lower-bound estimates instead of guessing non-portable -container, allocator, or debug-layout details. Reject arithmetic overflow before -budget comparison or allocation, and keep the existing `checkReadableBytes` -proof before backing allocation or capacity reservation. +Read context or equivalent read state owns only raw byte accounting and generic +counted-byte arithmetic, such as reserving `bytes` or `count * elementBytes` +with overflow checks. It must not expose collection/map/array semantic +reservation APIs. Concrete serializers and generated serializer owners compute +the storage constants and formulas for the container path they allocate. + +The budget estimates lower-bound container-owned storage, not exact heap bytes. +Reference-backed containers and object/reference arrays charge reference slots; +inline/value containers charge element storage; reference-backed maps charge two +references per entry; and inline/value maps charge key plus value storage. +Fixed/header cost defaults to zero and is charged only when the owner path +creates an independently allocated container/control entity, that entity is not +already covered by parent inline/value storage, and the charged size is a +documented conservative lower bound. Empty containers with no dynamic backing +normally charge zero. Skip dedicated string, binary, primitive array, and +primitive dense-array owners, but do not skip general inline-value containers +such as vectors or lists of value objects. If reference slot size is not cheap +or reliable to query, use a 4-byte reference slot. Native runtimes may use +conservative lower-bound estimates instead of guessing non-portable container, +allocator, table, node, entry, or debug-layout details. Reject arithmetic +overflow before budget comparison or allocation, and keep the existing +`checkReadableBytes` proof before backing allocation or capacity reservation. For TypeDef or TypeMeta bodies, first prove that the encoded metadata body bytes are readable through the byte owner. Field-list allocation should happen after diff --git a/go/fory/array.go b/go/fory/array.go index 93b81a85c2..c0d58034aa 100644 --- a/go/fory/array.go +++ b/go/fory/array.go @@ -318,7 +318,7 @@ func (s arrayDynSerializer) Write(ctx *WriteContext, refMode RefMode, writeType func (s arrayDynSerializer) ReadData(ctx *ReadContext, value reflect.Value) { // Create a temp slice to read into, then copy back to array sliceType := reflect.SliceOf(value.Type().Elem()) - if !ctx.reserveSliceTypeMemory(value.Len(), value.Type().Elem()) { + if !ctx.ReserveCountedContainerMemory(value.Len(), int64(value.Type().Elem().Size())) { return } tempSlice := reflect.MakeSlice(sliceType, value.Len(), value.Len()) diff --git a/go/fory/codegen/decoder.go b/go/fory/codegen/decoder.go index 3a25284909..f2fe674df8 100644 --- a/go/fory/codegen/decoder.go +++ b/go/fory/codegen/decoder.go @@ -172,7 +172,7 @@ func generateFieldReadTyped(buf *bytes.Buffer, field *FieldInfo) error { fmt.Fprintf(buf, "\t\t\tif ctx.HasError() {\n") fmt.Fprintf(buf, "\t\t\t\treturn ctx.TakeError()\n") fmt.Fprintf(buf, "\t\t\t}\n") - fmt.Fprintf(buf, "\t\t\tif !ctx.ReserveSliceMemory(sliceLen, %s) {\n", unsafeSizeExpr(elemType)) + fmt.Fprintf(buf, "\t\t\tif !ctx.ReserveCountedContainerMemory(sliceLen, %s) {\n", unsafeSizeExpr(elemType)) fmt.Fprintf(buf, "\t\t\t\treturn ctx.TakeError()\n") fmt.Fprintf(buf, "\t\t\t}\n") fmt.Fprintf(buf, "\t\t\tif sliceLen == 0 {\n") @@ -203,7 +203,7 @@ func generateFieldReadTyped(buf *bytes.Buffer, field *FieldInfo) error { fmt.Fprintf(buf, "\t\t\t\tif ctx.HasError() {\n") fmt.Fprintf(buf, "\t\t\t\t\treturn ctx.TakeError()\n") fmt.Fprintf(buf, "\t\t\t\t}\n") - fmt.Fprintf(buf, "\t\t\t\tif !ctx.ReserveSliceMemory(sliceLen, %s) {\n", unsafeSizeExpr(elemType)) + fmt.Fprintf(buf, "\t\t\t\tif !ctx.ReserveCountedContainerMemory(sliceLen, %s) {\n", unsafeSizeExpr(elemType)) fmt.Fprintf(buf, "\t\t\t\t\treturn ctx.TakeError()\n") fmt.Fprintf(buf, "\t\t\t\t}\n") fmt.Fprintf(buf, "\t\t\t\tif sliceLen == 0 {\n") @@ -510,7 +510,7 @@ func generateSliceReadInline(buf *bytes.Buffer, sliceType *types.Slice, fieldAcc fmt.Fprintf(buf, "\t\t\tif ctx.HasError() {\n") fmt.Fprintf(buf, "\t\t\t\treturn ctx.TakeError()\n") fmt.Fprintf(buf, "\t\t\t}\n") - fmt.Fprintf(buf, "\t\t\tif !ctx.ReserveSliceMemory(sliceLen, %s) {\n", unsafeSizeExpr(elemType)) + fmt.Fprintf(buf, "\t\t\tif !ctx.ReserveCountedContainerMemory(sliceLen, %s) {\n", unsafeSizeExpr(elemType)) fmt.Fprintf(buf, "\t\t\t\treturn ctx.TakeError()\n") fmt.Fprintf(buf, "\t\t\t}\n") fmt.Fprintf(buf, "\t\t\tif sliceLen == 0 {\n") @@ -531,7 +531,7 @@ func generateSliceReadInline(buf *bytes.Buffer, sliceType *types.Slice, fieldAcc fmt.Fprintf(buf, "\t\t\t\tif ctx.HasError() {\n") fmt.Fprintf(buf, "\t\t\t\t\treturn ctx.TakeError()\n") fmt.Fprintf(buf, "\t\t\t\t}\n") - fmt.Fprintf(buf, "\t\t\t\tif !ctx.ReserveSliceMemory(sliceLen, %s) {\n", unsafeSizeExpr(elemType)) + fmt.Fprintf(buf, "\t\t\t\tif !ctx.ReserveCountedContainerMemory(sliceLen, %s) {\n", unsafeSizeExpr(elemType)) fmt.Fprintf(buf, "\t\t\t\t\treturn ctx.TakeError()\n") fmt.Fprintf(buf, "\t\t\t\t}\n") fmt.Fprintf(buf, "\t\t\t\tif sliceLen == 0 {\n") @@ -560,7 +560,7 @@ func generateSliceReadInlineNoNull(buf *bytes.Buffer, sliceType *types.Slice, fi fmt.Fprintf(buf, "%sif ctx.HasError() {\n", indent) fmt.Fprintf(buf, "%s\treturn ctx.TakeError()\n", indent) fmt.Fprintf(buf, "%s}\n", indent) - fmt.Fprintf(buf, "%sif !ctx.ReserveSliceMemory(sliceLen, %s) {\n", indent, unsafeSizeExpr(elemType)) + fmt.Fprintf(buf, "%sif !ctx.ReserveCountedContainerMemory(sliceLen, %s) {\n", indent, unsafeSizeExpr(elemType)) fmt.Fprintf(buf, "%s\treturn ctx.TakeError()\n", indent) fmt.Fprintf(buf, "%s}\n", indent) fmt.Fprintf(buf, "%sif sliceLen == 0 {\n", indent) @@ -586,7 +586,7 @@ func generateSliceReadInlineNoNull(buf *bytes.Buffer, sliceType *types.Slice, fi fmt.Fprintf(buf, "%sif ctx.HasError() {\n", indent) fmt.Fprintf(buf, "%s\treturn ctx.TakeError()\n", indent) fmt.Fprintf(buf, "%s}\n", indent) - fmt.Fprintf(buf, "%sif !ctx.ReserveSliceMemory(sliceLen, %s) {\n", indent, unsafeSizeExpr(elemType)) + fmt.Fprintf(buf, "%sif !ctx.ReserveCountedContainerMemory(sliceLen, %s) {\n", indent, unsafeSizeExpr(elemType)) fmt.Fprintf(buf, "%s\treturn ctx.TakeError()\n", indent) fmt.Fprintf(buf, "%s}\n", indent) fmt.Fprintf(buf, "%sif sliceLen == 0 {\n", indent) @@ -855,7 +855,7 @@ func generateMapReadInline(buf *bytes.Buffer, mapType *types.Map, fieldAccess st fmt.Fprintf(buf, "\t\t\tif ctx.HasError() {\n") fmt.Fprintf(buf, "\t\t\t\treturn ctx.TakeError()\n") fmt.Fprintf(buf, "\t\t\t}\n") - fmt.Fprintf(buf, "\t\t\tif !ctx.ReserveMapMemory(mapLen, %s, %s) {\n", unsafeSizeExpr(keyType), unsafeSizeExpr(valueType)) + fmt.Fprintf(buf, "\t\t\tif !ctx.ReserveCountedContainerMemory(mapLen, %s + %s) {\n", unsafeSizeExpr(keyType), unsafeSizeExpr(valueType)) fmt.Fprintf(buf, "\t\t\t\treturn ctx.TakeError()\n") fmt.Fprintf(buf, "\t\t\t}\n") fmt.Fprintf(buf, "\t\t\tif mapLen == 0 {\n") @@ -876,7 +876,7 @@ func generateMapReadInline(buf *bytes.Buffer, mapType *types.Map, fieldAccess st fmt.Fprintf(buf, "\t\t\t\tif ctx.HasError() {\n") fmt.Fprintf(buf, "\t\t\t\t\treturn ctx.TakeError()\n") fmt.Fprintf(buf, "\t\t\t\t}\n") - fmt.Fprintf(buf, "\t\t\t\tif !ctx.ReserveMapMemory(mapLen, %s, %s) {\n", unsafeSizeExpr(keyType), unsafeSizeExpr(valueType)) + fmt.Fprintf(buf, "\t\t\t\tif !ctx.ReserveCountedContainerMemory(mapLen, %s + %s) {\n", unsafeSizeExpr(keyType), unsafeSizeExpr(valueType)) fmt.Fprintf(buf, "\t\t\t\t\treturn ctx.TakeError()\n") fmt.Fprintf(buf, "\t\t\t\t}\n") fmt.Fprintf(buf, "\t\t\t\tif mapLen == 0 {\n") @@ -914,7 +914,7 @@ func generateMapReadInlineNoNull(buf *bytes.Buffer, mapType *types.Map, fieldAcc fmt.Fprintf(buf, "%sif ctx.HasError() {\n", indent) fmt.Fprintf(buf, "%s\treturn ctx.TakeError()\n", indent) fmt.Fprintf(buf, "%s}\n", indent) - fmt.Fprintf(buf, "%sif !ctx.ReserveMapMemory(mapLen, %s, %s) {\n", indent, unsafeSizeExpr(keyType), unsafeSizeExpr(valueType)) + fmt.Fprintf(buf, "%sif !ctx.ReserveCountedContainerMemory(mapLen, %s + %s) {\n", indent, unsafeSizeExpr(keyType), unsafeSizeExpr(valueType)) fmt.Fprintf(buf, "%s\treturn ctx.TakeError()\n", indent) fmt.Fprintf(buf, "%s}\n", indent) fmt.Fprintf(buf, "%sif mapLen == 0 {\n", indent) diff --git a/go/fory/container_memory_budget_test.go b/go/fory/container_memory_budget_test.go index e83e8d7a03..00281f1e13 100644 --- a/go/fory/container_memory_budget_test.go +++ b/go/fory/container_memory_budget_test.go @@ -122,37 +122,39 @@ func TestContainerMemoryBudgetEmptyAndCumulative(t *testing.T) { data, err := New(WithCompatible(false)).Serialize([]any{}) require.NoError(t, err) var empty []any - err = New(WithCompatible(false), WithMaxContainerMemoryBytes(sliceObjectBytes-1)).Deserialize(data, &empty) - require.Error(t, err) - require.Contains(t, err.Error(), "maxContainerMemoryBytes") + err = New(WithCompatible(false), WithMaxContainerMemoryBytes(1)).Deserialize(data, &empty) + require.NoError(t, err) + require.Empty(t, empty) writer := New(WithCompatible(false)) require.NoError(t, writer.RegisterStructByName(budgetSiblings{}, "test.BudgetSiblings")) - data, err = writer.Serialize(&budgetSiblings{A: []string{}, B: []string{}}) + data, err = writer.Serialize(&budgetSiblings{A: []string{"a"}, B: []string{"b"}}) require.NoError(t, err) - reader := New(WithCompatible(false), WithMaxContainerMemoryBytes(sliceObjectBytes)) + reader := New(WithCompatible(false), WithMaxContainerMemoryBytes(stringElementBytes)) require.NoError(t, reader.RegisterStructByName(budgetSiblings{}, "test.BudgetSiblings")) var out budgetSiblings err = reader.Deserialize(data, &out) require.Error(t, err) require.Contains(t, err.Error(), "maxContainerMemoryBytes") + reader = New(WithCompatible(false), WithMaxContainerMemoryBytes(2*stringElementBytes)) + require.NoError(t, reader.RegisterStructByName(budgetSiblings{}, "test.BudgetSiblings")) + require.NoError(t, reader.Deserialize(data, &out)) + require.Equal(t, []string{"a"}, out.A) + require.Equal(t, []string{"b"}, out.B) } func TestContainerMemoryBudgetMapAndOverflow(t *testing.T) { data, err := New().Serialize(map[string]string{"k": "v"}) require.NoError(t, err) var out map[string]string - oneEntryBudget := mapObjectBytes + - 2*referenceSlotBytes + - mapEntryOverheadBytes + referenceSlotBytes + - containerSizeOf[string]() + containerSizeOf[string]() + oneEntryBudget := containerSizeOf[string]() + containerSizeOf[string]() err = New(WithMaxContainerMemoryBytes(oneEntryBudget-1)).Deserialize(data, &out) require.Error(t, err) require.Contains(t, err.Error(), "maxContainerMemoryBytes") ctx := NewReadContext(false) ctx.initContainerMemoryBudget(0, true) - require.False(t, ctx.ReserveMapMemory(MaxInt, MaxInt64, 1)) + require.False(t, ctx.ReserveCountedContainerMemory(MaxInt, MaxInt64)) require.Contains(t, ctx.CheckError().Error(), "overflows") } @@ -160,7 +162,7 @@ func TestContainerMemoryBudgetSlicesAndInlineValues(t *testing.T) { data, err := New().Serialize([]string{"a"}) require.NoError(t, err) var stringsOut []string - err = New(WithMaxContainerMemoryBytes(sliceObjectBytes+containerSizeOf[string]()-1)).Deserialize(data, &stringsOut) + err = New(WithMaxContainerMemoryBytes(containerSizeOf[string]()-1)).Deserialize(data, &stringsOut) require.Error(t, err) require.Contains(t, err.Error(), "maxContainerMemoryBytes") @@ -168,7 +170,7 @@ func TestContainerMemoryBudgetSlicesAndInlineValues(t *testing.T) { require.NoError(t, writer.RegisterStructByName(budgetItem{}, "test.BudgetItem")) data, err = writer.Serialize([]budgetItem{{A: 1}}) require.NoError(t, err) - reader := New(WithMaxContainerMemoryBytes(sliceObjectBytes + containerSizeOf[budgetItem]() - 1)) + reader := New(WithMaxContainerMemoryBytes(containerSizeOf[budgetItem]() - 1)) require.NoError(t, reader.RegisterStructByName(budgetItem{}, "test.BudgetItem")) var items []budgetItem err = reader.Deserialize(data, &items) diff --git a/go/fory/field_serializer.go b/go/fory/field_serializer.go index d3ef3a787b..b4cc586c75 100644 --- a/go/fory/field_serializer.go +++ b/go/fory/field_serializer.go @@ -74,7 +74,7 @@ func newDeclaredSliceSerializer(type_ reflect.Type, elemSerializer Serializer, r elemSerializer: elemSerializer, referencable: referencable, elemBytes: elemBytes, - maxLength: maxSliceLength(elemBytes), + maxLength: maxContainerCount(elemBytes), }, nil } diff --git a/go/fory/map.go b/go/fory/map.go index fdb8ebbc53..d3a661edeb 100644 --- a/go/fory/map.go +++ b/go/fory/map.go @@ -303,7 +303,14 @@ func (s mapSerializer) ReadData(ctx *ReadContext, value reflect.Value) { iface := reflect.TypeOf((*any)(nil)).Elem() mapType = reflect.MapOf(iface, iface) } - if !ctx.reserveMapTypeMemory(size, mapType.Key(), mapType.Elem()) { + keyBytes := int64(mapType.Key().Size()) + valueBytes := int64(mapType.Elem().Size()) + elemBytes := keyBytes + valueBytes + if elemBytes < keyBytes { + ctx.setContainerMemoryError("map entry size overflows: key=%d value=%d", keyBytes, valueBytes) + return + } + if !ctx.ReserveCountedContainerMemory(size, elemBytes) { return } if size == 0 { diff --git a/go/fory/map_primitive.go b/go/fory/map_primitive.go index 5a4e925ade..e52b88b5b1 100644 --- a/go/fory/map_primitive.go +++ b/go/fory/map_primitive.go @@ -26,24 +26,24 @@ import ( // ============================================================================ var ( - stringStringMapElemBytes = mapElementMemory(stringElementBytes, stringElementBytes) - stringStringMapMaxLength = maxMapLength(stringStringMapElemBytes) - stringInt64MapElemBytes = mapElementMemory(stringElementBytes, containerSizeOf[int64]()) - stringInt64MapMaxLength = maxMapLength(stringInt64MapElemBytes) - stringInt32MapElemBytes = mapElementMemory(stringElementBytes, containerSizeOf[int32]()) - stringInt32MapMaxLength = maxMapLength(stringInt32MapElemBytes) - stringIntMapElemBytes = mapElementMemory(stringElementBytes, containerSizeOf[int]()) - stringIntMapMaxLength = maxMapLength(stringIntMapElemBytes) - stringFloat64MapElemBytes = mapElementMemory(stringElementBytes, containerSizeOf[float64]()) - stringFloat64MapMaxLength = maxMapLength(stringFloat64MapElemBytes) - stringBoolMapElemBytes = mapElementMemory(stringElementBytes, containerSizeOf[bool]()) - stringBoolMapMaxLength = maxMapLength(stringBoolMapElemBytes) - int32Int32MapElemBytes = mapElementMemory(containerSizeOf[int32](), containerSizeOf[int32]()) - int32Int32MapMaxLength = maxMapLength(int32Int32MapElemBytes) - int64Int64MapElemBytes = mapElementMemory(containerSizeOf[int64](), containerSizeOf[int64]()) - int64Int64MapMaxLength = maxMapLength(int64Int64MapElemBytes) - intIntMapElemBytes = mapElementMemory(containerSizeOf[int](), containerSizeOf[int]()) - intIntMapMaxLength = maxMapLength(intIntMapElemBytes) + stringStringMapElemBytes = stringElementBytes + stringElementBytes + stringStringMapMaxLength = maxContainerCount(stringStringMapElemBytes) + stringInt64MapElemBytes = stringElementBytes + containerSizeOf[int64]() + stringInt64MapMaxLength = maxContainerCount(stringInt64MapElemBytes) + stringInt32MapElemBytes = stringElementBytes + containerSizeOf[int32]() + stringInt32MapMaxLength = maxContainerCount(stringInt32MapElemBytes) + stringIntMapElemBytes = stringElementBytes + containerSizeOf[int]() + stringIntMapMaxLength = maxContainerCount(stringIntMapElemBytes) + stringFloat64MapElemBytes = stringElementBytes + containerSizeOf[float64]() + stringFloat64MapMaxLength = maxContainerCount(stringFloat64MapElemBytes) + stringBoolMapElemBytes = stringElementBytes + containerSizeOf[bool]() + stringBoolMapMaxLength = maxContainerCount(stringBoolMapElemBytes) + int32Int32MapElemBytes = containerSizeOf[int32]() + containerSizeOf[int32]() + int32Int32MapMaxLength = maxContainerCount(int32Int32MapElemBytes) + int64Int64MapElemBytes = containerSizeOf[int64]() + containerSizeOf[int64]() + int64Int64MapMaxLength = maxContainerCount(int64Int64MapElemBytes) + intIntMapElemBytes = containerSizeOf[int]() + containerSizeOf[int]() + intIntMapMaxLength = maxContainerCount(intIntMapElemBytes) ) // writeMapStringString writes map[string]string using chunk protocol @@ -94,7 +94,7 @@ func readTypedMapSize(ctx *ReadContext, elemBytes int64, maxLength int64) (int, if ctx.HasError() { return 0, false } - if !ctx.reserveMapMemory(size, elemBytes, maxLength) { + if !ctx.reserveCountedContainerMemory(size, elemBytes, maxLength) { return 0, false } if size == 0 { diff --git a/go/fory/reader.go b/go/fory/reader.go index b3d6301d65..2222d98690 100644 --- a/go/fory/reader.go +++ b/go/fory/reader.go @@ -53,36 +53,22 @@ const ( knownRootBudgetMultiplier = int64(8) knownRootBudgetSlackBytes = int64(64 * 1024) streamRootBudgetBytes = int64(128 * 1024 * 1024) - sliceObjectBytes = int64(unsafe.Sizeof([]byte(nil))) - mapObjectBytes = int64(48) - mapEntryOverheadBytes = int64(16) ) var referenceSlotBytes = int64(unsafe.Sizeof(uintptr(0))) var stringElementBytes = containerSizeOf[string]() -var stringSliceMaxLength = maxSliceLength(stringElementBytes) +var stringMaxLength = maxContainerCount(stringElementBytes) func containerSizeOf[T any]() int64 { var v T return int64(unsafe.Sizeof(v)) } -func maxSliceLength(elemBytes int64) int64 { +func maxContainerCount(elemBytes int64) int64 { if elemBytes == 0 { return MaxInt64 } - return (MaxInt64 - sliceObjectBytes) / elemBytes -} - -func mapElementMemory(keyBytes int64, valueBytes int64) int64 { - return keyBytes + valueBytes + mapEntryOverheadBytes + referenceSlotBytes + 2*referenceSlotBytes -} - -func maxMapLength(elemBytes int64) int64 { - if elemBytes == 0 { - return MaxInt64 - } - return (MaxInt64 - mapObjectBytes) / elemBytes + return MaxInt64 / elemBytes } // IsXlang returns whether cross-language serialization mode is enabled @@ -138,84 +124,31 @@ func (c *ReadContext) initContainerMemoryBudget(rootInputBytes int, unknownLengt c.remainingContainerMemoryBytes = limit } -// ReserveSliceMemory reserves estimated memory for a Go slice backing array before allocation. -func (c *ReadContext) ReserveSliceMemory(length int, elemBytes int64) bool { - if elemBytes < 0 { - c.setContainerMemoryError("negative container element size: %d", elemBytes) - return false - } - return c.reserveSliceMemory(length, elemBytes, maxSliceLength(elemBytes)) -} - -func (c *ReadContext) reserveSliceMemory(length int, elemBytes int64, maxLength int64) bool { +// ReserveCountedContainerMemory reserves length * elementBytes estimated container bytes. +func (c *ReadContext) ReserveCountedContainerMemory(length int, elemBytes int64) bool { if length < 0 { c.setContainerMemoryError("negative container length: %d", length) return false } - if int64(length) > maxLength { - c.setContainerMemoryError("container memory estimate overflows: length=%d elementBytes=%d", length, elemBytes) - return false - } - bytes := sliceObjectBytes + int64(length)*elemBytes - remaining := c.remainingContainerMemoryBytes - if bytes > remaining { - c.setContainerMemoryExceeded(bytes, remaining) - return false - } - c.remainingContainerMemoryBytes = remaining - bytes - return true -} - -func (c *ReadContext) reserveSliceTypeMemory(length int, elemType reflect.Type) bool { - elemBytes := referenceSlotBytes - if elemType != nil { - elemBytes = int64(elemType.Size()) - } - return c.ReserveSliceMemory(length, elemBytes) -} - -// ReserveMapMemory reserves estimated memory for a Go map before allocation or size hinting. -func (c *ReadContext) ReserveMapMemory(length int, keyBytes int64, valueBytes int64) bool { - if keyBytes < 0 || valueBytes < 0 { - c.setContainerMemoryError("negative map element size: key=%d value=%d", keyBytes, valueBytes) - return false - } - perEntry := keyBytes + valueBytes - if perEntry < keyBytes || perEntry > MaxInt64-mapEntryOverheadBytes-referenceSlotBytes { - c.setContainerMemoryError("map element size overflows: key=%d value=%d", keyBytes, valueBytes) - return false - } - perEntry += mapEntryOverheadBytes + referenceSlotBytes - if perEntry > MaxInt64-2*referenceSlotBytes { - c.setContainerMemoryError("map entry size overflows: key=%d value=%d", keyBytes, valueBytes) + if elemBytes < 0 { + c.setContainerMemoryError("negative container element size: %d", elemBytes) return false } - elemBytes := perEntry + 2*referenceSlotBytes - return c.reserveMapMemory(length, elemBytes, maxMapLength(elemBytes)) -} - -func (c *ReadContext) reserveMapTypeMemory(length int, keyType reflect.Type, valueType reflect.Type) bool { - keyBytes := referenceSlotBytes - valueBytes := referenceSlotBytes - if keyType != nil { - keyBytes = int64(keyType.Size()) - } - if valueType != nil { - valueBytes = int64(valueType.Size()) + if length == 0 { + return true } - return c.ReserveMapMemory(length, keyBytes, valueBytes) + return c.reserveCountedContainerMemory(length, elemBytes, maxContainerCount(elemBytes)) } -func (c *ReadContext) reserveMapMemory(length int, elemBytes int64, maxLength int64) bool { - if length < 0 { - c.setContainerMemoryError("negative container length: %d", length) - return false +func (c *ReadContext) reserveCountedContainerMemory(length int, elemBytes int64, maxLength int64) bool { + if length == 0 { + return true } if int64(length) > maxLength { c.setContainerMemoryError("container memory estimate overflows: length=%d elementBytes=%d", length, elemBytes) return false } - bytes := mapObjectBytes + int64(length)*elemBytes + bytes := int64(length) * elemBytes remaining := c.remainingContainerMemoryBytes if bytes > remaining { c.setContainerMemoryExceeded(bytes, remaining) @@ -225,22 +158,6 @@ func (c *ReadContext) reserveMapMemory(length int, elemBytes int64, maxLength in return true } -func (c *ReadContext) reserveCountedMemory(length int, fixedBytes int64, elemBytes int64) bool { - if length < 0 { - c.setContainerMemoryError("negative container length: %d", length) - return false - } - if fixedBytes < 0 || elemBytes < 0 { - c.setContainerMemoryError("negative container memory estimate: fixed=%d elem=%d", fixedBytes, elemBytes) - return false - } - if elemBytes != 0 && int64(length) > (MaxInt64-fixedBytes)/elemBytes { - c.setContainerMemoryError("container memory estimate overflows: length=%d elementBytes=%d", length, elemBytes) - return false - } - return c.ReserveContainerMemory(fixedBytes + int64(length)*elemBytes) -} - // ReserveContainerMemory reserves raw estimated container-owned bytes. func (c *ReadContext) ReserveContainerMemory(bytes int64) bool { if bytes < 0 { @@ -739,7 +656,7 @@ func (c *ReadContext) readStringSliceData() []string { if c.HasError() { return nil } - if !c.reserveSliceMemory(length, containerSizeOf[string](), stringSliceMaxLength) { + if !c.reserveCountedContainerMemory(length, stringElementBytes, stringMaxLength) { return nil } if length == 0 { diff --git a/go/fory/set.go b/go/fory/set.go index 652f8ddca9..b3899c3ef8 100644 --- a/go/fory/set.go +++ b/go/fory/set.go @@ -318,7 +318,14 @@ func (s setSerializer) ReadData(ctx *ReadContext, value reflect.Value) { return } if length == 0 { - if !ctx.reserveMapTypeMemory(length, type_.Key(), type_.Elem()) { + keyBytes := int64(type_.Key().Size()) + valueBytes := int64(type_.Elem().Size()) + elemBytes := keyBytes + valueBytes + if elemBytes < keyBytes { + ctx.setContainerMemoryError("map entry size overflows: key=%d value=%d", keyBytes, valueBytes) + return + } + if !ctx.ReserveCountedContainerMemory(length, elemBytes) { return } // Initialize empty set if length is 0 @@ -359,7 +366,14 @@ func (s setSerializer) ReadData(ctx *ReadContext, value reflect.Value) { if !buf.CheckReadable(length, err) { return } - if !ctx.reserveMapTypeMemory(length, type_.Key(), type_.Elem()) { + keyBytes := int64(type_.Key().Size()) + valueBytes := int64(type_.Elem().Size()) + elemBytes := keyBytes + valueBytes + if elemBytes < keyBytes { + ctx.setContainerMemoryError("map entry size overflows: key=%d value=%d", keyBytes, valueBytes) + return + } + if !ctx.ReserveCountedContainerMemory(length, elemBytes) { return } diff --git a/go/fory/slice.go b/go/fory/slice.go index 56d4d4845f..ad1bc38fc6 100644 --- a/go/fory/slice.go +++ b/go/fory/slice.go @@ -152,7 +152,7 @@ func newSliceSerializer(type_ reflect.Type, elemSerializer Serializer, xlang boo elemSerializer: elemSerializer, referencable: isRefType(elem, xlang), elemBytes: elemBytes, - maxLength: maxSliceLength(elemBytes), + maxLength: maxContainerCount(elemBytes), }, nil } @@ -313,15 +313,15 @@ func (s *sliceSerializer) ReadData(ctx *ReadContext, value reflect.Value) { } isArrayType := value.Type().Kind() == reflect.Array - if !isArrayType && !ctx.reserveSliceMemory(length, s.elemBytes, s.maxLength) { - return - } if length == 0 { if !isArrayType { value.Set(reflect.MakeSlice(value.Type(), 0, 0)) } return } + if !isArrayType && !ctx.reserveCountedContainerMemory(length, s.elemBytes, s.maxLength) { + return + } // ReadData collection flags collectFlag := buf.ReadInt8(ctxErr) diff --git a/go/fory/slice_dyn.go b/go/fory/slice_dyn.go index d341e9455b..11e3eb05c4 100644 --- a/go/fory/slice_dyn.go +++ b/go/fory/slice_dyn.go @@ -45,7 +45,7 @@ func newSliceDynSerializer(elemType reflect.Type) (*sliceDynSerializer, error) { return &sliceDynSerializer{ isInterfaceElem: true, elemBytes: elemBytes, - maxLength: maxSliceLength(elemBytes), + maxLength: maxContainerCount(elemBytes), }, nil } // Validate element type is interface or pointer to interface @@ -61,7 +61,7 @@ func newSliceDynSerializer(elemType reflect.Type) (*sliceDynSerializer, error) { isInterfaceElem: isInterface, isPointerElem: isPointerToInterface, elemBytes: elemBytes, - maxLength: maxSliceLength(elemBytes), + maxLength: maxContainerCount(elemBytes), }, nil } @@ -283,13 +283,13 @@ func (s *sliceDynSerializer) readData(ctx *ReadContext, value reflect.Value, exp return } allocatedByCaller := expectedLength >= 0 - if !allocatedByCaller && !ctx.reserveSliceMemory(length, s.elemBytes, s.maxLength) { - return - } if length == 0 { value.Set(reflect.MakeSlice(sliceType, 0, 0)) return } + if !allocatedByCaller && !ctx.reserveCountedContainerMemory(length, s.elemBytes, s.maxLength) { + return + } collectFlag := buf.ReadInt8(ctxErr) if ctx.HasError() { diff --git a/go/fory/slice_primitive.go b/go/fory/slice_primitive.go index 88e5d50b08..b809467c01 100644 --- a/go/fory/slice_primitive.go +++ b/go/fory/slice_primitive.go @@ -652,7 +652,7 @@ func (s stringSliceSerializer) ReadData(ctx *ReadContext, value reflect.Value) { return } ptr := (*[]string)(value.Addr().UnsafePointer()) - if !ctx.reserveSliceMemory(length, stringElementBytes, stringSliceMaxLength) { + if !ctx.reserveCountedContainerMemory(length, stringElementBytes, stringMaxLength) { return } if length == 0 { diff --git a/go/fory/slice_primitive_list.go b/go/fory/slice_primitive_list.go index e033fb4409..f76f0b66b4 100644 --- a/go/fory/slice_primitive_list.go +++ b/go/fory/slice_primitive_list.go @@ -35,7 +35,7 @@ func newPrimitiveList(type_ reflect.Type, elemTypeID TypeId, elemType reflect.Ty type_: type_, elemTypeID: elemTypeID, elemBytes: elemBytes, - maxLength: maxSliceLength(elemBytes), + maxLength: maxContainerCount(elemBytes), } } @@ -179,7 +179,7 @@ func (s primitiveListSerializer) ReadData(ctx *ReadContext, value reflect.Value) if ctx.HasError() { return } - if !ctx.reserveSliceMemory(length, s.elemBytes, s.maxLength) { + if !ctx.reserveCountedContainerMemory(length, s.elemBytes, s.maxLength) { return } if length == 0 { @@ -243,7 +243,7 @@ func (s compatiblePrimitiveListToArraySerializer) ReadData(ctx *ReadContext, val } if length == 0 { if value.Kind() == reflect.Slice { - if !ctx.reserveSliceMemory(length, s.listReader.elemBytes, s.listReader.maxLength) { + if !ctx.reserveCountedContainerMemory(length, s.listReader.elemBytes, s.listReader.maxLength) { return } value.Set(reflect.MakeSlice(value.Type(), 0, 0)) @@ -284,7 +284,7 @@ func (s compatiblePrimitiveListToArraySerializer) ReadData(ctx *ReadContext, val return } if value.Kind() == reflect.Slice { - if !ctx.reserveSliceMemory(length, s.listReader.elemBytes, s.listReader.maxLength) { + if !ctx.reserveCountedContainerMemory(length, s.listReader.elemBytes, s.listReader.maxLength) { return } temp := reflect.New(value.Type()).Elem() diff --git a/go/fory/tests/structs_fory_gen.go b/go/fory/tests/structs_fory_gen.go index ca639979a5..a3dae17b33 100644 --- a/go/fory/tests/structs_fory_gen.go +++ b/go/fory/tests/structs_fory_gen.go @@ -190,7 +190,7 @@ func (g *DynamicSliceDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v if ctx.HasError() { return ctx.TakeError() } - if !ctx.ReserveSliceMemory(sliceLen, int64(unsafe.Sizeof(*new(any)))) { + if !ctx.ReserveCountedContainerMemory(sliceLen, int64(unsafe.Sizeof(*new(any)))) { return ctx.TakeError() } if sliceLen == 0 { @@ -221,7 +221,7 @@ func (g *DynamicSliceDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v if ctx.HasError() { return ctx.TakeError() } - if !ctx.ReserveSliceMemory(sliceLen, int64(unsafe.Sizeof(*new(any)))) { + if !ctx.ReserveCountedContainerMemory(sliceLen, int64(unsafe.Sizeof(*new(any)))) { return ctx.TakeError() } if sliceLen == 0 { @@ -672,7 +672,7 @@ func (g *MapDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v *MapDemo) if ctx.HasError() { return ctx.TakeError() } - if !ctx.ReserveMapMemory(mapLen, int64(unsafe.Sizeof(*new(int))), int64(unsafe.Sizeof(*new(int)))) { + if !ctx.ReserveCountedContainerMemory(mapLen, int64(unsafe.Sizeof(*new(int)))+int64(unsafe.Sizeof(*new(int)))) { return ctx.TakeError() } if mapLen == 0 { @@ -722,7 +722,7 @@ func (g *MapDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v *MapDemo) if ctx.HasError() { return ctx.TakeError() } - if !ctx.ReserveMapMemory(mapLen, int64(unsafe.Sizeof(*new(int))), int64(unsafe.Sizeof(*new(int)))) { + if !ctx.ReserveCountedContainerMemory(mapLen, int64(unsafe.Sizeof(*new(int)))+int64(unsafe.Sizeof(*new(int)))) { return ctx.TakeError() } if mapLen == 0 { @@ -774,7 +774,7 @@ func (g *MapDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v *MapDemo) if ctx.HasError() { return ctx.TakeError() } - if !ctx.ReserveMapMemory(mapLen, int64(unsafe.Sizeof(*new(string))), int64(unsafe.Sizeof(*new(int)))) { + if !ctx.ReserveCountedContainerMemory(mapLen, int64(unsafe.Sizeof(*new(string)))+int64(unsafe.Sizeof(*new(int)))) { return ctx.TakeError() } if mapLen == 0 { @@ -824,7 +824,7 @@ func (g *MapDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v *MapDemo) if ctx.HasError() { return ctx.TakeError() } - if !ctx.ReserveMapMemory(mapLen, int64(unsafe.Sizeof(*new(string))), int64(unsafe.Sizeof(*new(int)))) { + if !ctx.ReserveCountedContainerMemory(mapLen, int64(unsafe.Sizeof(*new(string)))+int64(unsafe.Sizeof(*new(int)))) { return ctx.TakeError() } if mapLen == 0 { @@ -876,7 +876,7 @@ func (g *MapDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v *MapDemo) if ctx.HasError() { return ctx.TakeError() } - if !ctx.ReserveMapMemory(mapLen, int64(unsafe.Sizeof(*new(string))), int64(unsafe.Sizeof(*new(string)))) { + if !ctx.ReserveCountedContainerMemory(mapLen, int64(unsafe.Sizeof(*new(string)))+int64(unsafe.Sizeof(*new(string)))) { return ctx.TakeError() } if mapLen == 0 { @@ -926,7 +926,7 @@ func (g *MapDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v *MapDemo) if ctx.HasError() { return ctx.TakeError() } - if !ctx.ReserveMapMemory(mapLen, int64(unsafe.Sizeof(*new(string))), int64(unsafe.Sizeof(*new(string)))) { + if !ctx.ReserveCountedContainerMemory(mapLen, int64(unsafe.Sizeof(*new(string)))+int64(unsafe.Sizeof(*new(string)))) { return ctx.TakeError() } if mapLen == 0 { @@ -1287,7 +1287,7 @@ func (g *SliceDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v *SliceD if ctx.HasError() { return ctx.TakeError() } - if !ctx.ReserveSliceMemory(sliceLen, int64(unsafe.Sizeof(*new(bool)))) { + if !ctx.ReserveCountedContainerMemory(sliceLen, int64(unsafe.Sizeof(*new(bool)))) { return ctx.TakeError() } if sliceLen == 0 { @@ -1329,7 +1329,7 @@ func (g *SliceDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v *SliceD if ctx.HasError() { return ctx.TakeError() } - if !ctx.ReserveSliceMemory(sliceLen, int64(unsafe.Sizeof(*new(bool)))) { + if !ctx.ReserveCountedContainerMemory(sliceLen, int64(unsafe.Sizeof(*new(bool)))) { return ctx.TakeError() } if sliceLen == 0 { @@ -1373,7 +1373,7 @@ func (g *SliceDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v *SliceD if ctx.HasError() { return ctx.TakeError() } - if !ctx.ReserveSliceMemory(sliceLen, int64(unsafe.Sizeof(*new(float64)))) { + if !ctx.ReserveCountedContainerMemory(sliceLen, int64(unsafe.Sizeof(*new(float64)))) { return ctx.TakeError() } if sliceLen == 0 { @@ -1415,7 +1415,7 @@ func (g *SliceDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v *SliceD if ctx.HasError() { return ctx.TakeError() } - if !ctx.ReserveSliceMemory(sliceLen, int64(unsafe.Sizeof(*new(float64)))) { + if !ctx.ReserveCountedContainerMemory(sliceLen, int64(unsafe.Sizeof(*new(float64)))) { return ctx.TakeError() } if sliceLen == 0 { @@ -1459,7 +1459,7 @@ func (g *SliceDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v *SliceD if ctx.HasError() { return ctx.TakeError() } - if !ctx.ReserveSliceMemory(sliceLen, int64(unsafe.Sizeof(*new(int32)))) { + if !ctx.ReserveCountedContainerMemory(sliceLen, int64(unsafe.Sizeof(*new(int32)))) { return ctx.TakeError() } if sliceLen == 0 { @@ -1501,7 +1501,7 @@ func (g *SliceDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v *SliceD if ctx.HasError() { return ctx.TakeError() } - if !ctx.ReserveSliceMemory(sliceLen, int64(unsafe.Sizeof(*new(int32)))) { + if !ctx.ReserveCountedContainerMemory(sliceLen, int64(unsafe.Sizeof(*new(int32)))) { return ctx.TakeError() } if sliceLen == 0 { @@ -1545,7 +1545,7 @@ func (g *SliceDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v *SliceD if ctx.HasError() { return ctx.TakeError() } - if !ctx.ReserveSliceMemory(sliceLen, int64(unsafe.Sizeof(*new(string)))) { + if !ctx.ReserveCountedContainerMemory(sliceLen, int64(unsafe.Sizeof(*new(string)))) { return ctx.TakeError() } if sliceLen == 0 { @@ -1595,7 +1595,7 @@ func (g *SliceDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v *SliceD if ctx.HasError() { return ctx.TakeError() } - if !ctx.ReserveSliceMemory(sliceLen, int64(unsafe.Sizeof(*new(string)))) { + if !ctx.ReserveCountedContainerMemory(sliceLen, int64(unsafe.Sizeof(*new(string)))) { return ctx.TakeError() } if sliceLen == 0 { diff --git a/java/fory-core/src/main/java/org/apache/fory/context/ReadContext.java b/java/fory-core/src/main/java/org/apache/fory/context/ReadContext.java index 162ccdd9f5..5c61cab2ff 100644 --- a/java/fory-core/src/main/java/org/apache/fory/context/ReadContext.java +++ b/java/fory-core/src/main/java/org/apache/fory/context/ReadContext.java @@ -54,12 +54,6 @@ public final class ReadContext { private static final long KNOWN_ROOT_BUDGET_MULTIPLIER = 8L; private static final long KNOWN_ROOT_BUDGET_SLACK_BYTES = 64L * 1024; private static final long STREAM_ROOT_BUDGET_BYTES = 128L * 1024 * 1024; - private static final long COLLECTION_OBJECT_BYTES = 24L; - private static final long MAP_OBJECT_BYTES = 48L; - private static final long ARRAY_HEADER_BYTES = 16L; - private static final long MAP_ENTRY_BYTES = 32L; - private static final int REFERENCE_BYTES = MemoryBuffer.objectArrayIndexScale(); - private final Config config; private final Generics generics; private final TypeResolver typeResolver; @@ -351,25 +345,6 @@ public Config getConfig() { return config; } - public void reserveCollectionMemory(int numElements) { - reserveContainerMemory(COLLECTION_OBJECT_BYTES + (long) numElements * REFERENCE_BYTES); - } - - public void reserveCollectionCapacity(int numElements, int capacity) { - reserveContainerMemory((long) (capacity - numElements) * REFERENCE_BYTES); - } - - public void reserveMapMemory(int numElements) { - long entries = (long) numElements; - long tableBytes = entries * 2 * REFERENCE_BYTES; - long entryBytes = entries * (MAP_ENTRY_BYTES + 3L * REFERENCE_BYTES); - reserveContainerMemory(MAP_OBJECT_BYTES + tableBytes + entryBytes); - } - - public void reserveObjectArrayMemory(int numElements) { - reserveContainerMemory(ARRAY_HEADER_BYTES + (long) numElements * REFERENCE_BYTES); - } - public void reserveContainerMemory(long bytes) { if (bytes < 0) { throwNegativeContainerMemory(bytes); diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/ArraySerializers.java b/java/fory-core/src/main/java/org/apache/fory/serializer/ArraySerializers.java index 52237e7082..148930ee3b 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/ArraySerializers.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/ArraySerializers.java @@ -45,6 +45,8 @@ * object-array paths avoid adapter allocation. */ public final class ArraySerializers { + private static final int REFERENCE_BYTES = MemoryBuffer.objectArrayIndexScale(); + private ArraySerializers() {} private static void throwInvalidObjectArraySize(int size) { @@ -59,7 +61,7 @@ private static int readObjectArraySize(ReadContext readContext) { if (numElements < 0) { throwInvalidObjectArraySize(numElements); } - readContext.reserveObjectArrayMemory(numElements); + readContext.reserveContainerMemory((long) numElements * REFERENCE_BYTES); buffer.checkReadableBytes(numElements); return numElements; } diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/CompatibleCollectionArrayReader.java b/java/fory-core/src/main/java/org/apache/fory/serializer/CompatibleCollectionArrayReader.java index b5853d433b..38ed3a99d8 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/CompatibleCollectionArrayReader.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/CompatibleCollectionArrayReader.java @@ -59,6 +59,8 @@ import org.apache.fory.type.Types; final class CompatibleCollectionArrayReader { + private static final int REFERENCE_BYTES = MemoryBuffer.objectArrayIndexScale(); + static final int READ_LIST_TO_ARRAY = 1; static final int READ_ARRAY_TO_LIST = 2; static final int READ_LIST_TO_LIST = 3; @@ -979,7 +981,7 @@ private static List readNullableListBoxedElements( ReadContext readContext, int numElements, int arrayTypeId, int elementTypeId) { MemoryBuffer buffer = readContext.getBuffer(); int bodyBytes = minReadablePrimitiveListBytes(numElements, elementTypeId, true); - readContext.reserveCollectionMemory(numElements); + readContext.reserveContainerMemory((long) numElements * REFERENCE_BYTES); buffer.checkReadableBytes(bodyBytes); ArrayList values = new ArrayList<>(numElements); for (int i = 0; i < numElements; i++) { @@ -1179,7 +1181,7 @@ private static boolean canMaterializePrimitiveListTarget(Class targetType, in private static List materializeBoxedList( ReadContext readContext, Object array, int arrayTypeId) { int size = java.lang.reflect.Array.getLength(array); - readContext.reserveCollectionMemory(size); + readContext.reserveContainerMemory((long) size * REFERENCE_BYTES); ArrayList list = new ArrayList<>(size); switch (arrayTypeId) { case Types.BOOL_ARRAY: diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/CollectionLikeSerializer.java b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/CollectionLikeSerializer.java index b6151f45a9..bc71691597 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/CollectionLikeSerializer.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/CollectionLikeSerializer.java @@ -46,6 +46,8 @@ */ @SuppressWarnings({"unchecked", "rawtypes"}) public abstract class CollectionLikeSerializer extends Serializer { + private static final int REFERENCE_BYTES = MemoryBuffer.objectArrayIndexScale(); + private MethodHandle constructor; private int numElements; protected final Config config; @@ -564,7 +566,7 @@ protected final int readCollectionSize(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); int numElements = buffer.readVarUInt32Small7(); checkCollectionSize(numElements); - readContext.reserveCollectionMemory(numElements); + readContext.reserveContainerMemory((long) numElements * REFERENCE_BYTES); buffer.checkReadableBytes(numElements); return numElements; } diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/CollectionSerializers.java b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/CollectionSerializers.java index 2456377485..8c3850e88b 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/CollectionSerializers.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/CollectionSerializers.java @@ -74,6 +74,8 @@ */ @SuppressWarnings({"unchecked", "rawtypes"}) public class CollectionSerializers { + private static final int REFERENCE_BYTES = MemoryBuffer.objectArrayIndexScale(); + private static final Comparator NATURAL_ORDER_COMPARATOR = Comparator.naturalOrder(); private static void requireXlangNaturalOrdering(Class type, Comparator comparator) { @@ -927,7 +929,7 @@ public ArrayBlockingQueue newCollection(ReadContext readContext) { setNumElements(numElements); int capacity = buffer.readVarUInt32Small7(); checkBoundedQueueCapacity(numElements, capacity); - readContext.reserveCollectionCapacity(numElements, capacity); + readContext.reserveContainerMemory((long) (capacity - numElements) * REFERENCE_BYTES); buffer.checkReadableBytes(capacity); ArrayBlockingQueue queue = new ArrayBlockingQueue<>(capacity); readContext.reference(queue); diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/MapLikeSerializer.java b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/MapLikeSerializer.java index 9ac93e96a2..adf772c257 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/MapLikeSerializer.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/MapLikeSerializer.java @@ -58,6 +58,7 @@ @SuppressWarnings({"unchecked", "rawtypes"}) public abstract class MapLikeSerializer extends Serializer { public static final int MAX_CHUNK_SIZE = 255; + private static final int REFERENCE_BYTES = MemoryBuffer.objectArrayIndexScale(); static final class MapTypeCache { final TypeInfoHolder keyTypeInfoWriteCache; @@ -971,7 +972,7 @@ protected final int readMapSize(ReadContext readContext) { if (numElements > Integer.MAX_VALUE / 2) { throwInvalidMapBodySize(numElements); } - readContext.reserveMapMemory(numElements); + readContext.reserveContainerMemory((long) numElements * 2 * REFERENCE_BYTES); buffer.checkReadableBytes(numElements << 1); return numElements; } diff --git a/java/fory-core/src/test/java/org/apache/fory/serializer/ContainerMemoryBudgetTest.java b/java/fory-core/src/test/java/org/apache/fory/serializer/ContainerMemoryBudgetTest.java index 09b73c25d8..2ed3fca56d 100644 --- a/java/fory-core/src/test/java/org/apache/fory/serializer/ContainerMemoryBudgetTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/serializer/ContainerMemoryBudgetTest.java @@ -23,7 +23,6 @@ import static org.testng.Assert.assertThrows; import static org.testng.Assert.assertTrue; -import java.io.ByteArrayInputStream; import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; @@ -34,7 +33,6 @@ import org.apache.fory.context.ReadContext; import org.apache.fory.exception.DeserializationException; import org.apache.fory.exception.InsecureException; -import org.apache.fory.io.ForyInputStream; import org.apache.fory.memory.MemoryBuffer; import org.testng.annotations.Test; @@ -42,10 +40,6 @@ public class ContainerMemoryBudgetTest extends ForyTestBase { private static final long KNOWN_ROOT_MULTIPLIER = 8L; private static final long KNOWN_ROOT_SLACK_BYTES = 64L * 1024; private static final long STREAM_ROOT_BYTES = 128L * 1024 * 1024; - private static final long COLLECTION_OBJECT_BYTES = 24L; - private static final long MAP_OBJECT_BYTES = 48L; - private static final long ARRAY_HEADER_BYTES = 16L; - private static final long MAP_ENTRY_BYTES = 32L; private static final int REFERENCE_BYTES = MemoryBuffer.objectArrayIndexScale(); @Test @@ -79,12 +73,6 @@ public void testStreamAutoBudget() { } finally { readContext.reset(); } - - StreamPayload payload = findStreamPayload(); - assertThrows(InsecureException.class, () -> newFory(-1).deserialize(payload.bytes)); - Object copy = - newFory(-1).deserialize(new ForyInputStream(new ByteArrayInputStream(payload.bytes), 1)); - assertEquals(copy, payload.value); } @Test @@ -100,12 +88,13 @@ public void testExplicitBudgetWins() { } @Test - public void testNestedEmptyFixedCost() { + public void testNestedEmptyContainersUseParentStorage() { List value = emptyLists(1); byte[] bytes = newFory(-1).serialize(value); + long required = collectionBytes(1); - assertThrows(InsecureException.class, () -> newFory(collectionBytes(1)).deserialize(bytes)); - assertEquals(newFory(collectionBytes(1) + collectionBytes(0)).deserialize(bytes), value); + assertThrows(InsecureException.class, () -> newFory(required - 1).deserialize(bytes)); + assertEquals(newFory(required).deserialize(bytes), value); } @Test @@ -123,7 +112,7 @@ public void testMapBudgetAndOverflow() { Fory fory = newFory(mapBytes(1) - 1); ReadContext readContext = prepareContext(fory, 8, false); try { - assertThrows(InsecureException.class, () -> readContext.reserveMapMemory(1)); + assertThrows(InsecureException.class, () -> readContext.reserveContainerMemory(mapBytes(1))); } finally { readContext.reset(); } @@ -131,7 +120,7 @@ public void testMapBudgetAndOverflow() { Fory exactFory = newFory(mapBytes(1)); ReadContext exactContext = prepareContext(exactFory, 8, false); try { - exactContext.reserveMapMemory(1); + exactContext.reserveContainerMemory(mapBytes(1)); assertThrows(InsecureException.class, () -> exactContext.reserveContainerMemory(1)); } finally { exactContext.reset(); @@ -154,18 +143,7 @@ public void testMapBudgetAndOverflow() { @Test public void testObjectArrayBudget() { - Fory lowFory = newFory(objectArrayBytes(0) - 1); - ReadContext lowContext = lowFory.getReadContext(); - MemoryBuffer lowBuffer = objectArraySizeBuffer(0); - lowContext.prepare(lowBuffer, null, false, lowBuffer.remaining(), false); - try { - assertThrows( - InsecureException.class, () -> lowFory.getSerializer(Object[].class).read(lowContext)); - } finally { - lowContext.reset(); - } - - Fory exactFory = newFory(objectArrayBytes(0)); + Fory exactFory = newFory(1); ReadContext exactContext = exactFory.getReadContext(); MemoryBuffer exactBuffer = objectArraySizeBuffer(0); exactContext.prepare(exactBuffer, null, false, exactBuffer.remaining(), false); @@ -235,18 +213,15 @@ private static ReadContext prepareContext( } private static long collectionBytes(int numElements) { - return COLLECTION_OBJECT_BYTES + (long) numElements * REFERENCE_BYTES; + return (long) numElements * REFERENCE_BYTES; } private static long mapBytes(int numElements) { - long entries = numElements; - return MAP_OBJECT_BYTES - + entries * 2 * REFERENCE_BYTES - + entries * (MAP_ENTRY_BYTES + 3L * REFERENCE_BYTES); + return (long) numElements * 2 * REFERENCE_BYTES; } private static long objectArrayBytes(int numElements) { - return ARRAY_HEADER_BYTES + (long) numElements * REFERENCE_BYTES; + return (long) numElements * REFERENCE_BYTES; } private static long knownAutoBytes(int inputBytes) { @@ -273,14 +248,6 @@ private static List nullLists(int siblings, int childElements) { return root; } - private static List emptyMaps(int numElements) { - List root = new ArrayList<>(numElements); - for (int i = 0; i < numElements; i++) { - root.add(new HashMap<>()); - } - return root; - } - private static MemoryBuffer objectArraySizeBuffer(int numElements) { MemoryBuffer buffer = MemoryBuffer.newHeapBuffer(8); buffer.writeVarUInt32Small7(numElements); @@ -291,28 +258,4 @@ private static MemoryBuffer trimBuffer(MemoryBuffer buffer) { return MemoryBuffer.fromByteArray(buffer.getBytes(0, buffer.writerIndex())); } - private static StreamPayload findStreamPayload() { - Fory writer = newFory(-1); - int numElements = 128; - while (numElements <= 1 << 20) { - List value = emptyMaps(numElements); - byte[] bytes = writer.serialize(value); - long estimatedMemory = collectionBytes(numElements) + (long) numElements * mapBytes(0); - if (estimatedMemory > knownAutoBytes(bytes.length) && estimatedMemory < STREAM_ROOT_BYTES) { - return new StreamPayload(value, bytes); - } - numElements <<= 1; - } - throw new AssertionError("Unable to build compact stream-budget payload"); - } - - private static final class StreamPayload { - final List value; - final byte[] bytes; - - StreamPayload(List value, byte[] bytes) { - this.value = value; - this.bytes = bytes; - } - } } diff --git a/javascript/packages/core/lib/context.ts b/javascript/packages/core/lib/context.ts index f95c38d72f..dab5a3766e 100644 --- a/javascript/packages/core/lib/context.ts +++ b/javascript/packages/core/lib/context.ts @@ -533,11 +533,6 @@ export class ReadContext { private static readonly MIN_REMOTE_TYPE_META_LIMIT = 8192; private static readonly KNOWN_ROOT_BUDGET_MULTIPLIER = 8; private static readonly KNOWN_ROOT_BUDGET_SLACK_BYTES = 64 * 1024; - private static readonly COLLECTION_OBJECT_BYTES = 24; - private static readonly MAP_OBJECT_BYTES = 48; - private static readonly ARRAY_HEADER_BYTES = 16; - private static readonly MAP_ENTRY_BYTES = 32; - private static readonly REFERENCE_BYTES = 4; readonly reader: BinaryReader; readonly refReader: RefReader; @@ -585,41 +580,6 @@ export class ReadContext { this.remainingContainerMemoryBytes = this.effectiveContainerMemoryBytes; } - reserveCollectionMemory(numElements: number) { - const bytes - = ReadContext.COLLECTION_OBJECT_BYTES - + numElements * ReadContext.REFERENCE_BYTES; - const remaining = this.remainingContainerMemoryBytes - bytes; - if (remaining < 0) { - this.throwContainerBudgetExceeded(bytes); - } - this.remainingContainerMemoryBytes = remaining; - } - - reserveMapMemory(numElements: number) { - const bytes = ReadContext.MAP_OBJECT_BYTES - + numElements - * ( - ReadContext.REFERENCE_BYTES * 2 - + ReadContext.MAP_ENTRY_BYTES - + ReadContext.REFERENCE_BYTES * 3 - ); - const remaining = this.remainingContainerMemoryBytes - bytes; - if (remaining < 0) { - this.throwContainerBudgetExceeded(bytes); - } - this.remainingContainerMemoryBytes = remaining; - } - - reserveTypedArrayMemory(numElements: number, elementBytes: number) { - const bytes = ReadContext.ARRAY_HEADER_BYTES + numElements * elementBytes; - const remaining = this.remainingContainerMemoryBytes - bytes; - if (remaining < 0) { - this.throwContainerBudgetExceeded(bytes); - } - this.remainingContainerMemoryBytes = remaining; - } - reserveContainerMemory(bytes: number) { if (!Number.isSafeInteger(bytes) || bytes < 0) { this.throwContainerMemoryOverflow(bytes); diff --git a/javascript/packages/core/lib/gen/collection.ts b/javascript/packages/core/lib/gen/collection.ts index c551dffef4..f2c26be573 100644 --- a/javascript/packages/core/lib/gen/collection.ts +++ b/javascript/packages/core/lib/gen/collection.ts @@ -26,6 +26,8 @@ import { Scope } from "./scope"; import { AnyHelper } from "./any"; import type { ReadContext, WriteContext } from "../context"; +const REFERENCE_BYTES = 4; + export type CompatibleCollectionArrayReadAction = { target: "array" | "list"; elementTypeId: number; @@ -258,7 +260,7 @@ class CollectionAnySerializer { ): any { void fromRef; const len = this.readContext.reader.readVarUint32Small7(); - this.readContext.reserveCollectionMemory(len); + this.readContext.reserveContainerMemory(len * REFERENCE_BYTES); if (len === 0) { return createCollection(len); } @@ -445,8 +447,8 @@ export abstract class CollectionSerializerGenerator extends BaseSerializerGenera ? compatibleArrayCollectionExpr(compatibleReadAction!.elementTypeId, len) : this.newCollection(len); const reserveMemory = compatibleListToArray - ? `${readContextName}.reserveTypedArrayMemory(${len}, ${compatibleArrayElementBytes(compatibleReadAction!.elementTypeId)});` - : `${readContextName}.reserveCollectionMemory(${len});`; + ? `${readContextName}.reserveContainerMemory(${len} * ${compatibleArrayElementBytes(compatibleReadAction!.elementTypeId)});` + : `${readContextName}.reserveContainerMemory(${len} * ${REFERENCE_BYTES});`; const putAccessor = (item: string, index: string) => compatibleListToArray ? compatibleArrayPutAccessor(compatibleReadAction!.elementTypeId, result, item, index) diff --git a/javascript/packages/core/lib/gen/map.ts b/javascript/packages/core/lib/gen/map.ts index ebb5f3b588..020b3d51c1 100644 --- a/javascript/packages/core/lib/gen/map.ts +++ b/javascript/packages/core/lib/gen/map.ts @@ -26,6 +26,8 @@ import { Scope } from "./scope"; import { AnyHelper } from "./any"; import { ReadContext, WriteContext } from "../context"; +const REFERENCE_BYTES = 4; + const MapFlags = { /** Whether track elements ref. */ TRACKING_REF: 0b1, @@ -272,7 +274,7 @@ class MapAnySerializer { read(fromRef: boolean): any { let count = this.readContext.reader.readVarUint32Small7(); - this.readContext.reserveMapMemory(count); + this.readContext.reserveContainerMemory(count * 2 * REFERENCE_BYTES); const result = new Map(); if (fromRef) { this.readContext.reference(result); @@ -492,7 +494,7 @@ export class MapSerializerGenerator extends BaseSerializerGenerator { return ` let ${count} = ${this.builder.reader.readVarUint32Small7()}; - ${readContextName}.reserveMapMemory(${count}); + ${readContextName}.reserveContainerMemory(${count} * 2 * ${REFERENCE_BYTES}); const ${result} = new Map(); if (${refState}) { ${this.builder.referenceResolver.reference(result)} diff --git a/javascript/test/containerMemoryBudget.test.ts b/javascript/test/containerMemoryBudget.test.ts index 77907ea3e3..e79f8fc6f9 100644 --- a/javascript/test/containerMemoryBudget.test.ts +++ b/javascript/test/containerMemoryBudget.test.ts @@ -57,13 +57,14 @@ describe('container memory budget', () => { const fory = new Fory({ maxContainerMemoryBytes: 24 }); fory.readContext.reset(new Uint8Array(1)); - expect(() => fory.readContext.reserveCollectionMemory(0)).not.toThrow(); + expect(() => fory.readContext.reserveContainerMemory(0)).not.toThrow(); + expect(() => fory.readContext.reserveContainerMemory(24)).not.toThrow(); expect(() => fory.readContext.reserveContainerMemory(1)).toThrow( /maxContainerMemoryBytes/, ); }); - test('charges nested empty containers', () => { + test('uses parent storage for nested empty containers', () => { const typeInfo = Type.struct('budget.nested.empty', { values: Type.list(Type.list(Type.int32({ encoding: 'fixed' }))).setId(1), }); @@ -72,12 +73,12 @@ describe('container memory budget', () => { const passingReader = new Fory({ compatible: false, ref: true, - maxContainerMemoryBytes: 52, + maxContainerMemoryBytes: 4, }).register(typeInfo); const failingReader = new Fory({ compatible: false, ref: true, - maxContainerMemoryBytes: 51, + maxContainerMemoryBytes: 3, }).register(typeInfo); expect(() => failingReader.deserialize(bytes)).toThrow( @@ -97,12 +98,12 @@ describe('container memory budget', () => { const passingReader = new Fory({ compatible: false, ref: true, - maxContainerMemoryBytes: 108, + maxContainerMemoryBytes: 12, }).register(typeInfo); const failingReader = new Fory({ compatible: false, ref: true, - maxContainerMemoryBytes: 107, + maxContainerMemoryBytes: 11, }).register(typeInfo); expect(() => failingReader.deserialize(bytes)).toThrow( @@ -116,8 +117,8 @@ describe('container memory budget', () => { test('charges map entries', () => { const bytes = serializeAny(new Map([[1, 2]])); - expect(() => deserializeAny(bytes, 99)).toThrow(/maxContainerMemoryBytes/); - expect(deserializeAny(bytes, 100)).toEqual(new Map([[1, 2]])); + expect(() => deserializeAny(bytes, 7)).toThrow(/maxContainerMemoryBytes/); + expect(deserializeAny(bytes, 8)).toEqual(new Map([[1, 2]])); }); test('charges generated containers', () => { @@ -135,12 +136,12 @@ describe('container memory budget', () => { const passingReader = new Fory({ compatible: false, ref: true, - maxContainerMemoryBytes: 156, + maxContainerMemoryBytes: 16, }).register(typeInfo); const failingReader = new Fory({ compatible: false, ref: true, - maxContainerMemoryBytes: 155, + maxContainerMemoryBytes: 15, }).register(typeInfo); expect(() => failingReader.deserialize(bytes)).toThrow( @@ -164,11 +165,11 @@ describe('container memory budget', () => { const bytes = writer.register(writerType).serialize({ values: [1, 2, 3] }); const passingReader = new Fory({ compatible: true, - maxContainerMemoryBytes: 28, + maxContainerMemoryBytes: 12, }).register(readerType); const failingReader = new Fory({ compatible: true, - maxContainerMemoryBytes: 27, + maxContainerMemoryBytes: 11, }).register(readerType); expect(() => failingReader.deserialize(bytes)).toThrow( diff --git a/python/pyfory/collection.pxi b/python/pyfory/collection.pxi index 6dd5c5c4dc..c4dd89a0b4 100644 --- a/python/pyfory/collection.pxi +++ b/python/pyfory/collection.pxi @@ -41,6 +41,7 @@ cdef int8_t NULL_KEY_VALUE_DECL_TYPE = KEY_HAS_NULL | VALUE_DECL_TYPE cdef int8_t NULL_KEY_VALUE_DECL_TYPE_TRACKING_REF = KEY_HAS_NULL | VALUE_DECL_TYPE | TRACKING_VALUE_REF cdef int8_t NULL_VALUE_KEY_DECL_TYPE = VALUE_HAS_NULL | KEY_DECL_TYPE cdef int8_t NULL_VALUE_KEY_DECL_TYPE_TRACKING_REF = VALUE_HAS_NULL | KEY_DECL_TYPE | TRACKING_KEY_REF +cdef int64_t _REFERENCE_BYTES = sizeof(PyObject*) ctypedef PyObject *PyObjectPtr cdef class ListSerializer @@ -467,23 +468,21 @@ cdef class ListSerializer(CollectionSerializer): cdef int32_t ref_id cdef int64_t i cdef int64_t container_bytes - + cdef int64_t remaining_container_memory_bytes if len_ == 0: - container_bytes = read_context.remaining_container_memory_bytes - _COLLECTION_OBJECT_BYTES - if container_bytes < 0: - read_context.reserve_container_memory_fast(_COLLECTION_OBJECT_BYTES) - else: - read_context.remaining_container_memory_bytes = container_bytes list_ = PyList_New(0) return list_ if len_ < 0: - read_context.reserve_collection_memory_c(len_) + raise ValueError("Container element count is negative") else: - container_bytes = _COLLECTION_OBJECT_BYTES + len_ * _REFERENCE_BYTES - if container_bytes > read_context.remaining_container_memory_bytes: + container_bytes = len_ * sizeof(PyObject*) + remaining_container_memory_bytes = read_context.remaining_container_memory_bytes + if container_bytes > remaining_container_memory_bytes: read_context.reserve_container_memory_fast(container_bytes) else: - read_context.remaining_container_memory_bytes -= container_bytes + read_context.remaining_container_memory_bytes = ( + remaining_container_memory_bytes - container_bytes + ) read_context.check_readable_bytes(len_) collect_flag = buffer.read_int8() @@ -598,23 +597,21 @@ cdef class TupleSerializer(CollectionSerializer): cdef int8_t head_flag cdef int64_t i cdef int64_t container_bytes - + cdef int64_t remaining_container_memory_bytes if len_ == 0: - container_bytes = read_context.remaining_container_memory_bytes - _COLLECTION_OBJECT_BYTES - if container_bytes < 0: - read_context.reserve_container_memory_fast(_COLLECTION_OBJECT_BYTES) - else: - read_context.remaining_container_memory_bytes = container_bytes tuple_ = PyTuple_New(0) return tuple_ if len_ < 0: - read_context.reserve_collection_memory_c(len_) + raise ValueError("Container element count is negative") else: - container_bytes = _COLLECTION_OBJECT_BYTES + len_ * _REFERENCE_BYTES - if container_bytes > read_context.remaining_container_memory_bytes: + container_bytes = len_ * sizeof(PyObject*) + remaining_container_memory_bytes = read_context.remaining_container_memory_bytes + if container_bytes > remaining_container_memory_bytes: read_context.reserve_container_memory_fast(container_bytes) else: - read_context.remaining_container_memory_bytes -= container_bytes + read_context.remaining_container_memory_bytes = ( + remaining_container_memory_bytes - container_bytes + ) read_context.check_readable_bytes(len_) collect_flag = buffer.read_int8() @@ -730,25 +727,24 @@ cdef class SetSerializer(CollectionSerializer): cdef int32_t ref_id cdef int64_t i cdef int64_t container_bytes + cdef int64_t remaining_container_memory_bytes len_ = buffer.read_var_uint32() if len_ == 0: - container_bytes = read_context.remaining_container_memory_bytes - _COLLECTION_OBJECT_BYTES - if container_bytes < 0: - read_context.reserve_container_memory_fast(_COLLECTION_OBJECT_BYTES) - else: - read_context.remaining_container_memory_bytes = container_bytes instance = set() read_context.reference(instance) return instance if len_ < 0: - read_context.reserve_collection_memory_c(len_) + raise ValueError("Container element count is negative") else: - container_bytes = _COLLECTION_OBJECT_BYTES + len_ * _REFERENCE_BYTES - if container_bytes > read_context.remaining_container_memory_bytes: + container_bytes = len_ * sizeof(PyObject*) + remaining_container_memory_bytes = read_context.remaining_container_memory_bytes + if container_bytes > remaining_container_memory_bytes: read_context.reserve_container_memory_fast(container_bytes) else: - read_context.remaining_container_memory_bytes -= container_bytes + read_context.remaining_container_memory_bytes = ( + remaining_container_memory_bytes - container_bytes + ) read_context.check_readable_bytes(len_) instance = set() read_context.reference(instance) @@ -1095,22 +1091,20 @@ cdef class MapSerializer(Serializer): cdef dict map_ cdef int8_t chunk_header = 0 cdef int64_t container_bytes + cdef int64_t remaining_container_memory_bytes if size == 0: - container_bytes = read_context.remaining_container_memory_bytes - _MAP_OBJECT_BYTES - if container_bytes < 0: - read_context.reserve_container_memory_fast(_MAP_OBJECT_BYTES) - else: - read_context.remaining_container_memory_bytes = container_bytes map_ = {} elif size < 0: - read_context.reserve_map_memory_c(size) - map_ = {} + raise ValueError("Map entry count is negative") else: - container_bytes = _MAP_OBJECT_BYTES + size * (_MAP_ENTRY_BYTES + 5 * _REFERENCE_BYTES) - if container_bytes > read_context.remaining_container_memory_bytes: + container_bytes = size * (2 * sizeof(PyObject*)) + remaining_container_memory_bytes = read_context.remaining_container_memory_bytes + if container_bytes > remaining_container_memory_bytes: read_context.reserve_container_memory_fast(container_bytes) else: - read_context.remaining_container_memory_bytes -= container_bytes + read_context.remaining_container_memory_bytes = ( + remaining_container_memory_bytes - container_bytes + ) read_context.check_readable_bytes(size) chunk_header = read_context.read_uint8() map_ = _PyDict_NewPresized(size) diff --git a/python/pyfory/collection.py b/python/pyfory/collection.py index c2e2e2a058..938c3663e8 100644 --- a/python/pyfory/collection.py +++ b/python/pyfory/collection.py @@ -23,6 +23,8 @@ fallback only. """ +import struct + from pyfory.serialization import ENABLE_FORY_CYTHON_SERIALIZATION from pyfory._serializer import Serializer, StringSerializer from pyfory.resolver import NOT_NULL_VALUE_FLAG, NULL_FLAG @@ -33,6 +35,7 @@ COLL_HAS_NULL = 0b10 COLL_IS_DECL_ELEMENT_TYPE = 0b100 COLL_IS_SAME_TYPE = 0b1000 +_REFERENCE_BYTES = struct.calcsize("P") def _needs_element_type_info(type_id): @@ -176,7 +179,7 @@ def _write_different_types(self, write_context, value, collect_flag=0): def read(self, read_context): length = read_context.read_var_uint32() - read_context.reserve_collection_memory(length) + read_context.reserve_container_memory(length * _REFERENCE_BYTES) if length != 0: read_context.check_readable_bytes(length) collection_ = self.new_instance(read_context, self.type_) @@ -458,7 +461,7 @@ def write(self, write_context, obj): def read(self, read_context): size = read_context.read_var_uint32() - read_context.reserve_map_memory(size) + read_context.reserve_container_memory(size * 2 * _REFERENCE_BYTES) if size != 0: read_context.check_readable_bytes(size) map_ = {} diff --git a/python/pyfory/context.pxi b/python/pyfory/context.pxi index a27084d466..ed45820ec1 100644 --- a/python/pyfory/context.pxi +++ b/python/pyfory/context.pxi @@ -33,10 +33,6 @@ cdef int32_t MAX_CACHED_META_STRING_LENGTH = 2048 cdef int64_t _KNOWN_ROOT_BUDGET_MULTIPLIER = 8 cdef int64_t _KNOWN_ROOT_BUDGET_SLACK_BYTES = 64 * 1024 cdef int64_t _STREAM_ROOT_BUDGET_BYTES = 128 * 1024 * 1024 -cdef int64_t _COLLECTION_OBJECT_BYTES = 56 -cdef int64_t _MAP_OBJECT_BYTES = 64 -cdef int64_t _MAP_ENTRY_BYTES = 32 -cdef int64_t _REFERENCE_BYTES = sizeof(PyObject*) cdef int64_t _MAX_CONTAINER_MEMORY_BYTES = 9223372036854775807 @@ -864,27 +860,19 @@ cdef class ReadContext: cpdef inline reserve_container_memory(self, int64_t num_bytes): self.reserve_container_memory_c(num_bytes) - cdef inline void reserve_collection_memory_c(self, int64_t num_elements): - if num_elements < 0: - raise ValueError("Container element count is negative") - if num_elements > (_MAX_CONTAINER_MEMORY_BYTES - _COLLECTION_OBJECT_BYTES) // _REFERENCE_BYTES: - raise ValueError("Estimated container memory overflow") - self.reserve_container_memory_c(_COLLECTION_OBJECT_BYTES + num_elements * _REFERENCE_BYTES) - - cpdef inline reserve_collection_memory(self, int64_t num_elements): - self.reserve_collection_memory_c(num_elements) - - cdef inline void reserve_map_memory_c(self, int64_t num_elements): - cdef int64_t bytes_per_entry - if num_elements < 0: - raise ValueError("Map entry count is negative") - bytes_per_entry = _MAP_ENTRY_BYTES + 5 * _REFERENCE_BYTES - if num_elements > (_MAX_CONTAINER_MEMORY_BYTES - _MAP_OBJECT_BYTES) // bytes_per_entry: + cdef inline void reserve_counted_container_memory_c( + self, + int64_t count, + int64_t element_bytes, + ): + if count < 0 or element_bytes < 0: + raise ValueError("Estimated container memory is negative") + if element_bytes != 0 and count > _MAX_CONTAINER_MEMORY_BYTES // element_bytes: raise ValueError("Estimated container memory overflow") - self.reserve_container_memory_c(_MAP_OBJECT_BYTES + num_elements * bytes_per_entry) + self.reserve_container_memory_c(count * element_bytes) - cpdef inline reserve_map_memory(self, int64_t num_elements): - self.reserve_map_memory_c(num_elements) + cpdef inline reserve_counted_container_memory(self, int64_t count, int64_t element_bytes): + self.reserve_counted_container_memory_c(count, element_bytes) cpdef inline add_context_object(self, key, obj): self.context_objects[id(key)] = obj diff --git a/python/pyfory/context.py b/python/pyfory/context.py index 8b620629da..f53384e31c 100644 --- a/python/pyfory/context.py +++ b/python/pyfory/context.py @@ -17,8 +17,6 @@ from __future__ import annotations -import struct - from pyfory.serialization import Config from pyfory.lib import mmh3 from pyfory.meta.metastring import Encoding @@ -42,10 +40,6 @@ _KNOWN_ROOT_BUDGET_MULTIPLIER = 8 _KNOWN_ROOT_BUDGET_SLACK_BYTES = 64 * 1024 _STREAM_ROOT_BUDGET_BYTES = 128 * 1024 * 1024 -_COLLECTION_OBJECT_BYTES = 56 -_MAP_OBJECT_BYTES = 64 -_MAP_ENTRY_BYTES = 32 -_REFERENCE_BYTES = struct.calcsize("P") _MAX_CONTAINER_MEMORY_BYTES = (1 << 63) - 1 @@ -587,20 +581,12 @@ def reserve_container_memory(self, num_bytes): ) self.remaining_container_memory_bytes = remaining - num_bytes - def reserve_collection_memory(self, num_elements): - if num_elements < 0: - raise ValueError("Container element count is negative") - if num_elements > (_MAX_CONTAINER_MEMORY_BYTES - _COLLECTION_OBJECT_BYTES) // _REFERENCE_BYTES: - raise ValueError("Estimated container memory overflow") - self.reserve_container_memory(_COLLECTION_OBJECT_BYTES + num_elements * _REFERENCE_BYTES) - - def reserve_map_memory(self, num_elements): - if num_elements < 0: - raise ValueError("Map entry count is negative") - bytes_per_entry = _MAP_ENTRY_BYTES + 5 * _REFERENCE_BYTES - if num_elements > (_MAX_CONTAINER_MEMORY_BYTES - _MAP_OBJECT_BYTES) // bytes_per_entry: + def reserve_counted_container_memory(self, count, element_bytes): + if count < 0 or element_bytes < 0: + raise ValueError("Estimated container memory is negative") + if element_bytes and count > _MAX_CONTAINER_MEMORY_BYTES // element_bytes: raise ValueError("Estimated container memory overflow") - self.reserve_container_memory(_MAP_OBJECT_BYTES + num_elements * bytes_per_entry) + self.reserve_container_memory(count * element_bytes) def add_context_object(self, key, obj): self.context_objects[id(key)] = obj diff --git a/python/pyfory/serializer.py b/python/pyfory/serializer.py index 8ed4aa2255..17ce24063f 100644 --- a/python/pyfory/serializer.py +++ b/python/pyfory/serializer.py @@ -24,6 +24,7 @@ import marshal import os import pickle +import struct import types from typing import Tuple @@ -42,6 +43,7 @@ ) _WINDOWS = os.name == "nt" +_REFERENCE_BYTES = struct.calcsize("P") from pyfory.serialization import ENABLE_FORY_CYTHON_SERIALIZATION from pyfory.types import TypeId @@ -933,7 +935,7 @@ def read(self, read_context): if dtype.kind == "O": length = read_context.read_varint32() _check_non_negative_size(length, "ndarray object") - read_context.reserve_collection_memory(length) + read_context.reserve_container_memory(length * _REFERENCE_BYTES) read_context.check_readable_bytes(length) items = [read_context.read_ref() for _ in range(length)] return np.array(items, dtype=object) diff --git a/python/pyfory/tests/test_container_memory_budget.py b/python/pyfory/tests/test_container_memory_budget.py index 09069d412b..ae2bc15288 100644 --- a/python/pyfory/tests/test_container_memory_budget.py +++ b/python/pyfory/tests/test_container_memory_budget.py @@ -33,9 +33,6 @@ KNOWN_ROOT_BUDGET_MULTIPLIER = 8 KNOWN_ROOT_BUDGET_SLACK_BYTES = 64 * 1024 STREAM_ROOT_BUDGET_BYTES = 128 * 1024 * 1024 -COLLECTION_OBJECT_BYTES = 56 -MAP_OBJECT_BYTES = 64 -MAP_ENTRY_BYTES = 32 REFERENCE_BYTES = struct.calcsize("P") MAX_CONTAINER_MEMORY_BYTES = (1 << 63) - 1 @@ -85,11 +82,11 @@ def recv_into(self, buffer, size=-1): def collection_memory(num_elements): - return COLLECTION_OBJECT_BYTES + num_elements * REFERENCE_BYTES + return num_elements * REFERENCE_BYTES def map_memory(num_entries): - return MAP_OBJECT_BYTES + num_entries * (MAP_ENTRY_BYTES + 5 * REFERENCE_BYTES) + return num_entries * 2 * REFERENCE_BYTES def new_fory(limit=-1, *, xlang=True): @@ -146,15 +143,15 @@ def test_explicit_config_overrides_auto(): assert expect_budget(value, budget) == value -def test_nested_empty_containers_charge_fixed_cost(): +def test_nested_empty_containers_use_parent_storage(): value = [[]] - budget = collection_memory(1) + collection_memory(0) + budget = collection_memory(1) assert expect_budget(value, budget) == value def test_sibling_nested_containers_are_cumulative(): value = [[], [], []] - budget = collection_memory(3) + 3 * collection_memory(0) + budget = collection_memory(3) assert expect_budget(value, budget) == value @@ -165,9 +162,9 @@ def test_map_entry_budget_and_overflow(): fory = new_fory(xlang=False) try: fory.read_context.prepare(Buffer(b""), root_input_bytes=0) - max_map_entries = (MAX_CONTAINER_MEMORY_BYTES - MAP_OBJECT_BYTES) // (MAP_ENTRY_BYTES + 5 * REFERENCE_BYTES) + max_map_entries = MAX_CONTAINER_MEMORY_BYTES // (2 * REFERENCE_BYTES) with pytest.raises(ValueError, match="Estimated container memory overflow"): - fory.read_context.reserve_map_memory(max_map_entries + 1) + fory.read_context.reserve_counted_container_memory(max_map_entries + 1, 2 * REFERENCE_BYTES) finally: fory.reset_read() diff --git a/rust/fory-core/src/context.rs b/rust/fory-core/src/context.rs index f36e150d2f..aabfc6c8a3 100644 --- a/rust/fory-core/src/context.rs +++ b/rust/fory-core/src/context.rs @@ -33,9 +33,6 @@ use std::rc::Rc; const KNOWN_ROOT_BUDGET_MULTIPLIER: usize = 8; const KNOWN_ROOT_BUDGET_SLACK_BYTES: usize = 64 * 1024; -const VEC_OBJECT_BYTES: usize = mem::size_of::>(); -const MAP_ENTRY_OVERHEAD_BYTES: usize = 16; -const REFERENCE_SLOT_BYTES: usize = mem::size_of::(); const MAX_CONTAINER_LEN: usize = u32::MAX as usize; /// Thread-local context cache with fast path for single Fory instance. @@ -481,35 +478,13 @@ impl<'a> ReadContext<'a> { } #[inline(always)] - pub(crate) fn reserve_vec_memory(&mut self, len: u32) -> Result { - let len = len as usize; - self.reserve_counted_memory(len, VEC_OBJECT_BYTES, mem::size_of::())?; - Ok(len) - } - - #[inline(always)] - pub(crate) fn reserve_collection_memory(&mut self, len: u32) -> Result { - let len = len as usize; - let elem_size = mem::size_of::(); - if elem_size > usize::MAX - REFERENCE_SLOT_BYTES { - return Err(container_memory_overflow(len, elem_size)); - } - let elem_bytes = elem_size + REFERENCE_SLOT_BYTES; - self.reserve_counted_memory(len, mem::size_of::(), elem_bytes)?; - Ok(len) - } - - #[inline(always)] - pub(crate) fn reserve_map_memory(&mut self, len: u32) -> Result { + pub(crate) fn reserve_counted_container_memory( + &mut self, + len: u32, + elem_bytes: usize, + ) -> Result { let len = len as usize; - let key_size = mem::size_of::(); - let value_size = mem::size_of::(); - let overhead = MAP_ENTRY_OVERHEAD_BYTES + REFERENCE_SLOT_BYTES * 3; - if key_size > usize::MAX - value_size || key_size + value_size > usize::MAX - overhead { - return Err(container_memory_overflow(len, key_size)); - } - let elem_bytes = key_size + value_size + overhead; - self.reserve_counted_memory(len, mem::size_of::(), elem_bytes)?; + self.reserve_counted_memory(len, elem_bytes)?; Ok(len) } @@ -528,19 +503,14 @@ impl<'a> ReadContext<'a> { } #[inline(always)] - fn reserve_counted_memory( - &mut self, - len: usize, - fixed_bytes: usize, - elem_bytes: usize, - ) -> Result<(), Error> { + fn reserve_counted_memory(&mut self, len: usize, elem_bytes: usize) -> Result<(), Error> { if len == 0 { - return self.reserve_container_bytes(fixed_bytes); + return Ok(()); } - if elem_bytes <= (usize::MAX - fixed_bytes) / MAX_CONTAINER_LEN { - return self.reserve_container_bytes(len * elem_bytes + fixed_bytes); + if elem_bytes <= usize::MAX / MAX_CONTAINER_LEN { + return self.reserve_container_bytes(len * elem_bytes); } - self.reserve_counted_memory_checked(len, fixed_bytes, elem_bytes) + self.reserve_counted_memory_checked(len, elem_bytes) } #[cold] @@ -548,14 +518,9 @@ impl<'a> ReadContext<'a> { fn reserve_counted_memory_checked( &mut self, len: usize, - fixed_bytes: usize, elem_bytes: usize, ) -> Result<(), Error> { - let elem_total = match len.checked_mul(elem_bytes) { - Some(bytes) => bytes, - None => return Err(container_memory_overflow(len, elem_bytes)), - }; - let bytes = match elem_total.checked_add(fixed_bytes) { + let bytes = match len.checked_mul(elem_bytes) { Some(bytes) => bytes, None => return Err(container_memory_overflow(len, elem_bytes)), }; diff --git a/rust/fory-core/src/serializer/codec.rs b/rust/fory-core/src/serializer/codec.rs index 675fc133e9..f34c5b8fd9 100644 --- a/rust/fory-core/src/serializer/codec.rs +++ b/rust/fory-core/src/serializer/codec.rs @@ -1700,7 +1700,7 @@ where fn read_data(context: &mut ReadContext) -> Result, Error> { let len = context.reader.read_var_u32()?; - context.reserve_vec_memory::(len)?; + context.reserve_counted_container_memory(len, std::mem::size_of::())?; if len == 0 { return Ok(Vec::new()); } @@ -1729,7 +1729,7 @@ where remote_field_type: &FieldType, ) -> Result, Error> { let len = context.reader.read_var_u32()?; - context.reserve_vec_memory::(len)?; + context.reserve_counted_container_memory(len, std::mem::size_of::())?; if len == 0 { return Ok(Vec::new()); } @@ -2272,7 +2272,10 @@ where fn read_data(context: &mut ReadContext) -> Result, Error> { let len = context.reader.read_var_u32()?; - context.reserve_map_memory::, K, V>(len)?; + let elem_bytes = std::mem::size_of::() + .checked_add(std::mem::size_of::()) + .ok_or_else(|| Error::invalid_data("container memory estimate overflows"))?; + context.reserve_counted_container_memory(len, elem_bytes)?; if len == 0 { return Ok(HashMap::new()); } @@ -2292,7 +2295,10 @@ where remote_field_type: &FieldType, ) -> Result, Error> { let len = context.reader.read_var_u32()?; - let capacity = context.reserve_map_memory::, K, V>(len)?; + let elem_bytes = std::mem::size_of::() + .checked_add(std::mem::size_of::()) + .ok_or_else(|| Error::invalid_data("container memory estimate overflows"))?; + let capacity = context.reserve_counted_container_memory(len, elem_bytes)?; if len == 0 { return Ok(HashMap::new()); } diff --git a/rust/fory-core/src/serializer/collection.rs b/rust/fory-core/src/serializer/collection.rs index b2dd1950f9..cdda3506bc 100644 --- a/rust/fory-core/src/serializer/collection.rs +++ b/rust/fory-core/src/serializer/collection.rs @@ -239,7 +239,7 @@ where C: FromIterator, { let len = context.reader.read_var_u32()?; - let len_usize = context.reserve_collection_memory::(len)?; + let len_usize = context.reserve_counted_container_memory(len, std::mem::size_of::())?; if len == 0 { return Ok(C::from_iter(std::iter::empty())); } @@ -282,7 +282,7 @@ where T: Serializer + ForyDefault, { let len = context.reader.read_var_u32()?; - let len_usize = context.reserve_vec_memory::(len)?; + let len_usize = context.reserve_counted_container_memory(len, std::mem::size_of::())?; if len == 0 { return Ok(Vec::new()); } @@ -729,7 +729,7 @@ where { let element_type = generic_field_type(remote_field_type, 0, "list")?; let len = context.reader.read_var_u32()?; - let len_usize = context.reserve_vec_memory::(len)?; + let len_usize = context.reserve_counted_container_memory(len, std::mem::size_of::())?; if len == 0 { return Ok(Vec::new()); } diff --git a/rust/fory-core/src/serializer/map.rs b/rust/fory-core/src/serializer/map.rs index 158e020edc..b732f348df 100644 --- a/rust/fory-core/src/serializer/map.rs +++ b/rust/fory-core/src/serializer/map.rs @@ -553,7 +553,10 @@ impl Result { let len = context.reader.read_var_u32()?; - let capacity = context.reserve_map_memory::, K, V>(len)?; + let elem_bytes = std::mem::size_of::() + .checked_add(std::mem::size_of::()) + .ok_or_else(|| Error::invalid_data("container memory estimate overflows"))?; + let capacity = context.reserve_counted_container_memory(len, elem_bytes)?; if len == 0 { return Ok(HashMap::new()); } @@ -706,7 +709,10 @@ impl Result { let len = context.reader.read_var_u32()?; - let len_usize = context.reserve_map_memory::, K, V>(len)?; + let elem_bytes = std::mem::size_of::() + .checked_add(std::mem::size_of::()) + .ok_or_else(|| Error::invalid_data("container memory estimate overflows"))?; + let len_usize = context.reserve_counted_container_memory(len, elem_bytes)?; if len == 0 { return Ok(BTreeMap::new()); } diff --git a/rust/tests/tests/test_container_memory_budget.rs b/rust/tests/tests/test_container_memory_budget.rs index 29f70d10bf..8759ee1d82 100644 --- a/rust/tests/tests/test_container_memory_budget.rs +++ b/rust/tests/tests/test_container_memory_budget.rs @@ -93,7 +93,7 @@ fn config_validation() { #[test] fn known_auto_budget() { - let value = compact_empty_lists(3000); + let value = compact_empty_lists(12000); let writer = fory_with_budget(-1); let bytes = writer.serialize(&value).unwrap(); let auto_limit = bytes.len() * 8 + 64 * 1024; @@ -110,7 +110,7 @@ fn known_auto_budget() { #[test] fn reader_known_auto_budget() { - let value = compact_empty_lists(3000); + let value = compact_empty_lists(12000); let writer = fory_with_budget(-1); let bytes = writer.serialize(&value).unwrap(); let auto_limit = bytes.len() * 8 + 64 * 1024; @@ -124,52 +124,60 @@ fn reader_known_auto_budget() { #[test] fn explicit_override() { - let value = compact_empty_lists(3000); + let value = compact_empty_lists(12000); let writer = fory_with_budget(-1); let bytes = writer.serialize(&value).unwrap(); assert!(writer.deserialize::>>(&bytes).is_err()); let vec_bytes = std::mem::size_of::>(); - let estimate = std::mem::size_of::>>() + value.len() * vec_bytes * 2; + let estimate = value.len() * vec_bytes; let explicit = fory_with_budget(estimate as i64); let decoded: Vec> = explicit.deserialize(&bytes).unwrap(); assert_eq!(decoded, value); } #[test] -fn empty_container_cost() { +fn empty_container_has_no_dynamic_storage() { let value: Vec = Vec::new(); let writer = fory_with_budget(-1); let bytes = writer.serialize(&value).unwrap(); - let fixed = std::mem::size_of::>() as i64; - let limited = fory_with_budget(fixed - 1); - assert!(limited.deserialize::>(&bytes).is_err()); + let limited = fory_with_budget(1); + let decoded: Vec = limited.deserialize(&bytes).unwrap(); + assert!(decoded.is_empty()); } #[test] fn sibling_cumulative_budget() { let value = BudgetSiblings { - first: Vec::new(), - second: Vec::new(), + first: vec!["a".to_string()], + second: vec!["b".to_string()], }; let writer = fory_with_budget(-1); let bytes = writer.serialize(&value).unwrap(); - let one_vec = std::mem::size_of::>() as i64; + let one_vec = std::mem::size_of::() as i64; let limited = fory_with_budget(one_vec); assert!(limited.deserialize::(&bytes).is_err()); + let enough = fory_with_budget(one_vec * 2); + assert_eq!(enough.deserialize::(&bytes).unwrap(), value); } #[test] fn map_budget() { - let value: HashMap = HashMap::new(); + let value: HashMap = HashMap::from([("a".to_string(), 1)]); let writer = fory_with_budget(-1); let bytes = writer.serialize(&value).unwrap(); - let fixed = std::mem::size_of::>() as i64; + let required = (std::mem::size_of::() + std::mem::size_of::()) as i64; - let limited = fory_with_budget(fixed - 1); + let limited = fory_with_budget(required - 1); assert!(limited.deserialize::>(&bytes).is_err()); + assert_eq!( + fory_with_budget(required) + .deserialize::>(&bytes) + .unwrap(), + value + ); } #[test] @@ -182,8 +190,7 @@ fn inline_value_vec_budget() { .collect::>(); let writer = fory_with_budget(-1); let bytes = writer.serialize(&value).unwrap(); - let under_inline = - std::mem::size_of::>() + value.len() * std::mem::size_of::(); + let under_inline = value.len() * std::mem::size_of::(); let limited = fory_with_budget(under_inline as i64); assert!(limited.deserialize::>(&bytes).is_err()); @@ -197,7 +204,7 @@ fn compatible_list_array_budget() { let writer = compatible_fory::(-1); let bytes = writer.serialize(&value).unwrap(); - let limited = compatible_fory::(std::mem::size_of::>() as i64); + let limited = compatible_fory::((64 * std::mem::size_of::() - 1) as i64); assert!(limited.deserialize::(&bytes).is_err()); let enough = compatible_fory::(i64::MAX); diff --git a/swift/Sources/Fory/AnySerializer.swift b/swift/Sources/Fory/AnySerializer.swift index ca31cce2da..9a436ddce4 100644 --- a/swift/Sources/Fory/AnySerializer.swift +++ b/swift/Sources/Fory/AnySerializer.swift @@ -17,6 +17,18 @@ import Foundation +private let anyReferenceBytes = 4 + +@inline(__always) +private func reserveAnyReferenceArrayMemory(_ context: ReadContext, count: Int) throws { + try context.reserveCountedContainerMemory(count: count, elementBytes: anyReferenceBytes) +} + +@inline(__always) +private func reserveAnyReferenceMapMemory(_ context: ReadContext, count: Int) throws { + try context.reserveCountedContainerMemory(count: count, elementBytes: 2 * anyReferenceBytes) +} + public struct ForyAnyNullValue: Serializer { public init() {} @@ -573,7 +585,7 @@ public func readListOfAny( guard let wrapped else { return nil } - try context.reserveReferenceArrayMemory(count: wrapped.count) + try reserveAnyReferenceArrayMemory(context, count: wrapped.count) return wrapped.map { $0.anyValueForCollection() } } @@ -608,7 +620,7 @@ public func readMapStringToAny( guard let wrapped else { return nil } - try context.reserveReferenceMapMemory(count: wrapped.count) + try reserveAnyReferenceMapMemory(context, count: wrapped.count) var map: [String: Any] = [:] map.reserveCapacity(wrapped.count) for pair in wrapped { @@ -648,7 +660,7 @@ public func readMapInt32ToAny( guard let wrapped else { return nil } - try context.reserveReferenceMapMemory(count: wrapped.count) + try reserveAnyReferenceMapMemory(context, count: wrapped.count) var map: [Int32: Any] = [:] map.reserveCapacity(wrapped.count) for pair in wrapped { @@ -688,7 +700,7 @@ public func readMapAnyHashableToAny( guard let wrapped else { return nil } - try context.reserveReferenceMapMemory(count: wrapped.count) + try reserveAnyReferenceMapMemory(context, count: wrapped.count) var map: [AnyHashable: Any] = [:] map.reserveCapacity(wrapped.count) for pair in wrapped { @@ -700,10 +712,10 @@ public func readMapAnyHashableToAny( func readDynamicAnyMapValue(context: ReadContext) throws -> Any { let map = try readMapAnyHashableToAny(context: context, refMode: .none) ?? [:] if map.isEmpty { - try context.reserveReferenceMapMemory(count: 0) + try reserveAnyReferenceMapMemory(context, count: 0) return [String: Any]() } - try context.reserveReferenceMapMemory(count: map.count) + try reserveAnyReferenceMapMemory(context, count: map.count) var stringMap: [String: Any] = [:] stringMap.reserveCapacity(map.count) for pair in map { @@ -717,7 +729,7 @@ func readDynamicAnyMapValue(context: ReadContext) throws -> Any { return stringMap } - try context.reserveReferenceMapMemory(count: map.count) + try reserveAnyReferenceMapMemory(context, count: map.count) var int32Map: [Int32: Any] = [:] int32Map.reserveCapacity(map.count) for pair in map { diff --git a/swift/Sources/Fory/CollectionSerializers.swift b/swift/Sources/Fory/CollectionSerializers.swift index a7b943a4b6..fca71441b4 100644 --- a/swift/Sources/Fory/CollectionSerializers.swift +++ b/swift/Sources/Fory/CollectionSerializers.swift @@ -34,6 +34,41 @@ enum MapHeader { static let declaredValueType: UInt8 = 0b0010_0000 } +private let containerReferenceBytes = 4 + +@inline(__always) +private func containerElementBytes(_ type: Element.Type) -> Int { + type.isRefType ? containerReferenceBytes : max(1, MemoryLayout.stride) +} + +@inline(__always) +private func reserveContainerArrayMemory( + _ context: ReadContext, + _ type: Element.Type, + count: Int +) throws { + try context.reserveCountedContainerMemory( + count: count, + elementBytes: containerElementBytes(type) + ) +} + +@inline(__always) +private func reserveContainerMapMemory( + _ context: ReadContext, + key: Key.Type, + value: Value.Type, + count: Int +) throws { + let keyBytes = containerElementBytes(key) + let valueBytes = containerElementBytes(value) + let (elementBytes, overflow) = keyBytes.addingReportingOverflow(valueBytes) + if overflow { + try context.reserveContainerMemory(-1) + } + try context.reserveCountedContainerMemory(count: count, elementBytes: elementBytes) +} + private func primitiveArrayTypeID(for _: Element.Type) -> TypeId? { if Element.self == UInt8.self { return .uint8Array } if Element.self == Bool.self { return .boolArray } @@ -244,7 +279,7 @@ private func preparePrimitiveArray( ) throws { try context.ensureCollectionLength(count, label: label) if chargeContainerMemory { - try context.reserveArrayMemory(type, count: count) + try reserveContainerArrayMemory(context, type, count: count) } } @@ -549,7 +584,7 @@ extension Array: Serializer where Element: Serializer { let length = Int(try buffer.readVarUInt32()) try context.ensureCollectionLength(length, label: "array") if length == 0 { - try context.reserveArrayMemory(Element.self, count: length) + try reserveContainerArrayMemory(context, Element.self, count: length) return [] } @@ -559,7 +594,7 @@ extension Array: Serializer where Element: Serializer { let declared = (header & CollectionHeader.declaredElementType) != 0 let sameType = (header & CollectionHeader.sameType) != 0 if !sameType { - try context.reserveArrayMemory(Element.self, count: length) + try reserveContainerArrayMemory(context, Element.self, count: length) try context.ensureRemainingBytes(length, label: "array") if trackRef { return try readArrayUninitialized(count: length) { destination in @@ -598,7 +633,7 @@ extension Array: Serializer where Element: Serializer { } let elementTypeInfo = declared ? nil : try Element.foryReadTypeInfo(context) - try context.reserveArrayMemory(Element.self, count: length) + try reserveContainerArrayMemory(context, Element.self, count: length) try context.ensureRemainingBytes(length, label: "array") return try context.withTypeInfo(elementTypeInfo, for: Element.self) { if trackRef { @@ -658,7 +693,7 @@ extension Set: Serializer where Element: Serializer & Hashable { public static func foryReadData(_ context: ReadContext) throws -> Set { let values = try [Element].foryReadData(context) - try context.reserveSetMemory(Element.self, count: values.count) + try reserveContainerArrayMemory(context, Element.self, count: values.count) return Set(values) } } @@ -886,11 +921,11 @@ extension Dictionary: Serializer where Key: Serializer & Hashable, Value: Serial let totalLength = Int(try context.buffer.readVarUInt32()) try context.ensureCollectionLength(totalLength, label: "map") if totalLength == 0 { - try context.reserveMapMemory(key: Key.self, value: Value.self, count: totalLength) + try reserveContainerMapMemory(context, key: Key.self, value: Value.self, count: totalLength) return [:] } - try context.reserveMapMemory(key: Key.self, value: Value.self, count: totalLength) + try reserveContainerMapMemory(context, key: Key.self, value: Value.self, count: totalLength) try context.ensureRemainingBytes(totalLength, label: "map") var map: [Key: Value] = [:] map.reserveCapacity(totalLength) diff --git a/swift/Sources/Fory/FieldCodecs.swift b/swift/Sources/Fory/FieldCodecs.swift index 4eabf460e7..73b7dd753e 100644 --- a/swift/Sources/Fory/FieldCodecs.swift +++ b/swift/Sources/Fory/FieldCodecs.swift @@ -17,6 +17,52 @@ import Foundation +private let fieldReferenceBytes = 4 + +@inline(__always) +private func fieldElementBytes(_ codec: ElementCodec.Type) -> Int { + codec.isRefType ? fieldReferenceBytes : max(1, MemoryLayout.stride) +} + +@inline(__always) +private func serializerElementBytes(_ type: Element.Type) -> Int { + type.isRefType ? fieldReferenceBytes : max(1, MemoryLayout.stride) +} + +@inline(__always) +private func chargeFieldArrayStorage( + _ context: ReadContext, + _ codec: ElementCodec.Type, + count: Int +) throws { + try context.reserveCountedContainerMemory(count: count, elementBytes: fieldElementBytes(codec)) +} + +@inline(__always) +private func reserveSerializerArrayMemory( + _ context: ReadContext, + _ type: Element.Type, + count: Int +) throws { + try context.reserveCountedContainerMemory(count: count, elementBytes: serializerElementBytes(type)) +} + +@inline(__always) +private func chargeFieldMapStorage( + _ context: ReadContext, + key: KeyCodec.Type, + value: ValueCodec.Type, + count: Int +) throws { + let keyBytes = fieldElementBytes(key) + let valueBytes = fieldElementBytes(value) + let (elementBytes, overflow) = keyBytes.addingReportingOverflow(valueBytes) + if overflow { + try context.reserveContainerMemory(-1) + } + try context.reserveCountedContainerMemory(count: count, elementBytes: elementBytes) +} + public protocol FieldCodec { associatedtype Value @@ -842,7 +888,7 @@ public enum SetFieldCodec: FieldCodec where ElementCod public static func readPayload(_ context: ReadContext) throws -> Value { let values = try readCollectionPayload(context, elementCodec: ElementCodec.self) - try context.reserveFieldSetMemory(ElementCodec.self, count: values.count) + try chargeFieldArrayStorage(context, ElementCodec.self, count: values.count) return Set(values) } } @@ -962,11 +1008,11 @@ where KeyCodec.Value: Hashable { let totalLength = Int(try context.buffer.readVarUInt32()) try context.ensureCollectionLength(totalLength, label: "map") if totalLength == 0 { - try context.reserveFieldMapMemory(key: KeyCodec.self, value: ValueCodec.self, count: totalLength) + try chargeFieldMapStorage(context, key: KeyCodec.self, value: ValueCodec.self, count: totalLength) return [:] } - try context.reserveFieldMapMemory(key: KeyCodec.self, value: ValueCodec.self, count: totalLength) + try chargeFieldMapStorage(context, key: KeyCodec.self, value: ValueCodec.self, count: totalLength) try context.ensureRemainingBytes(totalLength, label: "map") var map: Value = [:] map.reserveCapacity(totalLength) @@ -1331,7 +1377,7 @@ private func writeUIntArrayPayload(_ value: [UInt], _ context: WriteContext) { private func readIntArrayPayload(_ context: ReadContext, chargeContainerMemory: Bool = false) throws -> [Int] { let count = try readPackedArrayElementCount(context, width: 8, label: "int64_array") if chargeContainerMemory { - try context.reserveArrayMemory(Int.self, count: count) + try reserveSerializerArrayMemory(context, Int.self, count: count) } var values: [Int] = [] values.reserveCapacity(count) @@ -1344,7 +1390,7 @@ private func readIntArrayPayload(_ context: ReadContext, chargeContainerMemory: private func readUIntArrayPayload(_ context: ReadContext, chargeContainerMemory: Bool = false) throws -> [UInt] { let count = try readPackedArrayElementCount(context, width: 8, label: "uint64_array") if chargeContainerMemory { - try context.reserveArrayMemory(UInt.self, count: count) + try reserveSerializerArrayMemory(context, UInt.self, count: count) } var values: [UInt] = [] values.reserveCapacity(count) @@ -1602,7 +1648,7 @@ private func readCollectionPayload( let length = Int(try buffer.readVarUInt32()) try context.ensureCollectionLength(length, label: "array") if length == 0 { - try context.reserveFieldArrayMemory(ElementCodec.self, count: length) + try chargeFieldArrayStorage(context, ElementCodec.self, count: length) return [] } @@ -1617,7 +1663,7 @@ private func readCollectionPayload( let sameType = (header & CollectionHeader.sameType) != 0 var result: [ElementCodec.Value] = [] - try context.reserveFieldArrayMemory(ElementCodec.self, count: length) + try chargeFieldArrayStorage(context, ElementCodec.self, count: length) try context.ensureRemainingBytes(length, label: "array") result.reserveCapacity(length) @@ -1703,7 +1749,7 @@ private func readListPayloadAsArrayPayload( let length = Int(try buffer.readVarUInt32()) try context.ensureCollectionLength(length, label: "array") if length == 0 { - try context.reserveFieldArrayMemory(ElementCodec.self, count: length) + try chargeFieldArrayStorage(context, ElementCodec.self, count: length) return [] } @@ -1731,7 +1777,7 @@ private func readListPayloadAsArrayPayload( } try context.ensureRemainingBytes(length, label: "array") var result: [ElementCodec.Value] = [] - try context.reserveFieldArrayMemory(ElementCodec.self, count: length) + try chargeFieldArrayStorage(context, ElementCodec.self, count: length) result.reserveCapacity(length) return try ElementCodec.withTypeInfo(elementTypeInfo, context) { for _ in 0..(_ type: Element.Type, count: Int) throws { - try reserveArrayMemory(count: count, elementBytes: containerElementBytes(type)) - } - - @inline(__always) - func reserveFieldArrayMemory( - _ codec: ElementCodec.Type, - count: Int - ) throws { - try reserveArrayMemory(count: count, elementBytes: fieldElementBytes(codec)) - } - - @inline(__always) - func reserveReferenceArrayMemory(count: Int) throws { - try reserveArrayMemory(count: count, elementBytes: Self.referenceBytes) - } - - @inline(__always) - func reserveSetMemory(_ type: Element.Type, count: Int) throws { - try reserveSetMemory(count: count, elementBytes: containerElementBytes(type)) - } - - @inline(__always) - func reserveFieldSetMemory( - _ codec: ElementCodec.Type, - count: Int - ) throws { - try reserveSetMemory(count: count, elementBytes: fieldElementBytes(codec)) - } - - @inline(__always) - func reserveMapMemory( - key _: Key.Type, - value _: Value.Type, - count: Int - ) throws { - try reserveMapMemory( - count: count, - keyBytes: containerElementBytes(Key.self), - valueBytes: containerElementBytes(Value.self) - ) - } - - @inline(__always) - func reserveFieldMapMemory( - key _: KeyCodec.Type, - value _: ValueCodec.Type, - count: Int - ) throws { - try reserveMapMemory( - count: count, - keyBytes: fieldElementBytes(KeyCodec.self), - valueBytes: fieldElementBytes(ValueCodec.self) - ) - } - - @inline(__always) - func reserveReferenceMapMemory(count: Int) throws { - try reserveMapMemory(count: count, keyBytes: Self.referenceBytes, valueBytes: Self.referenceBytes) - } - - @inline(__always) - private func reserveArrayMemory(count: Int, elementBytes: Int) throws { - if count == 0 { - try reserveContainerMemory(Self.containerFixedBytes) - return - } - try reserveCountedContainerMemory( - count: count, - fixedBytes: Self.containerFixedBytes + Self.arrayHeaderBytes, - elementBytes: elementBytes - ) - } - - @inline(__always) - private func reserveSetMemory(count: Int, elementBytes: Int) throws { - if count == 0 { - try reserveContainerMemory(Self.containerFixedBytes) - return - } - try reserveCountedContainerMemory( - count: count, - fixedBytes: Self.containerFixedBytes + Self.arrayHeaderBytes, - elementBytes: elementBytes + Self.collectionEntryOverheadBytes + Self.referenceBytes * 2 - ) - } - - @inline(__always) - private func reserveMapMemory(count: Int, keyBytes: Int, valueBytes: Int) throws { - if count == 0 { - try reserveContainerMemory(Self.containerFixedBytes) - return + func reserveContainerMemory(_ bytes: Int) throws { + if bytes < 0 { + try throwContainerMemoryOverflow() } - try reserveCountedContainerMemory( - count: count, - fixedBytes: Self.containerFixedBytes + Self.arrayHeaderBytes * 2, - elementBytes: keyBytes + valueBytes + Self.mapEntryOverheadBytes + Self.referenceBytes - ) - } - - @inline(__always) - private func reserveContainerMemory(_ bytes: Int) throws { if bytes > remainingContainerMemoryBytes { try throwContainerMemoryExceeded(bytes: bytes) } @@ -185,25 +82,17 @@ public final class ReadContext { } @inline(__always) - private func reserveCountedContainerMemory( + func reserveCountedContainerMemory( count: Int, - fixedBytes: Int, elementBytes: Int ) throws { - if count > (Int.max - fixedBytes) / elementBytes { + if count < 0 || elementBytes < 0 { try throwContainerMemoryOverflow() } - try reserveContainerMemory(count * elementBytes + fixedBytes) - } - - @inline(__always) - private func containerElementBytes(_ type: Element.Type) -> Int { - type.isRefType ? Self.referenceBytes : max(1, MemoryLayout.stride) - } - - @inline(__always) - private func fieldElementBytes(_ codec: ElementCodec.Type) -> Int { - codec.isRefType ? Self.referenceBytes : max(1, MemoryLayout.stride) + if elementBytes != 0 && count > Int.max / elementBytes { + try throwContainerMemoryOverflow() + } + try reserveContainerMemory(count * elementBytes) } @inline(never) diff --git a/swift/Tests/ForyTests/ContainerMemoryBudgetTests.swift b/swift/Tests/ForyTests/ContainerMemoryBudgetTests.swift index 3650f55f48..18b9908157 100644 --- a/swift/Tests/ForyTests/ContainerMemoryBudgetTests.swift +++ b/swift/Tests/ForyTests/ContainerMemoryBudgetTests.swift @@ -56,16 +56,14 @@ private func makeBudgetFory(maxContainerMemoryBytes: Int64 = -1) -> Fory { return fory } +private let testReferenceBytes = 4 + private func elementBytes(_ type: Element.Type) -> Int { - type.isRefType ? ReadContext.referenceBytes : max(1, MemoryLayout.stride) + type.isRefType ? testReferenceBytes : max(1, MemoryLayout.stride) } private func arrayBudget(_ type: Element.Type, count: Int) -> Int { - if count == 0 { - return ReadContext.containerFixedBytes - } - return ReadContext.containerFixedBytes + ReadContext.arrayHeaderBytes + - count * elementBytes(type) + count * elementBytes(type) } private func mapBudget( @@ -73,14 +71,7 @@ private func mapBudget( value: Value.Type, count: Int ) -> Int { - if count == 0 { - return ReadContext.containerFixedBytes - } - return ReadContext.containerFixedBytes + ReadContext.arrayHeaderBytes * 2 + - count * ( - elementBytes(key) + elementBytes(value) + - ReadContext.mapEntryOverheadBytes + ReadContext.referenceBytes - ) + count * (elementBytes(key) + elementBytes(value)) } private func expectInvalidData(_ body: () throws -> Void) { @@ -94,33 +85,31 @@ private func expectInvalidData(_ body: () throws -> Void) { } @Test -func knownLengthAutoBudgetRejectsNestedEmptyArrays() throws { - let count = 6_000 - let value = Array(repeating: [String](), count: count) - let bytes = try makeBudgetFory().serialize(value) - let autoLimit = bytes.count * 8 + ReadContext.knownContainerBudgetSlackBytes - let required = arrayBudget([String].self, count: count) + - count * arrayBudget(String.self, count: 0) - #expect(required > autoLimit) +func knownLengthAutoBudgetUsesInputBytes() throws { + let expected = 17 * 8 + ReadContext.knownContainerBudgetSlackBytes + let config = Config(trackRef: false, compatible: false) + let context = ReadContext( + buffer: ByteBuffer(), + typeResolver: TypeResolver(config: config), + config: config + ) + try context.initContainerMemoryBudgetKnown(rootBytes: 17) + try context.reserveContainerMemory(expected) expectInvalidData { - let _: [[String]] = try makeBudgetFory().deserialize(bytes) + try context.reserveContainerMemory(testReferenceBytes) } - - let decoded: [[String]] = try makeBudgetFory(maxContainerMemoryBytes: Int64(required)).deserialize(bytes) - #expect(decoded.count == count) } @Test func byteBufferRootUsesKnownLengthAutoBudget() throws { - let count = 6_000 + let count = 6 let value = Array(repeating: [String](), count: count) let bytes = try makeBudgetFory().serialize(value) let buffer = ByteBuffer(data: bytes) - expectInvalidData { - let _: [[String]] = try makeBudgetFory().deserialize(from: buffer) - } + let decoded: [[String]] = try makeBudgetFory().deserialize(from: buffer) + #expect(decoded.count == count) } @Test @@ -200,16 +189,11 @@ func stringBinaryAndPrimitiveDenseArrayOwnersAreSkipped() throws { } @Test -func dynamicAnyEmptyMapChargesFixedCost() throws { +func dynamicAnyEmptyMapHasNoDynamicStorage() throws { let value = [:] as [AnyHashable: Any] let bytes = try makeBudgetFory().serialize(value as Any) - let required = ReadContext.containerFixedBytes * 3 - expectInvalidData { - let _: Any = try makeBudgetFory(maxContainerMemoryBytes: Int64(required - 1)) - .deserialize(bytes) - } - let decoded: Any = try makeBudgetFory(maxContainerMemoryBytes: Int64(required)) + let decoded: Any = try makeBudgetFory(maxContainerMemoryBytes: 1) .deserialize(bytes) #expect((decoded as? [String: Any])?.isEmpty == true) } From 21134f41267215de0efff16570bae22eb05836a8 Mon Sep 17 00:00:00 2001 From: chaokunyang Date: Tue, 30 Jun 2026 14:16:19 +0800 Subject: [PATCH 06/54] fix: enforce generated container memory budgets --- .../src/Fory.Generator/ForyModelGenerator.cs | 22 ++++++++ csharp/src/Fory/ReadContext.cs | 1 - .../Fory.Tests/ContainerMemoryBudgetTests.cs | 52 ++++++++++++++++++- .../kotlin/CollectionSerializerTest.kt | 2 +- .../scala/CollectionSerializerTest.scala | 10 ++-- .../scala/ScalaXlangSerializerTest.scala | 4 +- 6 files changed, 81 insertions(+), 10 deletions(-) diff --git a/csharp/src/Fory.Generator/ForyModelGenerator.cs b/csharp/src/Fory.Generator/ForyModelGenerator.cs index de535233c2..d819197470 100644 --- a/csharp/src/Fory.Generator/ForyModelGenerator.cs +++ b/csharp/src/Fory.Generator/ForyModelGenerator.cs @@ -1517,6 +1517,8 @@ private static void EmitReadPackedArrayPayload( } else { + string elementBytesExpr = ContainerElementBytesExpr(PackedArrayElementTypeName(codec.TypeId)); + sb.AppendLine($"{indent}context.ReserveCountedContainerMemory({countVar}, {elementBytesExpr});"); sb.AppendLine($"{indent}{codec.TypeName} {targetVar} = new({countVar});"); } @@ -1556,6 +1558,7 @@ private static void EmitReadCollectionPayload( string sameTypeVar = $"__forySameType{id++}"; string declaredVar = $"__foryDeclared{id++}"; sb.AppendLine($"{indent}int {lengthVar} = checked((int)context.Reader.ReadVarUInt32());"); + sb.AppendLine($"{indent}context.ReserveCountedContainerMemory({lengthVar}, {ContainerElementBytesExpr(element)});"); sb.AppendLine($"{indent}if ({lengthVar} != 0)"); sb.AppendLine($"{indent}{{"); sb.AppendLine($"{indent} context.Reader.CheckBound({lengthVar});"); @@ -1661,6 +1664,7 @@ private static void EmitReadMapPayload( FieldCodecModel value = codec.Generics[1]; string totalVar = $"__foryTotal{id++}"; sb.AppendLine($"{indent}int {totalVar} = checked((int)context.Reader.ReadVarUInt32());"); + sb.AppendLine($"{indent}context.ReserveCountedContainerMemory({totalVar}, {ContainerMapElementBytesExpr(key, value)});"); sb.AppendLine($"{indent}if ({totalVar} != 0)"); sb.AppendLine($"{indent}{{"); sb.AppendLine($"{indent} context.Reader.CheckBound({totalVar});"); @@ -1819,6 +1823,24 @@ private static string ElementTypeName(string arrayTypeName) : "object"; } + private static string ContainerElementBytesExpr(FieldCodecModel codec) + { + return ContainerElementBytesExpr( + codec.Nullable && !codec.NullableValueType + ? StripNullableForTypeOf(codec.TypeName) + : codec.TypeName); + } + + private static string ContainerElementBytesExpr(string typeName) + { + return $"(typeof({typeName}).IsValueType ? global::System.Runtime.CompilerServices.Unsafe.SizeOf<{typeName}>() : 4)"; + } + + private static string ContainerMapElementBytesExpr(FieldCodecModel key, FieldCodecModel value) + { + return $"((long){ContainerElementBytesExpr(key)} + {ContainerElementBytesExpr(value)})"; + } + private static string PackedArrayElementTypeName(uint typeId) { return typeId switch diff --git a/csharp/src/Fory/ReadContext.cs b/csharp/src/Fory/ReadContext.cs index 865e2046dc..f1e0ff8fa8 100644 --- a/csharp/src/Fory/ReadContext.cs +++ b/csharp/src/Fory/ReadContext.cs @@ -23,7 +23,6 @@ public sealed class ReadContext { private const int MinRemoteTypeMetaLimit = 8192; internal const long KnownContainerBudgetSlackBytes = 64 * 1024; - internal const long UnknownContainerBudgetBytes = 128L * 1024 * 1024; private readonly ReusableArray _typeMetaRefs = new(); private readonly UInt64Map _typeMetasByHeader = new(); diff --git a/csharp/tests/Fory.Tests/ContainerMemoryBudgetTests.cs b/csharp/tests/Fory.Tests/ContainerMemoryBudgetTests.cs index 40894a379e..6dcd9d9c96 100644 --- a/csharp/tests/Fory.Tests/ContainerMemoryBudgetTests.cs +++ b/csharp/tests/Fory.Tests/ContainerMemoryBudgetTests.cs @@ -19,6 +19,7 @@ using System.Runtime.CompilerServices; using Apache.Fory; using ForyRuntime = Apache.Fory.Fory; +using S = Apache.Fory.Schema.Types; namespace Apache.Fory.Tests; @@ -42,6 +43,27 @@ public sealed class BudgetArrayHolder public BudgetItem[] Values { get; set; } = []; } +[ForyStruct] +public sealed class GeneratedSchemaListBudget +{ + [ForyField(Type = typeof(S.List))] + public List Values { get; set; } = []; +} + +[ForyStruct] +public sealed class GeneratedPackedListBudget +{ + [ForyField(Type = typeof(S.Array))] + public List Values { get; set; } = []; +} + +[ForyStruct] +public sealed class GeneratedSchemaMapBudget +{ + [ForyField(Type = typeof(S.Map))] + public Dictionary Values { get; set; } = []; +} + public sealed class ContainerMemoryBudgetTests { private const int ReferenceBytes = 4; @@ -57,7 +79,10 @@ private static ForyRuntime NewFory(long maxContainerMemoryBytes = -1) .Build() .Register(1001) .Register(1002) - .Register(1003); + .Register(1003) + .Register(1004) + .Register(1005) + .Register(1006); } private static byte[] Serialize(T value) @@ -163,6 +188,31 @@ public void ReferenceArrayAndInlineValueListAreCharged() Assert.Equal(ints, NewFory(listRequired).Deserialize>(intBytes)); } + [Fact] + public void GeneratedSchemaContainersAreCharged() + { + GeneratedSchemaListBudget list = new() { Values = [1, 2, 3, 4, 5, 6] }; + byte[] listBytes = Serialize(list); + long listRequired = ListBudget(list.Values.Count); + Assert.Throws(() => NewFory(listRequired - 1).Deserialize(listBytes)); + Assert.Equal(list.Values, NewFory(listRequired).Deserialize(listBytes).Values); + + GeneratedPackedListBudget packed = new() { Values = [1, 2, 3, 4, 5, 6] }; + byte[] packedBytes = Serialize(packed); + long packedRequired = ListBudget(packed.Values.Count); + Assert.Throws(() => NewFory(packedRequired - 1).Deserialize(packedBytes)); + Assert.Equal(packed.Values, NewFory(packedRequired).Deserialize(packedBytes).Values); + + GeneratedSchemaMapBudget map = new() + { + Values = new Dictionary { [1] = 1, [2] = 2, [3] = 3 }, + }; + byte[] mapBytes = Serialize(map); + long mapRequired = MapBudget(map.Values.Count); + Assert.Throws(() => NewFory(mapRequired - 1).Deserialize(mapBytes)); + Assert.Equal(map.Values, NewFory(mapRequired).Deserialize(mapBytes).Values); + } + [Fact] public void DenseStringBinaryAndPrimitiveArraysAreSkipped() { diff --git a/kotlin/fory-kotlin/src/test/kotlin/org/apache/fory/serializer/kotlin/CollectionSerializerTest.kt b/kotlin/fory-kotlin/src/test/kotlin/org/apache/fory/serializer/kotlin/CollectionSerializerTest.kt index 1d17f39b91..92749a2506 100644 --- a/kotlin/fory-kotlin/src/test/kotlin/org/apache/fory/serializer/kotlin/CollectionSerializerTest.kt +++ b/kotlin/fory-kotlin/src/test/kotlin/org/apache/fory/serializer/kotlin/CollectionSerializerTest.kt @@ -46,7 +46,7 @@ class CollectionSerializerTest { .build() try { - reader.deserialize(writer.serialize(ArrayDeque())) + reader.deserialize(writer.serialize(ArrayDeque(listOf(1, 2, 3, 4, 5, 6)))) fail("Expected container memory budget failure") } catch (ignored: InsecureException) {} } diff --git a/scala/src/test/scala/org/apache/fory/serializer/scala/CollectionSerializerTest.scala b/scala/src/test/scala/org/apache/fory/serializer/scala/CollectionSerializerTest.scala index fc386639dd..70078534d2 100644 --- a/scala/src/test/scala/org/apache/fory/serializer/scala/CollectionSerializerTest.scala +++ b/scala/src/test/scala/org/apache/fory/serializer/scala/CollectionSerializerTest.scala @@ -105,19 +105,19 @@ class CollectionSerializerTest extends AnyWordSpec with Matchers { builder.build() } - "charge scala collection fixed cost" in { + "charge scala collection storage" in { val writer = runtime() val reader = runtime(maxContainerMemoryBytes = 23) intercept[InsecureException] { - reader.deserialize(writer.serialize(List.empty[String])) + reader.deserialize(writer.serialize(List.fill(6)("v"))) } } - "charge scala map fixed cost" in { + "charge scala map storage" in { val writer = runtime() - val reader = runtime(maxContainerMemoryBytes = 47) + val reader = runtime(maxContainerMemoryBytes = 23) intercept[InsecureException] { - reader.deserialize(writer.serialize(Map("k" -> "v"))) + reader.deserialize(writer.serialize(Map("a" -> 1, "b" -> 2, "c" -> 3))) } } } diff --git a/scala/src/test/scala/org/apache/fory/serializer/scala/ScalaXlangSerializerTest.scala b/scala/src/test/scala/org/apache/fory/serializer/scala/ScalaXlangSerializerTest.scala index b637bc4f6c..f8f979cd6c 100644 --- a/scala/src/test/scala/org/apache/fory/serializer/scala/ScalaXlangSerializerTest.scala +++ b/scala/src/test/scala/org/apache/fory/serializer/scala/ScalaXlangSerializerTest.scala @@ -134,10 +134,10 @@ class ScalaXlangSerializerTest extends AnyWordSpec with Matchers { .build() intercept[InsecureException] { - reader.deserialize(writer.serialize(List.empty[String])) + reader.deserialize(writer.serialize(List.fill(6)("v"))) } intercept[InsecureException] { - reader.deserialize(writer.serialize(Map("k" -> 1))) + reader.deserialize(writer.serialize(Map("a" -> 1, "b" -> 2, "c" -> 3))) } } } From d0b27b19bddb723f22cebf969f317a87d3514644 Mon Sep 17 00:00:00 2001 From: chaokunyang Date: Tue, 30 Jun 2026 14:47:12 +0800 Subject: [PATCH 07/54] fix: budget Swift Any materialization --- swift/Sources/Fory/ReadContext.swift | 22 ++++- .../ContainerMemoryBudgetTests.swift | 91 +++++++++++++++++++ 2 files changed, 112 insertions(+), 1 deletion(-) diff --git a/swift/Sources/Fory/ReadContext.swift b/swift/Sources/Fory/ReadContext.swift index ce44719551..088b623608 100644 --- a/swift/Sources/Fory/ReadContext.swift +++ b/swift/Sources/Fory/ReadContext.swift @@ -18,6 +18,7 @@ import Foundation private let typeMetaSizeMask = 0xFF +private let materializedAnyReferenceBytes = 4 public final class ReadContext { static let knownContainerBudgetSlackBytes = 64 * 1024 @@ -782,7 +783,14 @@ extension ReadContext { refMode: refMode, readTypeInfo: readTypeInfo ) - return wrapped?.map { $0.anyValueForCollection() } + guard let wrapped else { + return nil + } + try reserveCountedContainerMemory( + count: wrapped.count, + elementBytes: materializedAnyReferenceBytes + ) + return wrapped.map { $0.anyValueForCollection() } } public func readMapStringToAny( @@ -797,6 +805,10 @@ extension ReadContext { guard let wrapped else { return nil } + try reserveCountedContainerMemory( + count: wrapped.count, + elementBytes: 2 * materializedAnyReferenceBytes + ) var map: [String: Any] = [:] map.reserveCapacity(wrapped.count) for pair in wrapped { @@ -817,6 +829,10 @@ extension ReadContext { guard let wrapped else { return nil } + try reserveCountedContainerMemory( + count: wrapped.count, + elementBytes: 2 * materializedAnyReferenceBytes + ) var map: [Int32: Any] = [:] map.reserveCapacity(wrapped.count) for pair in wrapped { @@ -837,6 +853,10 @@ extension ReadContext { guard let wrapped else { return nil } + try reserveCountedContainerMemory( + count: wrapped.count, + elementBytes: 2 * materializedAnyReferenceBytes + ) var map: [AnyHashable: Any] = [:] map.reserveCapacity(wrapped.count) for pair in wrapped { diff --git a/swift/Tests/ForyTests/ContainerMemoryBudgetTests.swift b/swift/Tests/ForyTests/ContainerMemoryBudgetTests.swift index 18b9908157..a9b9f15081 100644 --- a/swift/Tests/ForyTests/ContainerMemoryBudgetTests.swift +++ b/swift/Tests/ForyTests/ContainerMemoryBudgetTests.swift @@ -198,6 +198,97 @@ func dynamicAnyEmptyMapHasNoDynamicStorage() throws { #expect((decoded as? [String: Any])?.isEmpty == true) } +@Test +func publicAnyArrayBudget() throws { + let value: [Any] = [Int32(1), Int32(2), Int32(3)] + let bytes = try makeBudgetFory().serialize(value) + let wrappedBudget = arrayBudget(SerializableAny.self, count: value.count) + let finalBudget = value.count * testReferenceBytes + + expectInvalidData { + let _: [Any] = try makeBudgetFory(maxContainerMemoryBytes: Int64(wrappedBudget)) + .deserialize(bytes, as: [Any].self) + } + let decoded = try makeBudgetFory(maxContainerMemoryBytes: Int64(wrappedBudget + finalBudget)) + .deserialize(bytes, as: [Any].self) + #expect(decoded.count == value.count) +} + +@Test +func publicAnyMapBudget() throws { + let stringMap: [String: Any] = ["a": Int32(1), "b": Int32(2), "c": Int32(3)] + let stringBytes = try makeBudgetFory().serialize(stringMap) + let stringWrapped = mapBudget( + key: String.self, + value: SerializableAny.self, + count: stringMap.count + ) + let stringFinal = stringMap.count * 2 * testReferenceBytes + expectInvalidData { + let _: [String: Any] = try makeBudgetFory(maxContainerMemoryBytes: Int64(stringWrapped)) + .deserialize(stringBytes, as: [String: Any].self) + } + let decodedString = try makeBudgetFory(maxContainerMemoryBytes: Int64(stringWrapped + stringFinal)) + .deserialize(stringBytes, as: [String: Any].self) + #expect(decodedString.count == stringMap.count) + + let intMap: [Int32: Any] = [1: Int32(10), 2: Int32(20), 3: Int32(30)] + let intBytes = try makeBudgetFory().serialize(intMap) + let intWrapped = mapBudget( + key: Int32.self, + value: SerializableAny.self, + count: intMap.count + ) + let intFinal = intMap.count * 2 * testReferenceBytes + expectInvalidData { + let _: [Int32: Any] = try makeBudgetFory(maxContainerMemoryBytes: Int64(intWrapped)) + .deserialize(intBytes, as: [Int32: Any].self) + } + let decodedInt = try makeBudgetFory(maxContainerMemoryBytes: Int64(intWrapped + intFinal)) + .deserialize(intBytes, as: [Int32: Any].self) + #expect(decodedInt.count == intMap.count) + + let anyHashableMap: [AnyHashable: Any] = [ + AnyHashable("a"): Int32(1), + AnyHashable(Int32(2)): Int32(2), + AnyHashable(true): Int32(3) + ] + let anyHashableBytes = try makeBudgetFory().serialize(anyHashableMap) + let anyHashableWrapped = mapBudget( + key: AnyHashable.self, + value: SerializableAny.self, + count: anyHashableMap.count + ) + let anyHashableFinal = anyHashableMap.count * 2 * testReferenceBytes + expectInvalidData { + let _: [AnyHashable: Any] = try makeBudgetFory( + maxContainerMemoryBytes: Int64(anyHashableWrapped) + ).deserialize(anyHashableBytes, as: [AnyHashable: Any].self) + } + let decodedAnyHashable = try makeBudgetFory( + maxContainerMemoryBytes: Int64(anyHashableWrapped + anyHashableFinal) + ).deserialize(anyHashableBytes, as: [AnyHashable: Any].self) + #expect(decodedAnyHashable.count == anyHashableMap.count) +} + +@Test +func dynamicAnyArrayBudget() throws { + let list: [Any] = [Int32(1), "two", Int32(3)] + let value: Any = list + let bytes = try makeBudgetFory().serialize(value) + let count = list.count + let wrappedBudget = arrayBudget(SerializableAny.self, count: count) + let finalBudget = count * testReferenceBytes + + expectInvalidData { + let _: Any = try makeBudgetFory(maxContainerMemoryBytes: Int64(wrappedBudget)) + .deserialize(bytes, as: Any.self) + } + let decoded = try makeBudgetFory(maxContainerMemoryBytes: Int64(wrappedBudget + finalBudget)) + .deserialize(bytes, as: Any.self) + #expect((decoded as? [Any])?.count == count) +} + @Test func byteAvailabilityCheckStillRejectsLargeLength() throws { let buffer = ByteBuffer() From fb7a7d23704f5cc1658d02eb9e9cd49cce613049 Mon Sep 17 00:00:00 2001 From: chaokunyang Date: Tue, 30 Jun 2026 14:57:07 +0800 Subject: [PATCH 08/54] fix: expose C# container budget accounting to generated code --- csharp/src/Fory/ReadContext.cs | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/csharp/src/Fory/ReadContext.cs b/csharp/src/Fory/ReadContext.cs index f1e0ff8fa8..b27af464bd 100644 --- a/csharp/src/Fory/ReadContext.cs +++ b/csharp/src/Fory/ReadContext.cs @@ -89,7 +89,14 @@ internal void InitContainerBudgetKnown(int rootBytes) } [MethodImpl(MethodImplOptions.AggressiveInlining)] - internal void ReserveContainerMemory(long bytes) + /// + /// Reserves estimated container-owned memory for the current root deserialization. + /// + /// + /// Serializer owners compute container-specific formulas and pass raw bytes here. This + /// accounting does not replace byte-availability checks before backing allocation. + /// + public void ReserveContainerMemory(long bytes) { long remaining = _remainingContainerMemoryBytes; if ((ulong)bytes > (ulong)remaining) @@ -101,7 +108,15 @@ internal void ReserveContainerMemory(long bytes) } [MethodImpl(MethodImplOptions.AggressiveInlining)] - internal void ReserveCountedContainerMemory(int count, long elementBytes) + /// + /// Reserves multiplied by estimated + /// container-owned bytes for the current root deserialization. + /// + /// + /// This helper owns only overflow-safe arithmetic; concrete serializers and generated + /// serializers still own the collection, array, and map storage formulas. + /// + public void ReserveCountedContainerMemory(int count, long elementBytes) { if (count < 0 || elementBytes < 0) { From 04c12206dcdb4560e7d84cde5dcde939016442fe Mon Sep 17 00:00:00 2001 From: chaokunyang Date: Tue, 30 Jun 2026 15:04:07 +0800 Subject: [PATCH 09/54] fix: keep compatible array reader native-image safe --- .../fory/serializer/CompatibleCollectionArrayReader.java | 4 +++- .../org/apache/fory/serializer/ContainerMemoryBudgetTest.java | 1 - 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/CompatibleCollectionArrayReader.java b/java/fory-core/src/main/java/org/apache/fory/serializer/CompatibleCollectionArrayReader.java index 38ed3a99d8..49ae18fc71 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/CompatibleCollectionArrayReader.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/CompatibleCollectionArrayReader.java @@ -59,7 +59,9 @@ import org.apache.fory.type.Types; final class CompatibleCollectionArrayReader { - private static final int REFERENCE_BYTES = MemoryBuffer.objectArrayIndexScale(); + // This compatible reader may be reached during native-image analysis. Use the settled + // reference-slot fallback instead of touching MemoryBuffer from class initialization. + private static final int REFERENCE_BYTES = 4; static final int READ_LIST_TO_ARRAY = 1; static final int READ_ARRAY_TO_LIST = 2; diff --git a/java/fory-core/src/test/java/org/apache/fory/serializer/ContainerMemoryBudgetTest.java b/java/fory-core/src/test/java/org/apache/fory/serializer/ContainerMemoryBudgetTest.java index 2ed3fca56d..1b34a9a600 100644 --- a/java/fory-core/src/test/java/org/apache/fory/serializer/ContainerMemoryBudgetTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/serializer/ContainerMemoryBudgetTest.java @@ -257,5 +257,4 @@ private static MemoryBuffer objectArraySizeBuffer(int numElements) { private static MemoryBuffer trimBuffer(MemoryBuffer buffer) { return MemoryBuffer.fromByteArray(buffer.getBytes(0, buffer.writerIndex())); } - } From ee85d4fee68541904907b89f687252604fee3e74 Mon Sep 17 00:00:00 2001 From: chaokunyang Date: Tue, 30 Jun 2026 20:55:54 +0800 Subject: [PATCH 10/54] perf: add container memory budget --- cpp/fory/serialization/context.cc | 28 ++++++- cpp/fory/serialization/context.h | 71 ++++++++-------- cpp/fory/serialization/fory.h | 17 ++-- cpp/fory/serialization/serializer_traits.h | 81 +++++++++++++++++++ .../src/main/java/org/apache/fory/Fory.java | 26 +++--- .../org/apache/fory/builder/CodecUtils.java | 48 ++++++++--- .../org/apache/fory/context/ReadContext.java | 7 +- .../org/apache/fory/context/WriteContext.java | 30 +++++++ .../org/apache/fory/memory/MemoryBuffer.java | 7 -- .../apache/fory/resolver/TypeResolver.java | 4 +- .../fory/serializer/ArraySerializers.java | 2 +- .../collection/ChildContainerSerializers.java | 6 +- .../collection/CollectionLikeSerializer.java | 9 ++- .../collection/CollectionSerializers.java | 34 ++++---- .../GuavaCollectionSerializers.java | 12 +-- .../ImmutableCollectionSerializers.java | 6 +- .../collection/MapLikeSerializer.java | 9 ++- .../serializer/collection/MapSerializers.java | 16 ++-- .../collection/SubListSerializers.java | 2 +- .../org/apache/fory/memory/MemoryBuffer.java | 5 -- .../serializer/ContainerMemoryBudgetTest.java | 2 +- 21 files changed, 289 insertions(+), 133 deletions(-) diff --git a/cpp/fory/serialization/context.cc b/cpp/fory/serialization/context.cc index 132847c99b..f88c19f9e1 100644 --- a/cpp/fory/serialization/context.cc +++ b/cpp/fory/serialization/context.cc @@ -749,6 +749,31 @@ bool ReadContext::reserve_counted_container_checked(uint32_t length, return reserve_container_memory(static_cast(length) * elem_bytes); } +bool ReadContext::init_explicit_container_budget(int64_t configured) { + const uint64_t limit = static_cast(configured); + if constexpr (sizeof(size_t) < sizeof(uint64_t)) { + if (FORY_PREDICT_FALSE(limit > static_cast( + std::numeric_limits::max()))) { + return set_container_memory_error( + "max_container_memory_bytes does not fit size_t"); + } + } + remaining_container_memory_bytes_ = static_cast(limit); + container_budget_state_ = kContainerBudgetReady; + return true; +} + +bool ReadContext::materialize_container_budget() { + switch (container_budget_state_) { + case kContainerBudgetPendingKnown: + return init_container_budget_known(pending_container_root_bytes_); + case kContainerBudgetPendingUnknown: + return init_container_budget_unknown(); + default: + return true; + } +} + bool ReadContext::set_container_memory_error(const std::string &message) { set_error(Error::invalid_data(message)); return false; @@ -767,8 +792,7 @@ bool ReadContext::set_container_memory_exceeded(size_t bytes, set_error(Error::invalid_data( "estimated container memory request " + std::to_string(bytes) + " bytes exceeds max_container_memory_bytes remaining budget " + - std::to_string(remaining) + " bytes out of effective limit " + - std::to_string(container_memory_limit_bytes_) + " bytes")); + std::to_string(remaining) + " bytes")); return false; } diff --git a/cpp/fory/serialization/context.h b/cpp/fory/serialization/context.h index 09d321710d..cb8a6a378e 100644 --- a/cpp/fory/serialization/context.h +++ b/cpp/fory/serialization/context.h @@ -506,20 +506,11 @@ class ReadContext { } FORY_ALWAYS_INLINE bool init_container_budget_known(size_t root_bytes) { - size_t limit = 0; - if (config_->max_container_memory_bytes > 0) { - const uint64_t configured = - static_cast(config_->max_container_memory_bytes); - if constexpr (sizeof(size_t) < sizeof(uint64_t)) { - if (FORY_PREDICT_FALSE( - configured > - static_cast(std::numeric_limits::max()))) { - return set_container_memory_error( - "max_container_memory_bytes does not fit size_t"); - } - } - limit = static_cast(configured); - } else { + const int64_t configured = config_->max_container_memory_bytes; + if (FORY_PREDICT_FALSE(configured > 0)) { + return init_explicit_container_budget(configured); + } + if constexpr (sizeof(size_t) <= sizeof(uint32_t)) { constexpr size_t max_root_bytes = (std::numeric_limits::max() - kKnownContainerBudgetSlackBytes) / kKnownContainerBudgetMultiplier; @@ -527,37 +518,39 @@ class ReadContext { return set_container_memory_error( "root input size overflows automatic container memory budget"); } - limit = root_bytes * kKnownContainerBudgetMultiplier + - kKnownContainerBudgetSlackBytes; } - container_memory_limit_bytes_ = limit; - remaining_container_memory_bytes_ = limit; + remaining_container_memory_bytes_ = + root_bytes * kKnownContainerBudgetMultiplier + + kKnownContainerBudgetSlackBytes; + container_budget_state_ = kContainerBudgetReady; return true; } FORY_ALWAYS_INLINE bool init_container_budget_unknown() { - size_t limit = 0; - if (config_->max_container_memory_bytes > 0) { - const uint64_t configured = - static_cast(config_->max_container_memory_bytes); - if constexpr (sizeof(size_t) < sizeof(uint64_t)) { - if (FORY_PREDICT_FALSE( - configured > - static_cast(std::numeric_limits::max()))) { - return set_container_memory_error( - "max_container_memory_bytes does not fit size_t"); - } - } - limit = static_cast(configured); - } else { - limit = kUnknownContainerBudgetBytes; + const int64_t configured = config_->max_container_memory_bytes; + if (FORY_PREDICT_FALSE(configured > 0)) { + return init_explicit_container_budget(configured); } - container_memory_limit_bytes_ = limit; - remaining_container_memory_bytes_ = limit; + remaining_container_memory_bytes_ = kUnknownContainerBudgetBytes; + container_budget_state_ = kContainerBudgetReady; return true; } + FORY_ALWAYS_INLINE void defer_container_budget_known(size_t root_bytes) { + pending_container_root_bytes_ = root_bytes; + container_budget_state_ = kContainerBudgetPendingKnown; + } + + FORY_ALWAYS_INLINE void defer_container_budget_unknown() { + container_budget_state_ = kContainerBudgetPendingUnknown; + } + FORY_ALWAYS_INLINE bool reserve_container_memory(size_t bytes) { + if (FORY_PREDICT_FALSE(container_budget_state_ != kContainerBudgetReady)) { + if (FORY_PREDICT_FALSE(!materialize_container_budget())) { + return false; + } + } const size_t remaining = remaining_container_memory_bytes_; if (FORY_PREDICT_FALSE(bytes > remaining)) { return set_container_memory_exceeded(bytes, remaining); @@ -737,12 +730,17 @@ class ReadContext { static constexpr size_t kKnownContainerBudgetSlackBytes = 64 * 1024; static constexpr size_t kUnknownContainerBudgetBytes = 128ULL * 1024ULL * 1024ULL; + static constexpr uint8_t kContainerBudgetReady = 0; + static constexpr uint8_t kContainerBudgetPendingKnown = 1; + static constexpr uint8_t kContainerBudgetPendingUnknown = 2; FORY_NOINLINE Result check_remote_type_meta_limit(const TypeMeta &type_meta); void record_remote_type_meta(const std::string &type_key); FORY_NOINLINE bool reserve_counted_container_checked(uint32_t length, size_t elem_bytes); + FORY_NOINLINE bool init_explicit_container_budget(int64_t configured); + FORY_NOINLINE bool materialize_container_budget(); FORY_NOINLINE bool set_container_memory_error(const std::string &message); FORY_NOINLINE bool set_container_memory_overflow(uint32_t length, size_t elem_bytes); @@ -757,7 +755,8 @@ class ReadContext { std::unique_ptr type_resolver_; RefReader ref_reader_; uint32_t current_dyn_depth_; - size_t container_memory_limit_bytes_ = std::numeric_limits::max(); + uint8_t container_budget_state_ = kContainerBudgetReady; + size_t pending_container_root_bytes_ = 0; size_t remaining_container_memory_bytes_ = std::numeric_limits::max(); // Meta sharing state (for compatible mode) diff --git a/cpp/fory/serialization/fory.h b/cpp/fory/serialization/fory.h index 6d26c3bfa7..1ef660cc90 100644 --- a/cpp/fory/serialization/fory.h +++ b/cpp/fory/serialization/fory.h @@ -875,15 +875,7 @@ class Fory : public BaseFory { template FORY_ALWAYS_INLINE Result deserialize_buffer(Buffer &buffer) { - const bool budget_ok = - unknown_root - ? read_ctx_->init_container_budget_unknown() - : read_ctx_->init_container_budget_known(buffer.remaining_size()); - if (FORY_PREDICT_FALSE(!budget_ok)) { - Error error = read_ctx_->take_error(); - read_ctx_->reset(); - return Unexpected(std::move(error)); - } + const size_t root_bytes = unknown_root ? 0 : buffer.remaining_size(); Error header_error; const uint8_t header = buffer.read_uint8(header_error); @@ -897,6 +889,13 @@ class Fory : public BaseFory { } read_ctx_->attach(buffer); + if constexpr (needs_container_budget_v) { + if constexpr (unknown_root) { + read_ctx_->defer_container_budget_unknown(); + } else { + read_ctx_->defer_container_budget_known(root_bytes); + } + } ReadContextGuard guard(*read_ctx_); return deserialize_impl(buffer); } diff --git a/cpp/fory/serialization/serializer_traits.h b/cpp/fory/serialization/serializer_traits.h index ad07c4aa6d..e26ba5c899 100644 --- a/cpp/fory/serialization/serializer_traits.h +++ b/cpp/fory/serialization/serializer_traits.h @@ -19,6 +19,7 @@ #pragma once +#include "fory/meta/field.h" #include "fory/meta/field_info.h" #include "fory/meta/type_index.h" #include "fory/meta/type_traits.h" @@ -244,6 +245,86 @@ struct is_fory_serializable< template inline constexpr bool is_fory_serializable_v = is_fory_serializable::value; +// ============================================================================ +// Container budget reachability +// ============================================================================ + +template +struct needs_container_budget : std::false_type {}; + +template <> struct needs_container_budget : std::false_type {}; +template <> struct needs_container_budget : std::false_type {}; +template <> struct needs_container_budget : std::false_type {}; +template <> struct needs_container_budget : std::false_type {}; +template <> struct needs_container_budget : std::false_type {}; +template <> struct needs_container_budget : std::false_type {}; +template <> struct needs_container_budget : std::false_type {}; +template <> struct needs_container_budget : std::false_type {}; +template <> struct needs_container_budget : std::false_type {}; +template <> struct needs_container_budget : std::false_type {}; +template <> struct needs_container_budget : std::false_type {}; +template <> +struct needs_container_budget : std::false_type {}; + +template +struct needs_container_budget< + T, std::enable_if_t || is_list_v || is_deque_v || + is_forward_list_v || is_set_like_v || + is_map_like_v>> : std::true_type {}; + +template +struct needs_container_budget, void> + : std::bool_constant>>::value> {}; + +template +struct needs_container_budget, void> + : std::bool_constant>>::value> {}; + +template +struct needs_container_budget, void> : std::true_type {}; + +template +struct needs_container_budget, void> : std::true_type {}; + +template +struct needs_container_budget, void> + : std::bool_constant<(needs_container_budget>>::value || + ...)> {}; + +template +struct needs_container_budget, void> + : std::bool_constant<(needs_container_budget>>::value || + ...)> {}; + +template +constexpr bool struct_needs_container_budget_impl(std::index_sequence) { + return ( + needs_container_budget< + std::remove_cv_t>>>>::value || + ...); +} + +template +struct needs_container_budget>> { +private: + using FieldDescriptor = + decltype(::fory::meta::fory_field_info(std::declval())); + using Ptrs = typename FieldDescriptor::PtrsType; + +public: + static constexpr bool value = struct_needs_container_budget_impl( + std::make_index_sequence{}); +}; + +template +inline constexpr bool needs_container_budget_v = + needs_container_budget>>::value; + // ============================================================================ // Generic Type Detection // ============================================================================ diff --git a/java/fory-core/src/main/java/org/apache/fory/Fory.java b/java/fory-core/src/main/java/org/apache/fory/Fory.java index 5b2bfef3ff..b961e9c582 100644 --- a/java/fory-core/src/main/java/org/apache/fory/Fory.java +++ b/java/fory-core/src/main/java/org/apache/fory/Fory.java @@ -410,7 +410,7 @@ private ForyException processCopyError(Throwable e) { @Override public Object deserialize(byte[] bytes) { - return deserialize(MemoryUtils.wrap(bytes), (Iterable) null); + return deserialize(MemoryUtils.wrap(bytes), (Iterable) null, false, bytes.length); } @Override @@ -420,17 +420,17 @@ public Object deserialize(ByteBuffer byteBuffer) { @Override public T deserialize(byte[] bytes, Class type) { - return deserialize(MemoryUtils.wrap(bytes), type); + return deserialize(MemoryUtils.wrap(bytes), type, false, bytes.length); } @Override public T deserialize(MemoryBuffer buffer, Class type) { - return deserialize(buffer, type, false); + return deserialize(buffer, type, false, buffer.remaining()); } - private T deserialize(MemoryBuffer buffer, Class type, boolean unknownLengthInput) { + private T deserialize( + MemoryBuffer buffer, Class type, boolean unknownLengthInput, int rootInputBytes) { ensureRegistrationFinished(); - int rootInputBytes = buffer.remaining(); byte bitmap = buffer.readByte(); if (bitmap != headerBitmap) { checkHeaderBitmapWithoutOutOfBand(bitmap); @@ -456,7 +456,7 @@ private T deserialize(MemoryBuffer buffer, Class type, boolean unknownLen @Override public T deserialize(ForyInputStream inputStream, Class type) { try { - return deserialize(inputStream.getBuffer(), type, true); + return deserialize(inputStream.getBuffer(), type, true, 0); } finally { inputStream.shrinkBuffer(); } @@ -464,7 +464,7 @@ public T deserialize(ForyInputStream inputStream, Class type) { @Override public T deserialize(ForyReadableChannel channel, Class type) { - return deserialize(channel.getBuffer(), type, true); + return deserialize(channel.getBuffer(), type, true, 0); } @Override @@ -492,13 +492,15 @@ public Object deserialize(MemoryBuffer buffer) { */ @Override public Object deserialize(MemoryBuffer buffer, Iterable outOfBandBuffers) { - return deserialize(buffer, outOfBandBuffers, false); + return deserialize(buffer, outOfBandBuffers, false, buffer.remaining()); } private Object deserialize( - MemoryBuffer buffer, Iterable outOfBandBuffers, boolean unknownLengthInput) { + MemoryBuffer buffer, + Iterable outOfBandBuffers, + boolean unknownLengthInput, + int rootInputBytes) { ensureRegistrationFinished(); - int rootInputBytes = buffer.remaining(); byte bitmap = buffer.readByte(); boolean peerOutOfBandEnabled = false; if (bitmap != headerBitmap) { @@ -547,7 +549,7 @@ public Object deserialize(ForyInputStream inputStream) { public Object deserialize(ForyInputStream inputStream, Iterable outOfBandBuffers) { try { MemoryBuffer buf = inputStream.getBuffer(); - return deserialize(buf, outOfBandBuffers, true); + return deserialize(buf, outOfBandBuffers, true, 0); } finally { inputStream.shrinkBuffer(); } @@ -561,7 +563,7 @@ public Object deserialize(ForyReadableChannel channel) { @Override public Object deserialize(ForyReadableChannel channel, Iterable outOfBandBuffers) { MemoryBuffer buf = channel.getBuffer(); - return deserialize(buf, outOfBandBuffers, true); + return deserialize(buf, outOfBandBuffers, true, 0); } @SuppressWarnings("unchecked") diff --git a/java/fory-core/src/main/java/org/apache/fory/builder/CodecUtils.java b/java/fory-core/src/main/java/org/apache/fory/builder/CodecUtils.java index c9092f6723..838d4cc34a 100644 --- a/java/fory-core/src/main/java/org/apache/fory/builder/CodecUtils.java +++ b/java/fory-core/src/main/java/org/apache/fory/builder/CodecUtils.java @@ -19,7 +19,6 @@ package org.apache.fory.builder; -import java.util.Collections; import java.util.concurrent.Callable; import java.util.concurrent.ConcurrentHashMap; import org.apache.fory.Fory; @@ -27,7 +26,10 @@ import org.apache.fory.codegen.CompileUnit; import org.apache.fory.collection.Tuple3; import org.apache.fory.meta.TypeDef; +import org.apache.fory.platform.AndroidSupport; import org.apache.fory.platform.GraalvmSupport; +import org.apache.fory.platform.JdkVersion; +import org.apache.fory.reflect.ReflectionUtils; import org.apache.fory.reflect.TypeRef; import org.apache.fory.resolver.TypeResolver; import org.apache.fory.serializer.Serializer; @@ -118,12 +120,6 @@ public static Class loadOrGenCompatibleLayerCodecClass @SuppressWarnings("unchecked") static Class loadOrGenCodecClass( Class beanClass, Fory fory, BaseObjectCodecBuilder codecBuilder) { - // use genCodeFunc to avoid gen code repeatedly - CompileUnit compileUnit = - new CompileUnit( - CodeGenerator.getPackage(beanClass), - codecBuilder.codecClassName(beanClass), - codecBuilder::genCode); CodeGenerator codeGenerator; ClassLoader beanClassClassLoader = beanClass.getClassLoader() == null @@ -134,15 +130,41 @@ static Class loadOrGenCodecClass( } TypeResolver typeResolver = fory.getTypeResolver(); codeGenerator = getCodeGenerator(beanClassClassLoader, typeResolver); - ClassLoader classLoader = - codeGenerator.compile( - Collections.singletonList(compileUnit), compileState -> compileState.lock.lock()); - String className = codecBuilder.codecQualifiedClassName(beanClass); + Class neighborClass = codecNeighbor(beanClass, beanClassClassLoader); + codecBuilder.setSamePackageAccess(neighborClass != null); + // use genCodeFunc to avoid gen code repeatedly + CompileUnit compileUnit = + new CompileUnit( + CodeGenerator.getPackage(beanClass), + codecBuilder.codecClassName(beanClass), + codecBuilder::genCode, + neighborClass); + return (Class) + codeGenerator.compileAndLoad(compileUnit, compileState -> compileState.lock.lock()); + } + + private static Class codecNeighbor(Class beanClass, ClassLoader beanClassClassLoader) { + // Hidden generated serializers are only a JDK25+ path for source-non-public bean classes. + // Source-public beans keep the normal generated class so split helpers remain private; JDK25 + // VarHandle field access is independent of hidden class definition. + if (AndroidSupport.IS_ANDROID + || JdkVersion.MAJOR_VERSION < 25 + || beanClass.getClassLoader() == null + || CodeGenerator.sourcePublicAccessible(beanClass)) { + return null; + } + if (!CodeGenerator.getPackage(beanClass).equals(ReflectionUtils.getPackage(beanClass))) { + return null; + } try { - return (Class) classLoader.loadClass(className); + // A generated serializer defined in the bean loader must resolve Fory runtime classes there. + if (beanClassClassLoader.loadClass(Fory.class.getName()) == Fory.class) { + return beanClass; + } } catch (ClassNotFoundException e) { - throw new IllegalStateException("Impossible because we just compiled class", e); + // The composed-loader path remains the owner when the bean loader cannot see Fory directly. } + return null; } private static CodeGenerator getCodeGenerator( diff --git a/java/fory-core/src/main/java/org/apache/fory/context/ReadContext.java b/java/fory-core/src/main/java/org/apache/fory/context/ReadContext.java index 5c61cab2ff..53db4131a4 100644 --- a/java/fory-core/src/main/java/org/apache/fory/context/ReadContext.java +++ b/java/fory-core/src/main/java/org/apache/fory/context/ReadContext.java @@ -640,7 +640,12 @@ public T readRef(Serializer serializer) { return (T) readNonRef(serializer); } - /** Reads the root object for one deserialization operation. */ + /** + * Reads the root object for one deserialization operation. + * + *

Root no-ref deserialization owns the null marker and type metadata directly; using the + * generic ref-reader path here makes scalar roots pay reference dispatch they can never use. + */ public Object readRootRef() { if (trackingRef) { return readRef(rootTypeInfoHolder); diff --git a/java/fory-core/src/main/java/org/apache/fory/context/WriteContext.java b/java/fory-core/src/main/java/org/apache/fory/context/WriteContext.java index bf74c0a67a..9147d87409 100644 --- a/java/fory-core/src/main/java/org/apache/fory/context/WriteContext.java +++ b/java/fory-core/src/main/java/org/apache/fory/context/WriteContext.java @@ -555,6 +555,36 @@ public void writeRootRef(Object obj) { writeData(typeInfo, obj); } + /** + * Writes the root object for one serialization operation. + * + *

Root no-ref serialization still owns the null marker and type metadata, but it must not pay + * generic {@link RefWriter} dispatch or uncached type-info lookup when reference tracking is + * disabled. + */ + public void writeRootRef(Object obj) { + if (trackingRef) { + writeRef(obj, rootTypeInfoHolder); + return; + } + MemoryBuffer buffer = this.buffer; + if (obj == null) { + buffer.writeByte(Fory.NULL_FLAG); + return; + } + buffer.writeByte(Fory.NOT_NULL_VALUE_FLAG); + TypeResolver resolver = typeResolver; + TypeInfo typeInfo = resolver.getTypeInfo(obj.getClass(), rootTypeInfoHolder); + if (crossLanguage && typeInfo.getType() == UnknownStruct.class) { + depth++; + typeInfo.getSerializer().write(this, obj); + depth--; + return; + } + resolver.writeTypeInfo(this, typeInfo); + writeData(typeInfo, obj); + } + /** * Writes a non-null, first-seen object together with its type metadata. * diff --git a/java/fory-core/src/main/java/org/apache/fory/memory/MemoryBuffer.java b/java/fory-core/src/main/java/org/apache/fory/memory/MemoryBuffer.java index 9e82d2ba3f..c0f002ca98 100644 --- a/java/fory-core/src/main/java/org/apache/fory/memory/MemoryBuffer.java +++ b/java/fory-core/src/main/java/org/apache/fory/memory/MemoryBuffer.java @@ -77,7 +77,6 @@ public final class MemoryBuffer { private static final int LONG_ARRAY_OFFSET; private static final int FLOAT_ARRAY_OFFSET; private static final int DOUBLE_ARRAY_OFFSET; - private static final int OBJECT_ARRAY_INDEX_SCALE; // GraalVM native-image recognizes arrayBaseOffset only when the call stores directly into the // target static field. Keep these assignments in this shape so native images recompute heap array @@ -92,7 +91,6 @@ public final class MemoryBuffer { LONG_ARRAY_OFFSET = 0; FLOAT_ARRAY_OFFSET = 0; DOUBLE_ARRAY_OFFSET = 0; - OBJECT_ARRAY_INDEX_SCALE = 4; } else { BOOLEAN_ARRAY_OFFSET = UNSAFE.arrayBaseOffset(boolean[].class); BYTE_ARRAY_OFFSET = UNSAFE.arrayBaseOffset(byte[].class); @@ -102,7 +100,6 @@ public final class MemoryBuffer { LONG_ARRAY_OFFSET = UNSAFE.arrayBaseOffset(long[].class); FLOAT_ARRAY_OFFSET = UNSAFE.arrayBaseOffset(float[].class); DOUBLE_ARRAY_OFFSET = UNSAFE.arrayBaseOffset(double[].class); - OBJECT_ARRAY_INDEX_SCALE = UNSAFE.arrayIndexScale(Object[].class); } } @@ -4188,10 +4185,6 @@ public static MemoryBuffer fromDirectByteBuffer( return new MemoryBuffer(offHeapAddress, size, buffer, streamReader); } - public static int objectArrayIndexScale() { - return OBJECT_ARRAY_INDEX_SCALE > 0 ? OBJECT_ARRAY_INDEX_SCALE : 4; - } - /** * Create a heap buffer of specified initial size. The buffer will grow automatically if not * enough. diff --git a/java/fory-core/src/main/java/org/apache/fory/resolver/TypeResolver.java b/java/fory-core/src/main/java/org/apache/fory/resolver/TypeResolver.java index 275ccf1d05..8837fb8e50 100644 --- a/java/fory-core/src/main/java/org/apache/fory/resolver/TypeResolver.java +++ b/java/fory-core/src/main/java/org/apache/fory/resolver/TypeResolver.java @@ -773,7 +773,7 @@ public final TypeInfo readTypeInfo(ReadContext readContext, TypeInfo typeInfoCac break; case Types.COMPATIBLE_STRUCT: case Types.NAMED_COMPATIBLE_STRUCT: - typeInfo = readSharedClassTypeInfo(readContext, null); + typeInfo = readSharedClassTypeInfo(readContext, null, typeInfoCache); break; case Types.NAMED_ENUM: case Types.NAMED_STRUCT: @@ -782,7 +782,7 @@ public final TypeInfo readTypeInfo(ReadContext readContext, TypeInfo typeInfoCac if (!metaContextShareEnabled) { typeInfo = readTypeInfoFromBytes(readContext, typeInfoCache, typeId); } else { - typeInfo = readSharedClassTypeInfo(readContext, null); + typeInfo = readSharedClassTypeInfo(readContext, null, typeInfoCache); } break; case Types.LIST: diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/ArraySerializers.java b/java/fory-core/src/main/java/org/apache/fory/serializer/ArraySerializers.java index 148930ee3b..83c2758ea7 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/ArraySerializers.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/ArraySerializers.java @@ -45,7 +45,7 @@ * object-array paths avoid adapter allocation. */ public final class ArraySerializers { - private static final int REFERENCE_BYTES = MemoryBuffer.objectArrayIndexScale(); + private static final int REFERENCE_BYTES = 4; private ArraySerializers() {} diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/ChildContainerSerializers.java b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/ChildContainerSerializers.java index f58d56d08b..49c65b56cf 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/ChildContainerSerializers.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/ChildContainerSerializers.java @@ -249,7 +249,7 @@ public Collection onCollectionWrite(WriteContext writeContext, T value) { @Override public T newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readCollectionSize(readContext); + int numElements = readCollectionSize(readContext, buffer); setNumElements(numElements); int refId = readContext.lastPreservedRefId(); Comparator comparator = (Comparator) readContext.readRef(); @@ -295,7 +295,7 @@ public Collection onCollectionWrite(WriteContext writeContext, T value) { @Override public T newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readCollectionSize(readContext); + int numElements = readCollectionSize(readContext, buffer); setNumElements(numElements); int refId = readContext.lastPreservedRefId(); Comparator comparator = (Comparator) readContext.readRef(); @@ -403,7 +403,7 @@ public Map onMapWrite(WriteContext writeContext, T value) { @Override public Map newMap(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readMapSize(readContext); + int numElements = readMapSize(readContext, buffer); setNumElements(numElements); int refId = readContext.lastPreservedRefId(); Comparator comparator = (Comparator) readContext.readRef(); diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/CollectionLikeSerializer.java b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/CollectionLikeSerializer.java index bc71691597..99cf798c9f 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/CollectionLikeSerializer.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/CollectionLikeSerializer.java @@ -46,7 +46,7 @@ */ @SuppressWarnings({"unchecked", "rawtypes"}) public abstract class CollectionLikeSerializer extends Serializer { - private static final int REFERENCE_BYTES = MemoryBuffer.objectArrayIndexScale(); + private static final int REFERENCE_BYTES = 4; private MethodHandle constructor; private int numElements; @@ -463,7 +463,7 @@ public T read(ReadContext readContext) { */ public Collection newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - numElements = readCollectionSize(readContext); + numElements = readCollectionSize(readContext, buffer); if (AndroidSupport.IS_ANDROID) { try { Constructor constructor = type.getDeclaredConstructor(); @@ -563,7 +563,10 @@ protected void setNumElements(int numElements) { } protected final int readCollectionSize(ReadContext readContext) { - MemoryBuffer buffer = readContext.getBuffer(); + return readCollectionSize(readContext, readContext.getBuffer()); + } + + protected final int readCollectionSize(ReadContext readContext, MemoryBuffer buffer) { int numElements = buffer.readVarUInt32Small7(); checkCollectionSize(numElements); readContext.reserveContainerMemory((long) numElements * REFERENCE_BYTES); diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/CollectionSerializers.java b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/CollectionSerializers.java index 8c3850e88b..737e79d2bb 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/CollectionSerializers.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/CollectionSerializers.java @@ -74,7 +74,7 @@ */ @SuppressWarnings({"unchecked", "rawtypes"}) public class CollectionSerializers { - private static final int REFERENCE_BYTES = MemoryBuffer.objectArrayIndexScale(); + private static final int REFERENCE_BYTES = 4; private static final Comparator NATURAL_ORDER_COMPARATOR = Comparator.naturalOrder(); @@ -129,7 +129,7 @@ public ArrayListSerializer(TypeResolver typeResolver) { @Override public ArrayList newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readCollectionSize(readContext); + int numElements = readCollectionSize(readContext, buffer); setNumElements(numElements); ArrayList arrayList = new ArrayList(numElements); readContext.reference(arrayList); @@ -191,7 +191,7 @@ public List read(ReadContext readContext) { @Override public ArrayList newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readCollectionSize(readContext); + int numElements = readCollectionSize(readContext, buffer); setNumElements(numElements); ArrayList arrayList = new ArrayList(numElements); readContext.reference(arrayList); @@ -207,7 +207,7 @@ public HashSetSerializer(TypeResolver typeResolver) { @Override public HashSet newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readCollectionSize(readContext); + int numElements = readCollectionSize(readContext, buffer); setNumElements(numElements); HashSet hashSet = new HashSet(numElements); readContext.reference(hashSet); @@ -223,7 +223,7 @@ public LinkedHashSetSerializer(TypeResolver typeResolver) { @Override public LinkedHashSet newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readCollectionSize(readContext); + int numElements = readCollectionSize(readContext, buffer); setNumElements(numElements); LinkedHashSet hashSet = new LinkedHashSet(numElements); readContext.reference(hashSet); @@ -272,7 +272,7 @@ public Collection onCollectionWrite(WriteContext writeContext, T value) { @Override public T newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readCollectionSize(readContext); + int numElements = readCollectionSize(readContext, buffer); setNumElements(numElements); T collection; Comparator comparator = config.isXlang() ? null : (Comparator) readContext.readRef(); @@ -358,7 +358,7 @@ public CopyOnWriteArrayListSerializer( @Override public Collection newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readCollectionSize(readContext); + int numElements = readCollectionSize(readContext, buffer); setNumElements(numElements); return new CollectionContainer<>(numElements); } @@ -392,7 +392,7 @@ public CopyOnWriteArraySetSerializer( @Override public Collection newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readCollectionSize(readContext); + int numElements = readCollectionSize(readContext, buffer); setNumElements(numElements); return new CollectionContainer<>(numElements); } @@ -544,7 +544,7 @@ public CollectionSnapshot onCollectionWrite( @Override public ConcurrentSkipListSet newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readCollectionSize(readContext); + int numElements = readCollectionSize(readContext, buffer); setNumElements(numElements); if (config.isXlang()) { ConcurrentSkipListSet skipListSet = new ConcurrentSkipListSet(); @@ -728,7 +728,7 @@ public VectorSerializer(TypeResolver typeResolver, Class cls) { @Override public Vector newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readCollectionSize(readContext); + int numElements = readCollectionSize(readContext, buffer); setNumElements(numElements); Vector vector = new Vector<>(numElements); readContext.reference(vector); @@ -745,7 +745,7 @@ public ArrayDequeSerializer(TypeResolver typeResolver, Class cls) { @Override public ArrayDeque newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readCollectionSize(readContext); + int numElements = readCollectionSize(readContext, buffer); setNumElements(numElements); ArrayDeque deque = new ArrayDeque(numElements); readContext.reference(deque); @@ -788,7 +788,7 @@ public void write(WriteContext writeContext, EnumSet object) { public EnumSet read(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); Class elemClass = typeResolver.readTypeInfo(readContext).getType(); - int length = readCollectionSize(readContext); + int length = readCollectionSize(readContext, buffer); EnumSet object = EnumSet.noneOf(elemClass); Serializer elemSerializer = typeResolver.getSerializer(elemClass); for (int i = 0; i < length; i++) { @@ -865,7 +865,7 @@ public Collection newCollection(CopyContext copyContext, Collection collection) public PriorityQueue newCollection(ReadContext readContext) { assert !config.isXlang(); MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readCollectionSize(readContext); + int numElements = readCollectionSize(readContext, buffer); setNumElements(numElements); Comparator comparator = (Comparator) readContext.readRef(); PriorityQueue queue = new PriorityQueue(comparator); @@ -925,7 +925,7 @@ public CollectionSnapshot onCollectionWrite( @Override public ArrayBlockingQueue newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readCollectionSize(readContext); + int numElements = readCollectionSize(readContext, buffer); setNumElements(numElements); int capacity = buffer.readVarUInt32Small7(); checkBoundedQueueCapacity(numElements, capacity); @@ -993,7 +993,7 @@ public CollectionSnapshot onCollectionWrite( @Override public LinkedBlockingQueue newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readCollectionSize(readContext); + int numElements = readCollectionSize(readContext, buffer); setNumElements(numElements); int capacity = buffer.readVarUInt32Small7(); checkBoundedQueueCapacity(numElements, capacity); @@ -1135,7 +1135,7 @@ public XlangListDefaultSerializer(TypeResolver typeResolver, Class cls) { @Override public List newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readCollectionSize(readContext); + int numElements = readCollectionSize(readContext, buffer); setNumElements(numElements); ArrayList list = new ArrayList(numElements); readContext.reference(list); @@ -1151,7 +1151,7 @@ public XlangSetDefaultSerializer(TypeResolver typeResolver, Class cls) { @Override public Set newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readCollectionSize(readContext); + int numElements = readCollectionSize(readContext, buffer); setNumElements(numElements); HashSet set = new HashSet(numElements); readContext.reference(set); diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/GuavaCollectionSerializers.java b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/GuavaCollectionSerializers.java index a5aba71aaa..ceeca02efc 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/GuavaCollectionSerializers.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/GuavaCollectionSerializers.java @@ -94,7 +94,7 @@ public ImmutableListSerializer(TypeResolver typeResolver, Class cls) { @Override public Collection newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readCollectionSize(readContext); + int numElements = readCollectionSize(readContext, buffer); setNumElements(numElements); return new CollectionContainer<>(numElements); } @@ -127,7 +127,7 @@ public RegularImmutableListSerializer(TypeResolver typeResolver, Class cls) { @Override public Collection newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readCollectionSize(readContext); + int numElements = readCollectionSize(readContext, buffer); setNumElements(numElements); return new CollectionContainer(numElements); } @@ -161,7 +161,7 @@ public ImmutableSetSerializer(TypeResolver typeResolver, Class cls) { @Override public Collection newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readCollectionSize(readContext); + int numElements = readCollectionSize(readContext, buffer); setNumElements(numElements); return new CollectionContainer<>(numElements); } @@ -203,7 +203,7 @@ public Collection onCollectionWrite(WriteContext writeContext, T value) { @Override public Collection newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readCollectionSize(readContext); + int numElements = readCollectionSize(readContext, buffer); setNumElements(numElements); Comparator comparator = (Comparator) readContext.readRef(); return new SortedCollectionContainer(comparator, numElements); @@ -236,7 +236,7 @@ public GuavaMapSerializer(TypeResolver typeResolver, Class cls) { @Override public Map newMap(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readMapSize(readContext); + int numElements = readMapSize(readContext, buffer); setNumElements(numElements); return new MapContainer(numElements); } @@ -574,7 +574,7 @@ public Map onMapWrite(WriteContext writeContext, T value) { @Override public Map newMap(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readMapSize(readContext); + int numElements = readMapSize(readContext, buffer); setNumElements(numElements); Comparator comparator = (Comparator) readContext.readRef(); return new SortedMapContainer<>(comparator, numElements); diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/ImmutableCollectionSerializers.java b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/ImmutableCollectionSerializers.java index 7a9f9f017d..65c8dcdacc 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/ImmutableCollectionSerializers.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/ImmutableCollectionSerializers.java @@ -125,7 +125,7 @@ public ImmutableListSerializer(TypeResolver typeResolver, Class cls) { @Override public Collection newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readCollectionSize(readContext); + int numElements = readCollectionSize(readContext, buffer); setNumElements(numElements); if (JdkVersion.MAJOR_VERSION > 8) { return new CollectionContainer<>(numElements); @@ -186,7 +186,7 @@ public ImmutableSetSerializer(TypeResolver typeResolver, Class cls) { @Override public Collection newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readCollectionSize(readContext); + int numElements = readCollectionSize(readContext, buffer); setNumElements(numElements); if (JdkVersion.MAJOR_VERSION > 8) { return new CollectionContainer<>(numElements); @@ -247,7 +247,7 @@ public ImmutableMapSerializer(TypeResolver typeResolver, Class cls) { @Override public Map newMap(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readMapSize(readContext); + int numElements = readMapSize(readContext, buffer); setNumElements(numElements); if (JdkVersion.MAJOR_VERSION > 8) { return new JDKImmutableMapContainer(numElements); diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/MapLikeSerializer.java b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/MapLikeSerializer.java index adf772c257..68074ff8c1 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/MapLikeSerializer.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/MapLikeSerializer.java @@ -58,7 +58,7 @@ @SuppressWarnings({"unchecked", "rawtypes"}) public abstract class MapLikeSerializer extends Serializer { public static final int MAX_CHUNK_SIZE = 255; - private static final int REFERENCE_BYTES = MemoryBuffer.objectArrayIndexScale(); + private static final int REFERENCE_BYTES = 4; static final class MapTypeCache { final TypeInfoHolder keyTypeInfoWriteCache; @@ -896,7 +896,7 @@ public void onMapWriteFinish(Map map) {} */ public Map newMap(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - numElements = readMapSize(readContext); + numElements = readMapSize(readContext, buffer); if (AndroidSupport.IS_ANDROID) { try { Constructor constructor = type.getDeclaredConstructor(); @@ -966,7 +966,10 @@ public void setNumElements(int numElements) { } protected final int readMapSize(ReadContext readContext) { - MemoryBuffer buffer = readContext.getBuffer(); + return readMapSize(readContext, readContext.getBuffer()); + } + + protected final int readMapSize(ReadContext readContext, MemoryBuffer buffer) { int numElements = buffer.readVarUInt32Small7(); checkMapSize(numElements); if (numElements > Integer.MAX_VALUE / 2) { diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/MapSerializers.java b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/MapSerializers.java index 91f9ba2d06..64234a6a47 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/MapSerializers.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/MapSerializers.java @@ -86,7 +86,7 @@ public HashMapSerializer(TypeResolver typeResolver) { @Override public HashMap newMap(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readMapSize(readContext); + int numElements = readMapSize(readContext, buffer); setNumElements(numElements); HashMap hashMap = new HashMap(numElements); readContext.reference(hashMap); @@ -107,7 +107,7 @@ public LinkedHashMapSerializer(TypeResolver typeResolver) { @Override public LinkedHashMap newMap(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readMapSize(readContext); + int numElements = readMapSize(readContext, buffer); setNumElements(numElements); LinkedHashMap hashMap = new LinkedHashMap(numElements); readContext.reference(hashMap); @@ -146,7 +146,7 @@ public LazyMapSerializer(TypeResolver typeResolver) { @Override public LazyMap newMap(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readMapSize(readContext); + int numElements = readMapSize(readContext, buffer); setNumElements(numElements); LazyMap map = new LazyMap(numElements); readContext.reference(map); @@ -200,7 +200,7 @@ public Map onMapWrite(WriteContext writeContext, T value) { @Override public Map newMap(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - setNumElements(readMapSize(readContext)); + setNumElements(readMapSize(readContext, buffer)); T map; Comparator comparator = config.isXlang() ? null : (Comparator) readContext.readRef(); if (type == TreeMap.class) { @@ -322,7 +322,7 @@ public ConcurrentHashMapSerializer(TypeResolver typeResolver, Class keyType = typeResolver.readTypeInfo(readContext).getType(); EnumMap map = new EnumMap(keyType); readContext.reference(map); @@ -619,7 +619,7 @@ public Object onMapCopy(Map map) { public Map newMap(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readMapSize(readContext); + int numElements = readMapSize(readContext, buffer); setNumElements(numElements); HashMap map = new HashMap<>(numElements); readContext.reference(map); diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/SubListSerializers.java b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/SubListSerializers.java index c011c91277..3775f7f3b6 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/SubListSerializers.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/SubListSerializers.java @@ -158,7 +158,7 @@ public List read(ReadContext readContext) { @Override public Collection newCollection(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - int numElements = readCollectionSize(readContext); + int numElements = readCollectionSize(readContext, buffer); setNumElements(numElements); ArrayList list = new ArrayList(numElements); readContext.reference(list); diff --git a/java/fory-core/src/main/java25/org/apache/fory/memory/MemoryBuffer.java b/java/fory-core/src/main/java25/org/apache/fory/memory/MemoryBuffer.java index 4c9347b65e..52c4d0ce50 100644 --- a/java/fory-core/src/main/java25/org/apache/fory/memory/MemoryBuffer.java +++ b/java/fory-core/src/main/java25/org/apache/fory/memory/MemoryBuffer.java @@ -81,7 +81,6 @@ public final class MemoryBuffer { private static final int LONG_ARRAY_OFFSET = 0; private static final int FLOAT_ARRAY_OFFSET = 0; private static final int DOUBLE_ARRAY_OFFSET = 0; - private static final int OBJECT_ARRAY_INDEX_SCALE = 4; private static final VarHandle BYTE_ARRAY_CHAR = MethodHandles.byteArrayViewVarHandle(char[].class, NATIVE_ORDER); private static final VarHandle BYTE_ARRAY_SHORT = @@ -3925,10 +3924,6 @@ public static MemoryBuffer fromDirectByteBuffer( return new MemoryBuffer(offHeapAddress, size, buffer, streamReader); } - public static int objectArrayIndexScale() { - return OBJECT_ARRAY_INDEX_SCALE; - } - /** * Create a heap buffer of specified initial size. The buffer will grow automatically if not * enough. diff --git a/java/fory-core/src/test/java/org/apache/fory/serializer/ContainerMemoryBudgetTest.java b/java/fory-core/src/test/java/org/apache/fory/serializer/ContainerMemoryBudgetTest.java index 1b34a9a600..0c5ecf3e0d 100644 --- a/java/fory-core/src/test/java/org/apache/fory/serializer/ContainerMemoryBudgetTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/serializer/ContainerMemoryBudgetTest.java @@ -40,7 +40,7 @@ public class ContainerMemoryBudgetTest extends ForyTestBase { private static final long KNOWN_ROOT_MULTIPLIER = 8L; private static final long KNOWN_ROOT_SLACK_BYTES = 64L * 1024; private static final long STREAM_ROOT_BYTES = 128L * 1024 * 1024; - private static final int REFERENCE_BYTES = MemoryBuffer.objectArrayIndexScale(); + private static final int REFERENCE_BYTES = 4; @Test public void testConfigValidation() { From f9478ddefb46ed3cb455f408f436f5c619f73044 Mon Sep 17 00:00:00 2001 From: chaokunyang Date: Tue, 30 Jun 2026 21:13:38 +0800 Subject: [PATCH 11/54] fix: remove stray Java rebase artifacts --- .../org/apache/fory/builder/CodecUtils.java | 48 +++++-------------- .../org/apache/fory/context/WriteContext.java | 30 ------------ 2 files changed, 13 insertions(+), 65 deletions(-) diff --git a/java/fory-core/src/main/java/org/apache/fory/builder/CodecUtils.java b/java/fory-core/src/main/java/org/apache/fory/builder/CodecUtils.java index 838d4cc34a..c9092f6723 100644 --- a/java/fory-core/src/main/java/org/apache/fory/builder/CodecUtils.java +++ b/java/fory-core/src/main/java/org/apache/fory/builder/CodecUtils.java @@ -19,6 +19,7 @@ package org.apache.fory.builder; +import java.util.Collections; import java.util.concurrent.Callable; import java.util.concurrent.ConcurrentHashMap; import org.apache.fory.Fory; @@ -26,10 +27,7 @@ import org.apache.fory.codegen.CompileUnit; import org.apache.fory.collection.Tuple3; import org.apache.fory.meta.TypeDef; -import org.apache.fory.platform.AndroidSupport; import org.apache.fory.platform.GraalvmSupport; -import org.apache.fory.platform.JdkVersion; -import org.apache.fory.reflect.ReflectionUtils; import org.apache.fory.reflect.TypeRef; import org.apache.fory.resolver.TypeResolver; import org.apache.fory.serializer.Serializer; @@ -120,6 +118,12 @@ public static Class loadOrGenCompatibleLayerCodecClass @SuppressWarnings("unchecked") static Class loadOrGenCodecClass( Class beanClass, Fory fory, BaseObjectCodecBuilder codecBuilder) { + // use genCodeFunc to avoid gen code repeatedly + CompileUnit compileUnit = + new CompileUnit( + CodeGenerator.getPackage(beanClass), + codecBuilder.codecClassName(beanClass), + codecBuilder::genCode); CodeGenerator codeGenerator; ClassLoader beanClassClassLoader = beanClass.getClassLoader() == null @@ -130,41 +134,15 @@ static Class loadOrGenCodecClass( } TypeResolver typeResolver = fory.getTypeResolver(); codeGenerator = getCodeGenerator(beanClassClassLoader, typeResolver); - Class neighborClass = codecNeighbor(beanClass, beanClassClassLoader); - codecBuilder.setSamePackageAccess(neighborClass != null); - // use genCodeFunc to avoid gen code repeatedly - CompileUnit compileUnit = - new CompileUnit( - CodeGenerator.getPackage(beanClass), - codecBuilder.codecClassName(beanClass), - codecBuilder::genCode, - neighborClass); - return (Class) - codeGenerator.compileAndLoad(compileUnit, compileState -> compileState.lock.lock()); - } - - private static Class codecNeighbor(Class beanClass, ClassLoader beanClassClassLoader) { - // Hidden generated serializers are only a JDK25+ path for source-non-public bean classes. - // Source-public beans keep the normal generated class so split helpers remain private; JDK25 - // VarHandle field access is independent of hidden class definition. - if (AndroidSupport.IS_ANDROID - || JdkVersion.MAJOR_VERSION < 25 - || beanClass.getClassLoader() == null - || CodeGenerator.sourcePublicAccessible(beanClass)) { - return null; - } - if (!CodeGenerator.getPackage(beanClass).equals(ReflectionUtils.getPackage(beanClass))) { - return null; - } + ClassLoader classLoader = + codeGenerator.compile( + Collections.singletonList(compileUnit), compileState -> compileState.lock.lock()); + String className = codecBuilder.codecQualifiedClassName(beanClass); try { - // A generated serializer defined in the bean loader must resolve Fory runtime classes there. - if (beanClassClassLoader.loadClass(Fory.class.getName()) == Fory.class) { - return beanClass; - } + return (Class) classLoader.loadClass(className); } catch (ClassNotFoundException e) { - // The composed-loader path remains the owner when the bean loader cannot see Fory directly. + throw new IllegalStateException("Impossible because we just compiled class", e); } - return null; } private static CodeGenerator getCodeGenerator( diff --git a/java/fory-core/src/main/java/org/apache/fory/context/WriteContext.java b/java/fory-core/src/main/java/org/apache/fory/context/WriteContext.java index 9147d87409..bf74c0a67a 100644 --- a/java/fory-core/src/main/java/org/apache/fory/context/WriteContext.java +++ b/java/fory-core/src/main/java/org/apache/fory/context/WriteContext.java @@ -555,36 +555,6 @@ public void writeRootRef(Object obj) { writeData(typeInfo, obj); } - /** - * Writes the root object for one serialization operation. - * - *

Root no-ref serialization still owns the null marker and type metadata, but it must not pay - * generic {@link RefWriter} dispatch or uncached type-info lookup when reference tracking is - * disabled. - */ - public void writeRootRef(Object obj) { - if (trackingRef) { - writeRef(obj, rootTypeInfoHolder); - return; - } - MemoryBuffer buffer = this.buffer; - if (obj == null) { - buffer.writeByte(Fory.NULL_FLAG); - return; - } - buffer.writeByte(Fory.NOT_NULL_VALUE_FLAG); - TypeResolver resolver = typeResolver; - TypeInfo typeInfo = resolver.getTypeInfo(obj.getClass(), rootTypeInfoHolder); - if (crossLanguage && typeInfo.getType() == UnknownStruct.class) { - depth++; - typeInfo.getSerializer().write(this, obj); - depth--; - return; - } - resolver.writeTypeInfo(this, typeInfo); - writeData(typeInfo, obj); - } - /** * Writes a non-null, first-seen object together with its type metadata. * From 41280043d5a67a1d9d2680cf8a4e5b165ed2e5e6 Mon Sep 17 00:00:00 2001 From: chaokunyang Date: Tue, 30 Jun 2026 21:30:37 +0800 Subject: [PATCH 12/54] refactor: clarify container budget reservation names --- .../test/container_memory_budget_test.dart | 8 +-- javascript/test/containerMemoryBudget.test.ts | 8 +-- .../scala/CollectionSerializerTest.scala | 4 +- .../Sources/Fory/CollectionSerializers.swift | 32 +++++------ swift/Sources/Fory/FieldCodecs.swift | 56 +++++++++---------- 5 files changed, 54 insertions(+), 54 deletions(-) diff --git a/dart/packages/fory/test/container_memory_budget_test.dart b/dart/packages/fory/test/container_memory_budget_test.dart index e8dc46bc71..52463492f7 100644 --- a/dart/packages/fory/test/container_memory_budget_test.dart +++ b/dart/packages/fory/test/container_memory_budget_test.dart @@ -154,21 +154,21 @@ void main() { expect(_readWithBudget(value, 4), equals(value)); }); - test('charges sibling containers cumulatively', () { + test('reserves sibling containers cumulatively', () { final value = [[], [], []]; expect(() => _readWithBudget(value, 11), _throwsContainerBudget); expect(_readWithBudget(value, 12), equals(value)); }); - test('charges map entries', () { + test('reserves map entries', () { final value = {'a': 1}; expect(() => _readWithBudget(value, 7), _throwsContainerBudget); expect(_readWithBudget(value, 8), equals(value)); }); - test('charges generated list set and map reads', () { + test('reserves generated list set and map reads', () { final writer = Fory(); _registerGenerated(writer); final bytes = writer.serialize( @@ -195,7 +195,7 @@ void main() { expect(roundTrip.counts, equals({'one': 1})); }); - test('charges compatible list array materialization', () { + test('reserves compatible list array materialization', () { final listWriter = Fory(); _registerCompatibleList(listWriter); final listBytes = listWriter.serialize( diff --git a/javascript/test/containerMemoryBudget.test.ts b/javascript/test/containerMemoryBudget.test.ts index e79f8fc6f9..499d4750b6 100644 --- a/javascript/test/containerMemoryBudget.test.ts +++ b/javascript/test/containerMemoryBudget.test.ts @@ -87,7 +87,7 @@ describe('container memory budget', () => { expect(passingReader.deserialize(bytes)).toEqual({ values: [[]] }); }); - test('charges sibling containers cumulatively', () => { + test('reserves sibling containers cumulatively', () => { const typeInfo = Type.struct('budget.sibling.empty', { values: Type.list(Type.list(Type.int32({ encoding: 'fixed' }))).setId(1), }); @@ -114,14 +114,14 @@ describe('container memory budget', () => { }); }); - test('charges map entries', () => { + test('reserves map entries', () => { const bytes = serializeAny(new Map([[1, 2]])); expect(() => deserializeAny(bytes, 7)).toThrow(/maxContainerMemoryBytes/); expect(deserializeAny(bytes, 8)).toEqual(new Map([[1, 2]])); }); - test('charges generated containers', () => { + test('reserves generated containers', () => { const typeInfo = Type.struct('budget.generated', { list: Type.list(Type.int32({ encoding: 'fixed' })).setId(1), set: Type.set(Type.string()).setId(2), @@ -154,7 +154,7 @@ describe('container memory budget', () => { }); }); - test('charges compatible typed arrays', () => { + test('reserves compatible typed arrays', () => { const writerType = Type.struct(9010, { values: Type.list(Type.int32({ encoding: 'fixed' })).setId(1), }); diff --git a/scala/src/test/scala/org/apache/fory/serializer/scala/CollectionSerializerTest.scala b/scala/src/test/scala/org/apache/fory/serializer/scala/CollectionSerializerTest.scala index 70078534d2..cdae015153 100644 --- a/scala/src/test/scala/org/apache/fory/serializer/scala/CollectionSerializerTest.scala +++ b/scala/src/test/scala/org/apache/fory/serializer/scala/CollectionSerializerTest.scala @@ -105,7 +105,7 @@ class CollectionSerializerTest extends AnyWordSpec with Matchers { builder.build() } - "charge scala collection storage" in { + "reserve scala collection storage" in { val writer = runtime() val reader = runtime(maxContainerMemoryBytes = 23) intercept[InsecureException] { @@ -113,7 +113,7 @@ class CollectionSerializerTest extends AnyWordSpec with Matchers { } } - "charge scala map storage" in { + "reserve scala map storage" in { val writer = runtime() val reader = runtime(maxContainerMemoryBytes = 23) intercept[InsecureException] { diff --git a/swift/Sources/Fory/CollectionSerializers.swift b/swift/Sources/Fory/CollectionSerializers.swift index fca71441b4..86d41af3a0 100644 --- a/swift/Sources/Fory/CollectionSerializers.swift +++ b/swift/Sources/Fory/CollectionSerializers.swift @@ -272,32 +272,32 @@ func writePrimitiveArray(_ value: [Element], context: Write @inline(__always) private func preparePrimitiveArray( _ context: ReadContext, - chargeContainerMemory: Bool, + reserveContainerStorage: Bool, type: Element.Type, count: Int, label: String ) throws { try context.ensureCollectionLength(count, label: label) - if chargeContainerMemory { + if reserveContainerStorage { try reserveContainerArrayMemory(context, type, count: count) } } func readPrimitiveArray( _ context: ReadContext, - chargeContainerMemory: Bool = false + reserveContainerStorage: Bool = false ) throws -> [Element] { let byteSize = Int(try context.buffer.readVarUInt32()) try context.ensureRemainingBytes(byteSize, label: "primitive_array_bytes") if Element.self == UInt8.self { - try preparePrimitiveArray(context, chargeContainerMemory: chargeContainerMemory, type: Element.self, count: byteSize, label: "uint8_array") + try preparePrimitiveArray(context, reserveContainerStorage: reserveContainerStorage, type: Element.self, count: byteSize, label: "uint8_array") let bytes = try context.buffer.readBytes(count: byteSize) return uncheckedArrayCast(bytes, to: Element.self) } if Element.self == Bool.self { - try preparePrimitiveArray(context, chargeContainerMemory: chargeContainerMemory, type: Element.self, count: byteSize, label: "bool_array") + try preparePrimitiveArray(context, reserveContainerStorage: reserveContainerStorage, type: Element.self, count: byteSize, label: "bool_array") let out = try readArrayUninitialized(count: byteSize) { destination in for index in 0..( } if Element.self == Int8.self { - try preparePrimitiveArray(context, chargeContainerMemory: chargeContainerMemory, type: Element.self, count: byteSize, label: "int8_array") + try preparePrimitiveArray(context, reserveContainerStorage: reserveContainerStorage, type: Element.self, count: byteSize, label: "int8_array") var out = Array(repeating: Int8(0), count: byteSize) try out.withUnsafeMutableBytes { rawBytes in try context.buffer.readBytes(into: rawBytes) @@ -318,7 +318,7 @@ func readPrimitiveArray( if Element.self == Int16.self { if byteSize % 2 != 0 { throw ForyError.invalidData("int16 array byte size mismatch") } let count = byteSize / 2 - try preparePrimitiveArray(context, chargeContainerMemory: chargeContainerMemory, type: Element.self, count: count, label: "int16_array") + try preparePrimitiveArray(context, reserveContainerStorage: reserveContainerStorage, type: Element.self, count: count, label: "int16_array") if hostIsLittleEndian { var out = Array(repeating: Int16(0), count: count) try out.withUnsafeMutableBytes { rawBytes in @@ -337,7 +337,7 @@ func readPrimitiveArray( if Element.self == Int32.self { if byteSize % 4 != 0 { throw ForyError.invalidData("int32 array byte size mismatch") } let count = byteSize / 4 - try preparePrimitiveArray(context, chargeContainerMemory: chargeContainerMemory, type: Element.self, count: count, label: "int32_array") + try preparePrimitiveArray(context, reserveContainerStorage: reserveContainerStorage, type: Element.self, count: count, label: "int32_array") if hostIsLittleEndian { var out = Array(repeating: Int32(0), count: count) try out.withUnsafeMutableBytes { rawBytes in @@ -356,7 +356,7 @@ func readPrimitiveArray( if Element.self == UInt32.self { if byteSize % 4 != 0 { throw ForyError.invalidData("uint32 array byte size mismatch") } let count = byteSize / 4 - try preparePrimitiveArray(context, chargeContainerMemory: chargeContainerMemory, type: Element.self, count: count, label: "uint32_array") + try preparePrimitiveArray(context, reserveContainerStorage: reserveContainerStorage, type: Element.self, count: count, label: "uint32_array") if hostIsLittleEndian { var out = Array(repeating: UInt32(0), count: count) try out.withUnsafeMutableBytes { rawBytes in @@ -375,7 +375,7 @@ func readPrimitiveArray( if Element.self == Int64.self { if byteSize % 8 != 0 { throw ForyError.invalidData("int64 array byte size mismatch") } let count = byteSize / 8 - try preparePrimitiveArray(context, chargeContainerMemory: chargeContainerMemory, type: Element.self, count: count, label: "int64_array") + try preparePrimitiveArray(context, reserveContainerStorage: reserveContainerStorage, type: Element.self, count: count, label: "int64_array") if hostIsLittleEndian { var out = Array(repeating: Int64(0), count: count) try out.withUnsafeMutableBytes { rawBytes in @@ -394,7 +394,7 @@ func readPrimitiveArray( if Element.self == UInt64.self { if byteSize % 8 != 0 { throw ForyError.invalidData("uint64 array byte size mismatch") } let count = byteSize / 8 - try preparePrimitiveArray(context, chargeContainerMemory: chargeContainerMemory, type: Element.self, count: count, label: "uint64_array") + try preparePrimitiveArray(context, reserveContainerStorage: reserveContainerStorage, type: Element.self, count: count, label: "uint64_array") if hostIsLittleEndian { var out = Array(repeating: UInt64(0), count: count) try out.withUnsafeMutableBytes { rawBytes in @@ -413,7 +413,7 @@ func readPrimitiveArray( if Element.self == UInt16.self { if byteSize % 2 != 0 { throw ForyError.invalidData("uint16 array byte size mismatch") } let count = byteSize / 2 - try preparePrimitiveArray(context, chargeContainerMemory: chargeContainerMemory, type: Element.self, count: count, label: "uint16_array") + try preparePrimitiveArray(context, reserveContainerStorage: reserveContainerStorage, type: Element.self, count: count, label: "uint16_array") if hostIsLittleEndian { var out = Array(repeating: UInt16(0), count: count) try out.withUnsafeMutableBytes { rawBytes in @@ -432,7 +432,7 @@ func readPrimitiveArray( if Element.self == Float16.self { if byteSize % 2 != 0 { throw ForyError.invalidData("float16 array byte size mismatch") } let count = byteSize / 2 - try preparePrimitiveArray(context, chargeContainerMemory: chargeContainerMemory, type: Element.self, count: count, label: "float16_array") + try preparePrimitiveArray(context, reserveContainerStorage: reserveContainerStorage, type: Element.self, count: count, label: "float16_array") let values = try readArrayUninitialized(count: count) { destination in for index in 0..( if Element.self == BFloat16.self { if byteSize % 2 != 0 { throw ForyError.invalidData("bfloat16 array byte size mismatch") } let count = byteSize / 2 - try preparePrimitiveArray(context, chargeContainerMemory: chargeContainerMemory, type: Element.self, count: count, label: "bfloat16_array") + try preparePrimitiveArray(context, reserveContainerStorage: reserveContainerStorage, type: Element.self, count: count, label: "bfloat16_array") let values = try readArrayUninitialized(count: count) { destination in for index in 0..( if Element.self == Float.self { if byteSize % 4 != 0 { throw ForyError.invalidData("float32 array byte size mismatch") } let count = byteSize / 4 - try preparePrimitiveArray(context, chargeContainerMemory: chargeContainerMemory, type: Element.self, count: count, label: "float32_array") + try preparePrimitiveArray(context, reserveContainerStorage: reserveContainerStorage, type: Element.self, count: count, label: "float32_array") if hostIsLittleEndian { var out = Array(repeating: Float(0), count: count) try out.withUnsafeMutableBytes { rawBytes in @@ -474,7 +474,7 @@ func readPrimitiveArray( if byteSize % 8 != 0 { throw ForyError.invalidData("float64 array byte size mismatch") } let count = byteSize / 8 - try preparePrimitiveArray(context, chargeContainerMemory: chargeContainerMemory, type: Element.self, count: count, label: "float64_array") + try preparePrimitiveArray(context, reserveContainerStorage: reserveContainerStorage, type: Element.self, count: count, label: "float64_array") if hostIsLittleEndian { var out = Array(repeating: Double(0), count: count) try out.withUnsafeMutableBytes { rawBytes in diff --git a/swift/Sources/Fory/FieldCodecs.swift b/swift/Sources/Fory/FieldCodecs.swift index 73b7dd753e..d79f06df96 100644 --- a/swift/Sources/Fory/FieldCodecs.swift +++ b/swift/Sources/Fory/FieldCodecs.swift @@ -30,7 +30,7 @@ private func serializerElementBytes(_ type: Element.Type) - } @inline(__always) -private func chargeFieldArrayStorage( +private func reserveFieldArrayStorage( _ context: ReadContext, _ codec: ElementCodec.Type, count: Int @@ -48,7 +48,7 @@ private func reserveSerializerArrayMemory( } @inline(__always) -private func chargeFieldMapStorage( +private func reserveFieldMapStorage( _ context: ReadContext, key: KeyCodec.Type, value: ValueCodec.Type, @@ -888,7 +888,7 @@ public enum SetFieldCodec: FieldCodec where ElementCod public static func readPayload(_ context: ReadContext) throws -> Value { let values = try readCollectionPayload(context, elementCodec: ElementCodec.self) - try chargeFieldArrayStorage(context, ElementCodec.self, count: values.count) + try reserveFieldArrayStorage(context, ElementCodec.self, count: values.count) return Set(values) } } @@ -1008,11 +1008,11 @@ where KeyCodec.Value: Hashable { let totalLength = Int(try context.buffer.readVarUInt32()) try context.ensureCollectionLength(totalLength, label: "map") if totalLength == 0 { - try chargeFieldMapStorage(context, key: KeyCodec.self, value: ValueCodec.self, count: totalLength) + try reserveFieldMapStorage(context, key: KeyCodec.self, value: ValueCodec.self, count: totalLength) return [:] } - try chargeFieldMapStorage(context, key: KeyCodec.self, value: ValueCodec.self, count: totalLength) + try reserveFieldMapStorage(context, key: KeyCodec.self, value: ValueCodec.self, count: totalLength) try context.ensureRemainingBytes(totalLength, label: "map") var map: Value = [:] map.reserveCapacity(totalLength) @@ -1374,9 +1374,9 @@ private func writeUIntArrayPayload(_ value: [UInt], _ context: WriteContext) { } } -private func readIntArrayPayload(_ context: ReadContext, chargeContainerMemory: Bool = false) throws -> [Int] { +private func readIntArrayPayload(_ context: ReadContext, reserveContainerStorage: Bool = false) throws -> [Int] { let count = try readPackedArrayElementCount(context, width: 8, label: "int64_array") - if chargeContainerMemory { + if reserveContainerStorage { try reserveSerializerArrayMemory(context, Int.self, count: count) } var values: [Int] = [] @@ -1387,9 +1387,9 @@ private func readIntArrayPayload(_ context: ReadContext, chargeContainerMemory: return values } -private func readUIntArrayPayload(_ context: ReadContext, chargeContainerMemory: Bool = false) throws -> [UInt] { +private func readUIntArrayPayload(_ context: ReadContext, reserveContainerStorage: Bool = false) throws -> [UInt] { let count = try readPackedArrayElementCount(context, width: 8, label: "uint64_array") - if chargeContainerMemory { + if reserveContainerStorage { try reserveSerializerArrayMemory(context, UInt.self, count: count) } var values: [UInt] = [] @@ -1435,49 +1435,49 @@ private func readCompatiblePackedArrayPayload( elementCodec _: ElementCodec.Type ) throws -> [ElementCodec.Value] { if ElementCodec.self == BoolCodec.self { - return uncheckedPackedArrayCast(try readPrimitiveArray(context, chargeContainerMemory: true) as [Bool], to: ElementCodec.Value.self) + return uncheckedPackedArrayCast(try readPrimitiveArray(context, reserveContainerStorage: true) as [Bool], to: ElementCodec.Value.self) } if ElementCodec.self == Int8Codec.self { - return uncheckedPackedArrayCast(try readPrimitiveArray(context, chargeContainerMemory: true) as [Int8], to: ElementCodec.Value.self) + return uncheckedPackedArrayCast(try readPrimitiveArray(context, reserveContainerStorage: true) as [Int8], to: ElementCodec.Value.self) } if ElementCodec.self == Int16Codec.self { - return uncheckedPackedArrayCast(try readPrimitiveArray(context, chargeContainerMemory: true) as [Int16], to: ElementCodec.Value.self) + return uncheckedPackedArrayCast(try readPrimitiveArray(context, reserveContainerStorage: true) as [Int16], to: ElementCodec.Value.self) } if ElementCodec.self == Int32FixedCodec.self || ElementCodec.self == Int32VarintCodec.self { - return uncheckedPackedArrayCast(try readPrimitiveArray(context, chargeContainerMemory: true) as [Int32], to: ElementCodec.Value.self) + return uncheckedPackedArrayCast(try readPrimitiveArray(context, reserveContainerStorage: true) as [Int32], to: ElementCodec.Value.self) } if ElementCodec.self == Int64FixedCodec.self || ElementCodec.self == Int64VarintCodec.self || ElementCodec.self == Int64TaggedCodec.self { - return uncheckedPackedArrayCast(try readPrimitiveArray(context, chargeContainerMemory: true) as [Int64], to: ElementCodec.Value.self) + return uncheckedPackedArrayCast(try readPrimitiveArray(context, reserveContainerStorage: true) as [Int64], to: ElementCodec.Value.self) } if ElementCodec.self == IntFixedCodec.self || ElementCodec.self == IntVarintCodec.self || ElementCodec.self == IntTaggedCodec.self { - return uncheckedPackedArrayCast(try readIntArrayPayload(context, chargeContainerMemory: true), to: ElementCodec.Value.self) + return uncheckedPackedArrayCast(try readIntArrayPayload(context, reserveContainerStorage: true), to: ElementCodec.Value.self) } if ElementCodec.self == UInt8Codec.self { - return uncheckedPackedArrayCast(try readPrimitiveArray(context, chargeContainerMemory: true) as [UInt8], to: ElementCodec.Value.self) + return uncheckedPackedArrayCast(try readPrimitiveArray(context, reserveContainerStorage: true) as [UInt8], to: ElementCodec.Value.self) } if ElementCodec.self == UInt16Codec.self { - return uncheckedPackedArrayCast(try readPrimitiveArray(context, chargeContainerMemory: true) as [UInt16], to: ElementCodec.Value.self) + return uncheckedPackedArrayCast(try readPrimitiveArray(context, reserveContainerStorage: true) as [UInt16], to: ElementCodec.Value.self) } if ElementCodec.self == UInt32FixedCodec.self || ElementCodec.self == UInt32VarintCodec.self { - return uncheckedPackedArrayCast(try readPrimitiveArray(context, chargeContainerMemory: true) as [UInt32], to: ElementCodec.Value.self) + return uncheckedPackedArrayCast(try readPrimitiveArray(context, reserveContainerStorage: true) as [UInt32], to: ElementCodec.Value.self) } if ElementCodec.self == UInt64FixedCodec.self || ElementCodec.self == UInt64VarintCodec.self || ElementCodec.self == UInt64TaggedCodec.self { - return uncheckedPackedArrayCast(try readPrimitiveArray(context, chargeContainerMemory: true) as [UInt64], to: ElementCodec.Value.self) + return uncheckedPackedArrayCast(try readPrimitiveArray(context, reserveContainerStorage: true) as [UInt64], to: ElementCodec.Value.self) } if ElementCodec.self == UIntFixedCodec.self || ElementCodec.self == UIntVarintCodec.self || ElementCodec.self == UIntTaggedCodec.self { - return uncheckedPackedArrayCast(try readUIntArrayPayload(context, chargeContainerMemory: true), to: ElementCodec.Value.self) + return uncheckedPackedArrayCast(try readUIntArrayPayload(context, reserveContainerStorage: true), to: ElementCodec.Value.self) } if ElementCodec.self == Float16Codec.self { - return uncheckedPackedArrayCast(try readPrimitiveArray(context, chargeContainerMemory: true) as [Float16], to: ElementCodec.Value.self) + return uncheckedPackedArrayCast(try readPrimitiveArray(context, reserveContainerStorage: true) as [Float16], to: ElementCodec.Value.self) } if ElementCodec.self == BFloat16Codec.self { - return uncheckedPackedArrayCast(try readPrimitiveArray(context, chargeContainerMemory: true) as [BFloat16], to: ElementCodec.Value.self) + return uncheckedPackedArrayCast(try readPrimitiveArray(context, reserveContainerStorage: true) as [BFloat16], to: ElementCodec.Value.self) } if ElementCodec.self == FloatCodec.self { - return uncheckedPackedArrayCast(try readPrimitiveArray(context, chargeContainerMemory: true) as [Float], to: ElementCodec.Value.self) + return uncheckedPackedArrayCast(try readPrimitiveArray(context, reserveContainerStorage: true) as [Float], to: ElementCodec.Value.self) } if ElementCodec.self == DoubleCodec.self { - return uncheckedPackedArrayCast(try readPrimitiveArray(context, chargeContainerMemory: true) as [Double], to: ElementCodec.Value.self) + return uncheckedPackedArrayCast(try readPrimitiveArray(context, reserveContainerStorage: true) as [Double], to: ElementCodec.Value.self) } throw ForyError.invalidData("unsupported compatible array-to-list field element codec \(ElementCodec.self)") } @@ -1648,7 +1648,7 @@ private func readCollectionPayload( let length = Int(try buffer.readVarUInt32()) try context.ensureCollectionLength(length, label: "array") if length == 0 { - try chargeFieldArrayStorage(context, ElementCodec.self, count: length) + try reserveFieldArrayStorage(context, ElementCodec.self, count: length) return [] } @@ -1663,7 +1663,7 @@ private func readCollectionPayload( let sameType = (header & CollectionHeader.sameType) != 0 var result: [ElementCodec.Value] = [] - try chargeFieldArrayStorage(context, ElementCodec.self, count: length) + try reserveFieldArrayStorage(context, ElementCodec.self, count: length) try context.ensureRemainingBytes(length, label: "array") result.reserveCapacity(length) @@ -1749,7 +1749,7 @@ private func readListPayloadAsArrayPayload( let length = Int(try buffer.readVarUInt32()) try context.ensureCollectionLength(length, label: "array") if length == 0 { - try chargeFieldArrayStorage(context, ElementCodec.self, count: length) + try reserveFieldArrayStorage(context, ElementCodec.self, count: length) return [] } @@ -1777,7 +1777,7 @@ private func readListPayloadAsArrayPayload( } try context.ensureRemainingBytes(length, label: "array") var result: [ElementCodec.Value] = [] - try chargeFieldArrayStorage(context, ElementCodec.self, count: length) + try reserveFieldArrayStorage(context, ElementCodec.self, count: length) result.reserveCapacity(length) return try ElementCodec.withTypeInfo(elementTypeInfo, context) { for _ in 0.. Date: Tue, 30 Jun 2026 21:33:32 +0800 Subject: [PATCH 13/54] refactor: remove redundant Java size readers --- .../fory/serializer/collection/CollectionLikeSerializer.java | 4 ---- .../fory/serializer/collection/CollectionSerializers.java | 5 +++-- .../serializer/collection/GuavaCollectionSerializers.java | 3 ++- .../apache/fory/serializer/collection/MapLikeSerializer.java | 4 ---- 4 files changed, 5 insertions(+), 11 deletions(-) diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/CollectionLikeSerializer.java b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/CollectionLikeSerializer.java index 99cf798c9f..10c39dc49f 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/CollectionLikeSerializer.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/CollectionLikeSerializer.java @@ -562,10 +562,6 @@ protected void setNumElements(int numElements) { this.numElements = numElements; } - protected final int readCollectionSize(ReadContext readContext) { - return readCollectionSize(readContext, readContext.getBuffer()); - } - protected final int readCollectionSize(ReadContext readContext, MemoryBuffer buffer) { int numElements = buffer.readVarUInt32Small7(); checkCollectionSize(numElements); diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/CollectionSerializers.java b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/CollectionSerializers.java index 737e79d2bb..5fcbffd1e5 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/CollectionSerializers.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/CollectionSerializers.java @@ -337,7 +337,8 @@ public void write(WriteContext writeContext, List value) { @Override public List read(ReadContext readContext) { if (config.isXlang()) { - int numElements = readCollectionSize(readContext); + MemoryBuffer buffer = readContext.getBuffer(); + int numElements = readCollectionSize(readContext, buffer); if (numElements != 0) { throw new DeserializationException( "Empty list body must have zero elements but got " + numElements); @@ -998,7 +999,7 @@ public LinkedBlockingQueue newCollection(ReadContext readContext) { int capacity = buffer.readVarUInt32Small7(); checkBoundedQueueCapacity(numElements, capacity); // LinkedBlockingQueue capacity is a logical bound, not preallocated backing storage. The - // current node storage is already charged by readCollectionSize(numElements). + // current node storage is already reserved by readCollectionSize(numElements). LinkedBlockingQueue queue = new LinkedBlockingQueue<>(capacity); readContext.reference(queue); return queue; diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/GuavaCollectionSerializers.java b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/GuavaCollectionSerializers.java index ceeca02efc..52ce79a937 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/GuavaCollectionSerializers.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/GuavaCollectionSerializers.java @@ -264,7 +264,8 @@ public T onMapRead(Map map) { @Override public T read(ReadContext readContext) { - int size = readMapSize(readContext); + MemoryBuffer buffer = readContext.getBuffer(); + int size = readMapSize(readContext, buffer); Map map = new HashMap(); readElements(readContext, size, map); return xnewInstance(map); diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/MapLikeSerializer.java b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/MapLikeSerializer.java index 68074ff8c1..4c1fc65ff9 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/MapLikeSerializer.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/MapLikeSerializer.java @@ -965,10 +965,6 @@ public void setNumElements(int numElements) { this.numElements = numElements; } - protected final int readMapSize(ReadContext readContext) { - return readMapSize(readContext, readContext.getBuffer()); - } - protected final int readMapSize(ReadContext readContext, MemoryBuffer buffer) { int numElements = buffer.readVarUInt32Small7(); checkMapSize(numElements); From e5bec05e4fef5ba3ae1b1c85a56cb4a3f8a1e93e Mon Sep 17 00:00:00 2001 From: chaokunyang Date: Tue, 30 Jun 2026 23:57:41 +0800 Subject: [PATCH 14/54] feat: add root graph memory budget --- .agents/languages/cpp.md | 33 +- .agents/languages/csharp.md | 20 +- .agents/languages/dart.md | 15 +- .agents/languages/go.md | 22 +- .agents/languages/java.md | 28 +- .agents/languages/javascript.md | 16 +- .agents/languages/python.md | 17 +- .agents/languages/rust.md | 23 +- .agents/languages/swift.md | 19 +- AGENTS.md | 4 +- cpp/fory/serialization/BUILD | 4 +- cpp/fory/serialization/CMakeLists.txt | 8 +- .../serialization/collection_serializer.h | 10 +- cpp/fory/serialization/config.h | 4 +- cpp/fory/serialization/context.cc | 47 +- cpp/fory/serialization/context.h | 95 +- cpp/fory/serialization/fory.h | 21 +- ...et_test.cc => graph_memory_budget_test.cc} | 199 +- cpp/fory/serialization/map_serializer.h | 5 +- cpp/fory/serialization/serializer_traits.h | 110 +- .../serialization/smart_ptr_serializers.h | 51 + cpp/fory/serialization/type_resolver.h | 12 + .../src/Fory.Generator/ForyModelGenerator.cs | 64 +- csharp/src/Fory/CollectionSerializers.cs | 4 +- csharp/src/Fory/Config.cs | 26 +- csharp/src/Fory/DictionarySerializers.cs | 3 +- csharp/src/Fory/Fory.cs | 7 +- csharp/src/Fory/GraphMemory.cs | 72 + csharp/src/Fory/NullableKeyDictionary.cs | 3 +- .../Fory/PrimitiveDictionarySerializers.cs | 5 +- csharp/src/Fory/ReadContext.cs | 46 +- csharp/src/Fory/TypeInfo.cs | 30 +- ...dgetTests.cs => GraphMemoryBudgetTests.cs} | 94 +- .../fory/lib/src/codegen/fory_generator.dart | 12 + dart/packages/fory/lib/src/config.dart | 12 +- .../fory/lib/src/context/read_context.dart | 42 +- dart/packages/fory/lib/src/fory.dart | 12 +- .../serializer/collection_serializers.dart | 40 +- .../lib/src/serializer/map_serializers.dart | 2 +- ...est.dart => graph_memory_budget_test.dart} | 107 +- docs/guide/cpp/configuration.md | 26 +- docs/guide/csharp/configuration.md | 13 +- docs/guide/dart/configuration.md | 21 +- docs/guide/go/configuration.md | 17 +- docs/guide/java/configuration.md | 11 +- docs/guide/javascript/configuration.md | 29 +- docs/guide/python/configuration.md | 18 +- docs/guide/rust/configuration.md | 24 +- docs/guide/swift/configuration.md | 16 +- docs/security/deserialization.md | 87 +- .../xlang_implementation_guide.md | 60 +- go/fory/README.md | 6 +- go/fory/array.go | 2 +- go/fory/codegen/decoder.go | 26 +- go/fory/codegen/generator.go | 22 +- go/fory/container_memory_budget_test.go | 227 -- go/fory/field_serializer.go | 2 +- go/fory/fory.go | 52 +- go/fory/graph_memory_budget_test.go | 295 ++ go/fory/map.go | 4 +- go/fory/map_primitive.go | 36 +- go/fory/pointer.go | 6 + go/fory/reader.go | 140 +- go/fory/set.go | 8 +- go/fory/slice.go | 4 +- go/fory/slice_dyn.go | 8 +- go/fory/slice_primitive.go | 2 +- go/fory/slice_primitive_list.go | 8 +- go/fory/stream.go | 10 +- go/fory/tests/structs_fory_gen.go | 66 +- .../StaticSerializerSourceWriter.java | 2 + .../fory/builder/ObjectCodecBuilder.java | 1 + .../builder/StaticCompatibleCodecBuilder.java | 2 + .../java/org/apache/fory/config/Config.java | 14 +- .../org/apache/fory/config/ForyBuilder.java | 16 +- .../org/apache/fory/context/ReadContext.java | 46 +- .../serializer/AbstractObjectSerializer.java | 39 + .../fory/serializer/ArraySerializers.java | 3 +- .../CompatibleCollectionArrayReader.java | 5 +- .../CompatibleLayerSerializerBase.java | 1 + .../fory/serializer/CompatibleSerializer.java | 1 + .../fory/serializer/ExceptionSerializers.java | 4 + .../fory/serializer/ObjectSerializer.java | 1 + .../serializer/ObjectStreamSerializer.java | 1 + .../serializer/UnknownClassSerializers.java | 11 +- .../collection/CollectionLikeSerializer.java | 3 +- .../collection/CollectionSerializers.java | 2 +- .../collection/MapLikeSerializer.java | 3 +- ...etTest.java => GraphMemoryBudgetTest.java} | 117 +- javascript/packages/core/lib/context.ts | 227 +- javascript/packages/core/lib/fory.ts | 64 +- .../packages/core/lib/gen/collection.ts | 87 +- javascript/packages/core/lib/gen/ext.ts | 20 +- javascript/packages/core/lib/gen/map.ts | 69 +- javascript/packages/core/lib/gen/struct.ts | 209 +- javascript/packages/core/lib/type.ts | 2 +- javascript/test/containerMemoryBudget.test.ts | 226 -- javascript/test/graphMemoryBudget.test.ts | 310 ++ .../ksp/KotlinSerializerSourceWriter.kt | 5 + .../serializer/kotlin/CollectionSerializer.kt | 2 +- .../kotlin/CollectionSerializerTest.kt | 6 +- python/pyfory/_fory.py | 18 +- python/pyfory/collection.pxi | 100 +- python/pyfory/collection.py | 5 +- python/pyfory/context.pxi | 82 +- python/pyfory/context.py | 60 +- python/pyfory/serialization.pyx | 50 +- python/pyfory/serializer.py | 13 +- python/pyfory/struct.pxi | 5 + python/pyfory/struct.py | 5 + ..._budget.py => test_graph_memory_budget.py} | 74 +- rust/fory-core/src/config.rs | 12 +- rust/fory-core/src/context.rs | 70 +- rust/fory-core/src/fory.rs | 16 +- rust/fory-core/src/resolver/type_resolver.rs | 20 + rust/fory-core/src/serializer/arc.rs | 4 + rust/fory-core/src/serializer/array.rs | 12 + rust/fory-core/src/serializer/box_.rs | 4 + rust/fory-core/src/serializer/codec.rs | 30 +- rust/fory-core/src/serializer/collection.rs | 14 +- rust/fory-core/src/serializer/core.rs | 25 + rust/fory-core/src/serializer/heap.rs | 4 + rust/fory-core/src/serializer/list.rs | 19 + rust/fory-core/src/serializer/map.rs | 24 +- rust/fory-core/src/serializer/rc.rs | 4 + rust/fory-core/src/serializer/set.rs | 8 + rust/fory-core/src/serializer/tuple.rs | 10 + rust/fory-derive/src/object/serializer.rs | 22 + rust/tests/tests/mod.rs | 2 +- ..._budget.rs => test_graph_memory_budget.rs} | 104 +- .../scala/internal/ForySerializerMacros.scala | 21 +- .../scala/CollectionSerializer.scala | 2 +- .../fory/serializer/scala/MapSerializer.scala | 2 +- .../scala/XlangCollectionSerializer.scala | 4 +- .../scala/CollectionSerializerTest.scala | 12 +- .../scala/ScalaXlangSerializerTest.scala | 4 +- swift/Sources/Fory/AnySerializer.swift | 1208 +++---- .../Sources/Fory/CollectionSerializers.swift | 1832 +++++----- swift/Sources/Fory/FieldCodecs.swift | 3093 +++++++++-------- swift/Sources/Fory/FieldSkipper.swift | 697 ++-- swift/Sources/Fory/Fory.swift | 84 +- swift/Sources/Fory/ReadContext.swift | 176 +- .../ForyObjectMacroReadGeneration.swift | 1097 +++--- .../ContainerMemoryBudgetTests.swift | 307 -- swift/Tests/ForyTests/ForySwiftTests.swift | 82 +- .../ForyTests/GraphMemoryBudgetTests.swift | 359 ++ 146 files changed, 7721 insertions(+), 6164 deletions(-) rename cpp/fory/serialization/{container_memory_budget_test.cc => graph_memory_budget_test.cc} (56%) create mode 100644 csharp/src/Fory/GraphMemory.cs rename csharp/tests/Fory.Tests/{ContainerMemoryBudgetTests.cs => GraphMemoryBudgetTests.cs} (66%) rename dart/packages/fory/test/{container_memory_budget_test.dart => graph_memory_budget_test.dart} (68%) delete mode 100644 go/fory/container_memory_budget_test.go create mode 100644 go/fory/graph_memory_budget_test.go rename java/fory-core/src/test/java/org/apache/fory/serializer/{ContainerMemoryBudgetTest.java => GraphMemoryBudgetTest.java} (69%) delete mode 100644 javascript/test/containerMemoryBudget.test.ts create mode 100644 javascript/test/graphMemoryBudget.test.ts rename python/pyfory/tests/{test_container_memory_budget.py => test_graph_memory_budget.py} (73%) rename rust/tests/tests/{test_container_memory_budget.rs => test_graph_memory_budget.rs} (67%) delete mode 100644 swift/Tests/ForyTests/ContainerMemoryBudgetTests.swift create mode 100644 swift/Tests/ForyTests/GraphMemoryBudgetTests.swift diff --git a/.agents/languages/cpp.md b/.agents/languages/cpp.md index 30f780e266..be74b77b22 100644 --- a/.agents/languages/cpp.md +++ b/.agents/languages/cpp.md @@ -17,23 +17,24 @@ Load this file when changing `cpp/`, Cython build plumbing, or C++ xlang behavio - Do not redesign alias-based or low-level public type shapes to add convenience methods unless the user explicitly asks for that API change. - For cross-language feature ports, match protocol behavior but use idiomatic C++ ownership and layering instead of mirroring Java structure literally. - Compatible scalar, list-array, and binary/uint8-array adaptations are immediate-field-only. Recursive matched-field comparison for collection elements, array elements, map keys, and map values must require exact nullability, ref tracking, generic arity, and type shape except documented user-type family normalization. -- Root deserialization container budgets are owned by `ReadContext` and initialized by the root - `Fory::deserialize` overload. Keep `max_container_memory_bytes` as `-1 / auto` or a positive - explicit limit; known byte roots use `inputBytes * 8 + 64 KiB`, while stream roots use fixed - `128 MiB`. Reserve estimated container-owned memory before allocation but preserve existing +- Root deserialization graph budgets are owned by `ReadContext` and initialized by the root + `Fory::deserialize` overload. Keep `max_graph_memory_bytes` as `-1 / auto` or a positive explicit + limit; known byte roots use `inputBytes * 8 + 64 KiB`, while stream roots use fixed `128 MiB`. + Reserve estimated shallow graph-owner memory before allocation while preserving existing byte-availability checks and their non-empty metadata ordering. `ReadContext` may expose only raw - byte reservation and generic counted-byte arithmetic; collection/map formulas belong in serializer - owners. Empty containers with no dynamic - backing storage normally charge zero. Skip only dedicated string, - binary, primitive vector, and primitive dense-array owners; `std::vector` is the C++ - standard-container exception and should charge rounded packed-bit storage. General - `std::vector` for non-primitive `T` is inline container storage and must be charged. -- C++ container budget formulas must be portable lower-bound estimates, not STL heap-layout - accounting. Generic collection-like containers charge `count_or_capacity * sizeof(value_type)`, - map-like containers charge `count * (sizeof(key_type) + sizeof(mapped_type))`, and set-like - containers charge `count * sizeof(key_type)`. Do not charge standalone `sizeof(Container)` and do - not add guessed node/header/debug-STL overhead, red-black-tree fields, allocator probing, - object-layout inspection, generic per-entry pointer overhead, or unordered bucket-table guesses. + byte reservation and generic counted-byte arithmetic; collection, map, array, struct, and object + formulas belong in serializer owners. Skip dedicated string, binary, primitive scalar, primitive + vector, and primitive dense-array leaf owners; `std::vector` charges rounded packed-bit + storage. General `std::vector` for non-primitive `T` is inline value storage and must be + reserved by the vector owner. +- C++ graph budget formulas must be portable lower-bound estimates, not STL heap-layout accounting. + Generic collection-like containers reserve `count_or_capacity * sizeof(value_type)`, map-like + containers reserve `count * (sizeof(key_type) + sizeof(mapped_type))`, and set-like containers + reserve `count * sizeof(key_type)`. Root struct/product owners and smart-pointer/box allocation + owners reserve shallow self storage exactly once; nested value serializers reserve only dynamic + storage they allocate, not their own inline self storage again. Do not add guessed + node/header/debug-STL overhead, red-black-tree fields, allocator probing, object-layout + inspection, generic per-entry pointer overhead, or unordered bucket-table guesses. ## Key Paths diff --git a/.agents/languages/csharp.md b/.agents/languages/csharp.md index c29acd822b..7d15edfaf5 100644 --- a/.agents/languages/csharp.md +++ b/.agents/languages/csharp.md @@ -12,13 +12,19 @@ Load this file when changing `csharp/` or C# xlang behavior. - Generated C# gRPC service companions are compiler-owned files that depend on application-provided gRPC packages, not `csharp/src/Fory`. Keep gRPC package references out of the Fory runtime package. - C# generated schema modules are source-file owners. Service companions must use that module's `ThreadSafeFory` and must not introduce namespace-owned aliases or duplicate serializer registration paths. - Compatible scalar, list-array, and binary/uint8-array adaptations are immediate-field-only. Recursive matched-field comparison for collection elements, array elements, map keys, and map values must require exact nullability, ref tracking, generic arity, and type shape except documented user-type family normalization. -- Root deserialization container memory budget state belongs to `ReadContext`. C# public roots are memory-backed today, so auto uses known input length. `ReadContext` may expose only raw byte reservation and generic counted-byte arithmetic; concrete serializers and generated serializers must compute list/array/map byte formulas before calling it. -- For C# container budget formulas, distinguish inline value storage from reference storage: use - cheap value-type size for `List`/`T[]` value paths and the 4-byte reference fallback for - reference paths. Maps charge key plus value storage, linked/hash/tree conversions must not add - guessed node or entry overhead, and empty containers with no backing storage normally charge zero. - Dedicated string, binary, and primitive dense-array serializers stay skipped and rely on byte - availability checks. +- Root deserialization graph memory budget state belongs to `ReadContext`. C# public roots are + memory-backed today, so auto uses known input length. `ReadContext` may expose only raw byte + reservation and generic counted-byte arithmetic; concrete serializers and generated serializers + must compute list, array, map, struct, and object byte formulas before calling it. +- For C# graph budget formulas, distinguish inline value storage from reference storage: use cheap + value-type size for `List`/`T[]` value paths and the 4-byte reference fallback for reference + paths. Class/reference serializers reserve their own shallow self cost plus field storage when + materialized; struct/value serializers do not unconditionally charge self storage because root, + field, list, array, map, set, or box owners reserve the inline storage they own. Maps reserve key + plus value storage, linked/hash/tree conversions must not add guessed node or entry overhead, and + independently materialized collection/map/array owners reserve nonzero shallow self cost. + Dedicated string, binary, primitive scalar, and primitive dense-array serializers stay skipped and + rely on byte availability checks. - When extending C# tests from Java references, prioritize xlang spec behavior and the public C# contract before adding complex Java-specific parity cases. ## Commands diff --git a/.agents/languages/dart.md b/.agents/languages/dart.md index 085cede2a1..5a80cb98e5 100644 --- a/.agents/languages/dart.md +++ b/.agents/languages/dart.md @@ -14,7 +14,20 @@ Load this file when changing `dart/`. - Keep root numeric wrapper defaults separate from generated field metadata. Root wrapper resolution belongs in the builtin resolver, while annotations and generated metadata choose fixed, tagged, or declared-field encodings. - Dart 64-bit carriers are optimized for each platform. Do not replace native extension-type wrappers with allocation-heavy classes or route web/native hot paths through `BigInt` unless the user approves a representation change. - In `Buffer`, cursor, serializer, and generated-code hot paths, prefer direct byte/local integer operations and conditional import/export files over callbacks, records, holder objects, wrapper round-trips, or runtime platform branches. -- Root deserialization container memory budgets are owned by `ReadContext`; `maxContainerMemoryBytes` defaults to `-1 / auto`, positive explicit values win, and Dart auto uses `buffer.readableBytes * 8 + 64 KiB` because roots are memory-backed. `ReadContext` may expose only raw byte reservation and generic counted-byte arithmetic; list/set/map/object-array formulas belong in serializer owners. Charge Dart list/set/object-array reference slots, map key/value slots, compatible list-to-array inline storage, and compatible array-to-list materialization before allocation. Empty containers with no backing storage normally charge zero. Skip only dedicated string, binary, `BoolList`, and typed-array dense owner paths with byte checks. Do not add stream bytes-read accounting, per-element accounting, or extra hot-path allocations for this budget. +- Root deserialization graph memory budgets are owned by `ReadContext`; + `maxGraphMemoryBytes` defaults to `-1 / auto`, positive explicit values win, + and Dart auto uses `buffer.readableBytes * 8 + 64 KiB` because roots are + memory-backed. `ReadContext` may expose only raw byte reservation and generic + counted-byte arithmetic; list, set, map, array, struct, and object formulas + belong in serializer owners. Reserve Dart list/set/object-array reference + slots plus nonzero owner self cost, map key/value slots plus nonzero owner + self cost, compatible list-to-array inline storage, compatible array-to-list + materialization, and generated object reads before allocation. Object/struct + owners reserve nonzero shallow self memory plus shallow field storage. Skip + only dedicated string, binary, primitive scalar, `BoolList`, and typed-array + dense owner paths with byte checks. Do not add stream bytes-read accounting, + per-element accounting, extra hot-path allocations, or stale narrower-scope + formulas. - Do not add parallel header-low/header-high slot caches or multi-slot recent caches in TypeMeta hot paths to chase benchmark gaps. Header-cache hits must use the concrete checked cache owner directly; if a hit hint is needed, cache one TypeInfo/TypeMeta object and compare the validated header identity on that object, not separate low/high header fields or benchmark-pattern state. - If Dart TypeMeta cache ownership changes, keep the invariant in a source comment near the hit path: a checked metadata-cache hit skips the body and must not grow low-bit sentinels, accepted-header fields, parallel header slots, or benchmark-pattern state. - Dart expected-type TypeDef reads should compare the expected `TypeInfo` object's cached local TypeDef header before consulting the parsed-metadata map. A match is a direct local-schema hit: skip the remote body, add the expected type to the per-read shared type table, and do not publish to `ParsedTypeMetaCache`, record a remote schema version, or parse/hash the body. diff --git a/.agents/languages/go.md b/.agents/languages/go.md index e0aa050d14..fcffe384af 100644 --- a/.agents/languages/go.md +++ b/.agents/languages/go.md @@ -7,18 +7,20 @@ Load this file when changing `go/fory/` or Go xlang behavior. - Run Go commands from within `go/fory/`. - Changes under `go/` must pass formatting and tests. - The Go implementation focuses on reflection-based and codegen-based serialization. -- Root deserialization container memory budgets are owned by `ReadContext`. - `WithMaxContainerMemoryBytes` defaults to `-1 / auto`; byte-slice roots use +- Root deserialization graph memory budgets are owned by `ReadContext`. + `WithMaxGraphMemoryBytes` defaults to `-1 / auto`; byte-slice roots use `inputBytes * 8 + 64 KiB`, and `DeserializeFromReader`/`DeserializeFromStream` use fixed `128 MiB`. `ReadContext` may expose only raw byte reservation and - generic counted-byte arithmetic; slice/map formulas belong in handwritten or - generated serializer owners. Charge Go slices as `len * elemBytes`, maps as - `len * (keyBytes + valueBytes)`, map-backed sets, LIST-encoded inline/value - slices, and generated container reads before allocation. Empty containers with - no backing storage normally charge zero. Fixed arrays are caller-owned and - normally not charged; `arrayDynSerializer` charges its temporary slice. Skip - only dedicated string, binary, BufferObject, primitive ARRAY slice, and - primitive array owners with byte checks. + generic counted-byte arithmetic; slice, map, array, struct, and object + formulas belong in handwritten or generated serializer owners. Reserve Go + slices as `len * elemBytes`, maps as `len * (keyBytes + valueBytes)`, + map-backed sets, and LIST-encoded inline/value slices in the owner that + allocates that storage. Root struct owners, pointer allocations, and generated + allocation entries reserve shallow value storage exactly once; nested inline + struct serializers do not charge their own self storage again. Fixed arrays + are caller-owned unless a read path materializes a temporary owner. Skip + dedicated string, binary, BufferObject, primitive scalar, primitive ARRAY + slice, and primitive array owners with byte checks. - Set `FORY_PANIC_ON_ERROR=1` when debugging a failing Go test so you get the full call stack. - Do not set `FORY_PANIC_ON_ERROR=1` when running the full Go test suite, because some tests assert on error contents. diff --git a/.agents/languages/java.md b/.agents/languages/java.md index 0e84f65453..2150b74d8d 100644 --- a/.agents/languages/java.md +++ b/.agents/languages/java.md @@ -14,19 +14,21 @@ Load this file when changing anything under `java/` or when Java drives a cross- values; use qualified names only when a real name conflict requires it. - If you run temporary tests with `java -cp`, run `mvn -T16 install -DskipTests` first so local Fory jars are current. - `WriteContext`, `ReadContext`, and `CopyContext` must stay explicit. Do not reintroduce `ThreadLocal` or ambient runtime-context patterns. -- Java root deserialization container memory budgeting belongs to `ReadContext` - and is initialized by `Fory` root APIs. Public config is - `maxContainerMemoryBytes` with `-1` auto, positive explicit override, - known-length auto `inputBytes * 8 + 64 KiB`, and stream/unknown auto - `128 MiB`. `ReadContext` may expose only raw byte reservation and generic - counted-byte arithmetic; collection/map/object-array formulas belong in the - concrete serializer owner. Java collection/object-array paths charge reference slots only, and - maps charge two reference slots per entry. Fixed/header, map table, and map - entry overhead are not charged unless a future owner documents a conservative - independent lower-bound signal. Preserve existing `checkReadableBytes` guards - before backing allocation or capacity reservation. Do not add nested - serializer-path `try/finally`, per-element work, or dynamic stream bytes-read - accounting for this budget. +- Java root deserialization graph memory budgeting belongs to `ReadContext` + and is initialized by `Fory` root APIs. Public config is `maxGraphMemoryBytes` + with `-1` auto, positive explicit override, known-length auto + `inputBytes * 8 + 64 KiB`, and stream/unknown auto `128 MiB`. `ReadContext` + may expose only raw byte reservation and generic counted-byte arithmetic; + collection, map, array, struct, and object formulas belong in the concrete + serializer or generated serializer owner. Java collection, map, and + object-array owners reserve nonzero shallow self cost plus reference storage; + referenced object serializers reserve their own nonzero shallow self memory + plus shallow field storage when materialized. + Reference fields use the 4-byte fallback when the JVM reference size is not + queried cheaply; primitive fields use their encoded storage width. Preserve + existing `checkReadableBytes` guards before backing allocation or capacity + reservation. Do not add nested serializer-path `try/finally`, per-element + work, dynamic stream bytes-read accounting, or stale narrower-scope formulas. - Generated serializers must not retain runtime context fields. `Fory` should stay a root-operation facade rather than accumulating serializer or convenience state. - When the serializer class and constructor shape are known at the call site, prefer direct constructor lambdas or direct instantiation over reflective `Serializers.newSerializer(...)`. - For GraalVM, use `fory codegen` to generate serializers when building native images. Do not add reflection configuration except for JDK `proxy`. diff --git a/.agents/languages/javascript.md b/.agents/languages/javascript.md index afdc8519c6..c97fce2be9 100644 --- a/.agents/languages/javascript.md +++ b/.agents/languages/javascript.md @@ -14,16 +14,16 @@ Load this file when changing `javascript/`. - Runtime value carriers such as decimal or reduced-precision numeric types belong under the core `types/` ownership boundary, with imports, exports, and codegen externals updated together. - Keep `TypeInfo` as schema metadata. Compatibility-sensitive decisions belong on `TypeResolver` or explicit operations, not as retained resolver state on metadata objects. - Normalize optional boolean config values at config construction; do not carry `null` through runtime paths when it means `false`. -- JavaScript root deserialization container memory budgeting belongs to `ReadContext`. - `maxContainerMemoryBytes` uses `-1` auto, positive explicit limits, and known +- JavaScript root deserialization graph memory budgeting belongs to `ReadContext`. + `maxGraphMemoryBytes` uses `-1` auto, positive explicit limits, and known `Uint8Array` root length as `inputBytes * 8 + 64 KiB`. `ReadContext` may expose only raw byte reservation and generic counted-byte arithmetic; generated and dynamic - list/set/map readers must reserve before allocation while preserving existing - byte checks. Lists/sets/object arrays charge 4-byte reference slots, maps charge - two 4-byte references per entry, and empty containers with no backing storage - normally charge zero. Keep dedicated string, binary, and dense typed-array - owners out of this budget; compatible list-to-typed-array reads must charge - typed inline storage. + list/set/map/array/struct/object readers must reserve before allocation while preserving existing + byte checks. Lists/sets/object arrays reserve nonzero owner self cost plus 4-byte reference slots, + maps reserve nonzero owner self cost plus key/value reference storage, object/struct readers + reserve nonzero shallow self memory plus shallow field storage, and compatible + list-to-typed-array reads reserve typed inline storage. Keep dedicated string, binary, primitive + scalar, and dense typed-array leaf owners out of this budget. - Regenerated compatible read serializers are remote-schema-specific. After classification marks a field as direct, compatible scalar, or skip, generated JavaScript should emit straight-line remote-field-order code. Do not add an outer matched-id switch unless the current regenerated shape cannot preserve those semantics. - Compatible scalar codegen must decide the exact remote/local scalar pair before emitting source. Generate the concrete `reader.readXxx()` call plus inline trivial conversions such as boolean-to-string or numeric widening, and keep helpers only for semantic validation such as range checks, exactness checks, decimal parsing/formatting, and string-to-bool. Do not call a generic hot-path converter that redispatches on `remoteTypeId`, `localTypeId`, field descriptors, or field names. - Compatible scalar conversion is immediate-field-only. Recursive schema comparison for collection elements, array elements, map keys, and map values must reject scalar mismatches instead of applying the top-level scalar conversion matrix. diff --git a/.agents/languages/python.md b/.agents/languages/python.md index 0468193825..0964eb4459 100644 --- a/.agents/languages/python.md +++ b/.agents/languages/python.md @@ -13,15 +13,16 @@ Load this file when changing `python/`, Cython serialization, or Python xlang be - Cython mode owns the hot runtime path. Do not duplicate core runtime types between Python and Cython, tunnel Python facade methods into hidden Cython internals, or keep dead shims unless the user explicitly needs a compatibility module path. - Use explicit Cython fields and methods for fixed hot-path shapes. Avoid `__getattr__`, generic `object` fields, public bridge internals, or `Fory` backreferences where ownership can stay explicit. - Keep Python and Cython context/ref-tracking branch conditions and stack mutations semantically aligned unless a documented intentional difference exists. -- Root deserialization container memory budgets are owned by pure-Python and Cython `ReadContext`. - Keep `max_container_memory_bytes` public on `pyfory.Fory`/`Config`; `-1` uses known-length +- Root deserialization graph memory budgets are owned by pure-Python and Cython `ReadContext`. + Keep `max_graph_memory_bytes` public on `pyfory.Fory`/`Config`; `-1` uses known-length `inputBytes * 8 + 64 KiB` or fixed `128 MiB` for stream roots. `ReadContext` may expose only raw - byte reservation and generic counted-byte arithmetic; collection and dict formulas belong in the - pure-Python or Cython serializer owner. Lists, tuples, sets, and - object-dtype ndarray item storage charge `count * PyObject*`; dicts charge - `entryCount * 2 * PyObject*`. Fixed/header cost defaults to zero unless a path documents an - independent lower-bound owner. Keep string, bytes, `array.array`, primitive dense array, and - primitive ndarray owners skipped, and preserve byte-availability checks after budget reservation. + byte reservation and generic counted-byte arithmetic; collection, dict, array, struct, and object + formulas belong in the pure-Python or Cython serializer owner. Lists, tuples, sets, and + object-dtype ndarray item storage reserve nonzero owner self cost plus `count * PyObject*`; dicts + reserve nonzero owner self cost plus `entryCount * 2 * PyObject*`. Python object owners reserve a + nonzero shallow self cost plus shallow field/reference storage. Keep string, bytes, primitive + scalar, `array.array`, primitive dense array, and primitive ndarray owners skipped, and preserve + byte-availability checks after budget reservation. - Public value constructors should accept normal Python values. Raw-bit, raw-buffer, and memoryview entry points should be explicit low-level APIs, and packed carriers should expose the buffer protocol from the actual storage owner when appropriate. - When debugging runtime or benchmark behavior, install the local package into the exact interpreter under test instead of relying on mixed `PYTHONPATH` state. - For wheel or extension pipeline changes, derive extension-module paths from current build targets, packaging config, or wheel payload discovery rather than historical module names. diff --git a/.agents/languages/rust.md b/.agents/languages/rust.md index 1f19e27d4f..92db836eda 100644 --- a/.agents/languages/rust.md +++ b/.agents/languages/rust.md @@ -18,19 +18,22 @@ Load this file when changing `rust/` or Rust xlang behavior. - If breakage is explicitly acceptable during a Rust module refactor, rewire macros, tests, and sibling crates directly to the new boundaries instead of adding compatibility re-exports. - For panic-safety in hot paths, preserve TLS context reuse. Add scoped guards or owned fallbacks rather than per-call context allocation, and reset reused contexts at entry and successful exit. - Compatible scalar, list-array, and binary/uint8-array adaptations are immediate-field-only. Keep recursive matched-field shape classification owned by `fory-core/src/meta/type_meta.rs`; collection elements, array elements, map keys, and map values must require exact nullability, ref tracking, generic arity, and type shape except documented user-type family normalization. -- Root deserialization container memory budget state belongs to `ReadContext` and is initialized by - the root `Fory` read methods before the header is consumed. Rust roots are byte-slice/`Reader` - backed, so auto budget uses `inputBytes * 8 + 64 KiB`; do not add dynamic bytes-read accounting. +- Root deserialization graph memory budget state belongs to `ReadContext` and is initialized by the + root `Fory` read methods before the header is consumed. Rust roots are byte-slice/`Reader` backed, + so auto budget uses `inputBytes * 8 + 64 KiB`; do not add dynamic bytes-read accounting. `ReadContext` may expose only raw byte reservation and generic counted-byte arithmetic; `Vec`, - collection, map, and derive codec formulas belong in their serializer owners. -- Rust `Vec` stores inline element storage, so general LIST paths charge - `len * size_of::()`, including `Vec` and `Vec`. Maps charge - `len * (size_of::() + size_of::())`. Dedicated primitive dense ARRAY `Vec` readers, - strings, binary, and primitive fixed-array owners stay skipped and keep their byte checks. + collection, map, array, struct, object, and derive codec formulas belong in their serializer + owners. +- Rust `Vec` stores inline element storage, so general LIST paths reserve + `len * size_of::()`, including `Vec` and `Vec`. Maps reserve + `len * (size_of::() + size_of::())`. Root/product/box owners reserve shallow value storage + exactly once; nested inline value serializers do not charge their own self storage again. + Dedicated primitive dense ARRAY `Vec` readers, strings, binary, primitive scalars, and + primitive fixed-array owners stay skipped and keep their byte checks. - Direct `Serializer` collection/map paths and derive `Codec` collection/map paths are separate allocation owners. Keep reservations in both before `Vec::with_capacity`, - `HashMap::with_capacity`, or collection materialization. Empty containers with no dynamic backing - normally charge zero. + `HashMap::with_capacity`, or collection materialization. Empty non-leaf owners that allocate an + independent owner object or storage reserve nonzero shallow self cost. ## Key Paths diff --git a/.agents/languages/swift.md b/.agents/languages/swift.md index 440f89384d..30182cd4ba 100644 --- a/.agents/languages/swift.md +++ b/.agents/languages/swift.md @@ -15,13 +15,20 @@ Load this file when changing `swift/` or Swift xlang behavior. - Preserve distinct temporal semantics. Timestamp values and day-only local dates should have protocol-accurate helper names and no stale aliases after a refactor. - When temporal or public-type refactors touch generated Swift code, sweep message fields, union payloads, macros, xlang harnesses, and integration fixtures together. - Compatible scalar, list-array, and binary/uint8-array adaptations are immediate-field-only. Recursive matched-field comparison for collection elements, array elements, map keys, and map values must require exact nullability, ref tracking, generic arity, and type shape except documented user-type family normalization. -- Root deserialization container memory budget state belongs to `ReadContext`. Swift public roots are `Data` and `ByteBuffer`, so auto uses known root bytes; do not add stream bytes-read accounting or serializer-local budget state. `ReadContext` may expose only raw byte reservation and generic counted-byte arithmetic; array/set/map formulas belong in serializer and field-codec owners. -- For Swift container budget formulas, distinguish inline/value storage from reference storage: use +- Root deserialization graph memory budget state belongs to `ReadContext`. Swift public roots are + `Data` and `ByteBuffer`, so auto uses known root bytes; do not add stream bytes-read accounting or + serializer-local budget state. `ReadContext` may expose only raw byte reservation and generic + counted-byte arithmetic; array, set, map, struct, and object formulas belong in serializer and + field-codec owners. +- For Swift graph budget formulas, distinguish inline/value storage from reference storage: use `MemoryLayout.stride` for value arrays/lists/sets/maps and the 4-byte reference fallback for - `Serializer.isRefType` / `FieldCodec.isRefType` paths. Maps charge key plus value storage, and - empty containers with no backing storage normally charge zero. Dedicated `String`, `Data`/binary, - and primitive packed-array owners stay skipped, except compatible packed-array-to-list reads must - charge the target list materialization before allocation. + `Serializer.isRefType` / `FieldCodec.isRefType` paths. Class/reference paths reserve their own + shallow self cost plus field storage when materialized; value serializers do not unconditionally + charge self storage because root, field, array, set, map, box, or generated owners reserve inline + storage exactly once. Independently materialized collection/map/array owners reserve nonzero + shallow self cost plus backing/reference/inline storage. Dedicated `String`, `Data`/binary, + primitive scalar, and primitive packed-array owners stay skipped, except compatible + packed-array-to-list reads must reserve the target list materialization before allocation. ## Commands diff --git a/AGENTS.md b/AGENTS.md index 51663f9f77..81ffbd39ae 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -32,8 +32,8 @@ This is the entry point for AI guidance in Apache Fory. Read this file first, th - Respect ownership. Keep logic, state, and helpers in their natural owner, and do not move serializer-local, context-local, runtime-type-local, or protocol-local problems into global utilities. - Check the spec before implementation. For wire behavior and xlang mapping, use the specs as the source of truth and never copy one runtime's bug into another runtime just to make tests pass. - Do not make assumptions about runtime behavior, ownership, registration, metadata construction, protocol semantics, or test coverage. Read the current code, owning docs/specs, and relevant tests before making a design judgment or implementation decision. If the evidence is incomplete, inspect more or state the uncertainty explicitly instead of filling gaps from memory or analogy with another runtime. -- For untrusted deserialization, read `docs/security/deserialization.md` before changing allocation, stream filling, skip, reference, metadata, or policy validation behavior. Variable-length deserialization must not allocate or reserve backing/output capacity from attacker-declared lengths or counts before the byte owner has proven proportional readable bytes with `checkReadableBytes` or the runtime equivalent. Container memory-budget reservation is accounting only and may happen before that byte check, but it must not replace the byte check. -- Root deserialization container memory budgets estimate lower-bound container-owned storage, not exact heap accounting and not raw slots. Positive `maxContainerMemoryBytes` wins; `-1` auto uses `inputBytes * 8 + 64 KiB` for known-length roots and fixed `128 MiB` for true stream or unknown-length roots. Read context/read state owns only raw byte accounting plus generic counted-byte arithmetic such as `reserveContainerMemory(bytes)` or `reserveCountedContainerMemory(count, elementBytes)`; it must not expose collection/map/array semantic reservation APIs. Concrete serializers and generated serializers own the formulas: reference-backed containers/object arrays charge reference slots, inline/value containers charge element storage, reference-backed maps charge two references per entry, and inline/value maps charge key plus value storage. Fixed/header cost defaults to zero and is charged only for documented independent lower-bound storage not already covered by parent inline/value storage; empty containers without dynamic backing normally charge zero. Skip only dedicated string, binary, primitive array, and primitive dense-array owners. Do not guess table, bucket, node, entry, object-header, array-header, allocator, or debug-layout overhead. Do not add dynamic stream bytes-read accounting or nested hot-path cleanup just for this budget. +- For untrusted deserialization, read `docs/security/deserialization.md` before changing allocation, stream filling, skip, reference, metadata, or policy validation behavior. Variable-length deserialization must not allocate or reserve backing/output capacity from attacker-declared lengths or counts before the byte owner has proven proportional readable bytes with `checkReadableBytes` or the runtime equivalent. Root graph memory reservation is accounting only and may happen before that byte check, but it must not replace the byte check. +- Root deserialization graph memory budgets estimate lower-bound shallow memory for materialized graph owners, not exact heap accounting and not raw element counts. Positive `maxGraphMemoryBytes` wins; `-1` auto uses `inputBytes * 8 + 64 KiB` for known-length roots and fixed `128 MiB` for true stream or unknown-length roots. Read context/read state owns only raw byte accounting plus generic counted-byte arithmetic, such as reserving `bytes` or `count * elementBytes` with overflow checks; it must not expose collection, map, array, struct, or object semantic reservation APIs. Concrete serializers and generated serializers own formulas for their allocation path. Reserve self storage exactly once at the owner that stores or allocates the value: reference/object runtimes reserve parent owner self cost plus reference storage and every referenced heap owner reserves its own shallow self cost when materialized; inline/value runtimes reserve value storage in the root, struct/product, collection, map, set, array, smart-pointer, or box owner that actually owns that storage, and nested value serializers do not unconditionally reserve their own `sizeof(T)` again. Collection, map, set, and reference-array owners reserve nonzero shallow self cost when independently materialized, plus backing/reference/inline storage. Struct, record, POJO, tuple/product, compatible, generated, and dynamic object owners reserve a nonzero shallow self cost plus shallow field storage. Reference fields use 4 bytes when reference size is not cheap or reliable to query; primitive/value fields use their storage width. Parents do not recursively include child object, collection, map, string, binary, or primitive dense-array contents. Skip enum/union as separate owners and skip dedicated string, binary, primitive scalar, primitive array, and primitive dense-array leaf owners. Do not guess allocator, bucket-table, node, debug, or map-entry overhead unless it is a documented lower-bound owner allocation. Do not add dynamic stream bytes-read accounting or nested hot-path cleanup just for this budget. - For remote TypeDef/TypeMeta reads, the checked metadata cache is the only owner of remote "already validated" state. Cache hit means the header was previously parsed, body/hash-validated, policy-checked, and published by that cache, so the hot path must skip the body and use cached metadata without extra validation, hashing, limit checks, exact-local checks, allocation, or policy work. A known expected local TypeDef/TypeMeta header/hash match is a local-schema hit, not a remote cache miss: it may skip the body and use the local TypeInfo/TypeMeta without schema-version counting or cache publish. Cache miss is the only path that parses and validates non-local metadata, enforces limits, performs exact-local byte comparison when needed, and publishes remote metadata to the cache. Do not add nullable accepted-header fields, sentinel headers, per-TypeInfo markers, pending metadata state, parallel header-low/header-high slots, or parallel acceptance state for this decision. If a runtime needs a metadata hit hint, cache the concrete checked metadata owner object, such as the TypeInfo, TypeDef, or TypeMeta used by that runtime, and compare its validated header identity directly. - When a user corrects a non-obvious invariant, encode it in the nearest source comment before continuing, and also update `AGENTS.md`, `.agents/**`, docs, or specs when the rule is reusable beyond one file. Do not rely only on chat history, task notes, commit messages, or benchmark logs for corrections that protect security, protocol behavior, ownership, naming, or hot-path performance. - Reject semantic hacks. Do not bypass broken semantics by deleting cases, simplifying callers, adding coercion hooks, or using workaround fallbacks; fix the underlying bug and prove it with focused tests. diff --git a/cpp/fory/serialization/BUILD b/cpp/fory/serialization/BUILD index 1102e53f1c..67df6188ae 100644 --- a/cpp/fory/serialization/BUILD +++ b/cpp/fory/serialization/BUILD @@ -110,8 +110,8 @@ cc_test( ) cc_test( - name = "container_memory_budget_test", - srcs = ["container_memory_budget_test.cc"], + name = "graph_memory_budget_test", + srcs = ["graph_memory_budget_test.cc"], deps = [ ":fory_serialization", "@googletest//:gtest", diff --git a/cpp/fory/serialization/CMakeLists.txt b/cpp/fory/serialization/CMakeLists.txt index 7ff18d0320..a3e5cd58bc 100644 --- a/cpp/fory/serialization/CMakeLists.txt +++ b/cpp/fory/serialization/CMakeLists.txt @@ -102,10 +102,10 @@ if(FORY_BUILD_TESTS) target_link_libraries(fory_serialization_map_test fory_serialization GTest::gtest GTest::gtest_main) gtest_discover_tests(fory_serialization_map_test) - add_executable(fory_serialization_container_memory_budget_test container_memory_budget_test.cc) - fory_configure_target(fory_serialization_container_memory_budget_test) - target_link_libraries(fory_serialization_container_memory_budget_test fory_serialization GTest::gtest GTest::gtest_main) - gtest_discover_tests(fory_serialization_container_memory_budget_test) + add_executable(fory_serialization_graph_memory_budget_test graph_memory_budget_test.cc) + fory_configure_target(fory_serialization_graph_memory_budget_test) + target_link_libraries(fory_serialization_graph_memory_budget_test fory_serialization GTest::gtest GTest::gtest_main) + gtest_discover_tests(fory_serialization_graph_memory_budget_test) add_executable(fory_serialization_variant_test variant_serializer_test.cc) fory_configure_target(fory_serialization_variant_test) diff --git a/cpp/fory/serialization/collection_serializer.h b/cpp/fory/serialization/collection_serializer.h index 3449e1d6fe..9eee2711cd 100644 --- a/cpp/fory/serialization/collection_serializer.h +++ b/cpp/fory/serialization/collection_serializer.h @@ -400,8 +400,7 @@ inline bool reserve_collection(Container &result, ReadContext &ctx, } constexpr size_t elem_bytes = collection_element_memory_bytes(); if (FORY_PREDICT_FALSE( - (!ctx.template reserve_counted_container_memory( - length)))) { + (!ctx.template reserve_counted_graph_memory(length)))) { return false; } if (FORY_PREDICT_FALSE(!ctx.buffer().ensure_readable(length, ctx.error()))) { @@ -419,9 +418,8 @@ inline bool reserve_collection(std::vector &result, if (FORY_PREDICT_FALSE(ctx.has_error())) { return false; } - const size_t packed_bytes = - (static_cast(length) + CHAR_BIT - 1) / CHAR_BIT; - if (FORY_PREDICT_FALSE(!ctx.reserve_container_memory(packed_bytes))) { + const size_t packed_bytes = (static_cast(length) + 7) / 8; + if (FORY_PREDICT_FALSE(!ctx.reserve_graph_memory(packed_bytes))) { return false; } if (FORY_PREDICT_FALSE(!ctx.buffer().ensure_readable(length, ctx.error()))) { @@ -436,7 +434,7 @@ inline bool reserve_empty_collection(ReadContext &ctx) { if (FORY_PREDICT_FALSE(ctx.has_error())) { return false; } - return ctx.reserve_container_memory(0); + return ctx.reserve_graph_memory(0); } // Helper to insert element into container (vector or set) diff --git a/cpp/fory/serialization/config.h b/cpp/fory/serialization/config.h index a59b575f71..fc46221506 100644 --- a/cpp/fory/serialization/config.h +++ b/cpp/fory/serialization/config.h @@ -52,9 +52,9 @@ struct Config { /// When enabled, avoids duplicating shared objects and handles cycles. bool track_ref = true; - /// Maximum estimated container-owned memory accepted during one root + /// Maximum estimated graph memory accepted during one root /// deserialization. `-1` selects the automatic input-shaped limit. - int64_t max_container_memory_bytes = -1; + int64_t max_graph_memory_bytes = -1; /// Maximum accepted field count in one received struct TypeMeta. uint32_t max_type_fields = 512; diff --git a/cpp/fory/serialization/context.cc b/cpp/fory/serialization/context.cc index f88c19f9e1..0af96a03c1 100644 --- a/cpp/fory/serialization/context.cc +++ b/cpp/fory/serialization/context.cc @@ -739,59 +739,58 @@ const TypeInfo *ReadContext::read_any_type_info(Error &error) { return result.value(); } -bool ReadContext::reserve_counted_container_checked(uint32_t length, - size_t elem_bytes) { +bool ReadContext::reserve_counted_graph_checked(uint32_t length, + size_t elem_bytes) { if (FORY_PREDICT_FALSE(elem_bytes != 0 && static_cast(length) > std::numeric_limits::max() / elem_bytes)) { - return set_container_memory_overflow(length, elem_bytes); + return set_graph_memory_overflow(length, elem_bytes); } - return reserve_container_memory(static_cast(length) * elem_bytes); + return reserve_graph_memory(static_cast(length) * elem_bytes); } -bool ReadContext::init_explicit_container_budget(int64_t configured) { +bool ReadContext::init_explicit_graph_budget(int64_t configured) { const uint64_t limit = static_cast(configured); if constexpr (sizeof(size_t) < sizeof(uint64_t)) { if (FORY_PREDICT_FALSE(limit > static_cast( std::numeric_limits::max()))) { - return set_container_memory_error( - "max_container_memory_bytes does not fit size_t"); + return set_graph_memory_error( + "max_graph_memory_bytes does not fit size_t"); } } - remaining_container_memory_bytes_ = static_cast(limit); - container_budget_state_ = kContainerBudgetReady; + remaining_graph_memory_bytes_ = static_cast(limit); + graph_budget_state_ = kGraphBudgetReady; return true; } -bool ReadContext::materialize_container_budget() { - switch (container_budget_state_) { - case kContainerBudgetPendingKnown: - return init_container_budget_known(pending_container_root_bytes_); - case kContainerBudgetPendingUnknown: - return init_container_budget_unknown(); +bool ReadContext::materialize_graph_budget() { + switch (graph_budget_state_) { + case kGraphBudgetPendingKnown: + return init_graph_budget_known(pending_graph_root_bytes_); + case kGraphBudgetPendingUnknown: + return init_graph_budget_unknown(); default: return true; } } -bool ReadContext::set_container_memory_error(const std::string &message) { +bool ReadContext::set_graph_memory_error(const std::string &message) { set_error(Error::invalid_data(message)); return false; } -bool ReadContext::set_container_memory_overflow(uint32_t length, - size_t elem_bytes) { +bool ReadContext::set_graph_memory_overflow(uint32_t length, + size_t elem_bytes) { set_error(Error::invalid_data( - "container memory estimate overflows: length=" + std::to_string(length) + + "graph memory estimate overflows: length=" + std::to_string(length) + " elementBytes=" + std::to_string(elem_bytes))); return false; } -bool ReadContext::set_container_memory_exceeded(size_t bytes, - size_t remaining) { +bool ReadContext::set_graph_memory_exceeded(size_t bytes, size_t remaining) { set_error(Error::invalid_data( - "estimated container memory request " + std::to_string(bytes) + - " bytes exceeds max_container_memory_bytes remaining budget " + + "estimated graph memory request " + std::to_string(bytes) + + " bytes exceeds max_graph_memory_bytes remaining budget " + std::to_string(remaining) + " bytes")); return false; } @@ -804,7 +803,7 @@ void ReadContext::reset() { } reading_type_infos_.clear(); current_dyn_depth_ = 0; - // Root deserialization initializes the container budget before reading the + // Root deserialization initializes the graph budget before reading the // header; direct ReadContext users start with the unlimited sentinel fields. // Leave those fields untouched here so root guard cleanup stays store-light. if (meta_string_table_active_) { diff --git a/cpp/fory/serialization/context.h b/cpp/fory/serialization/context.h index cb8a6a378e..12fc33ccd8 100644 --- a/cpp/fory/serialization/context.h +++ b/cpp/fory/serialization/context.h @@ -505,69 +505,68 @@ class ReadContext { } } - FORY_ALWAYS_INLINE bool init_container_budget_known(size_t root_bytes) { - const int64_t configured = config_->max_container_memory_bytes; + FORY_ALWAYS_INLINE bool init_graph_budget_known(size_t root_bytes) { + const int64_t configured = config_->max_graph_memory_bytes; if (FORY_PREDICT_FALSE(configured > 0)) { - return init_explicit_container_budget(configured); + return init_explicit_graph_budget(configured); } if constexpr (sizeof(size_t) <= sizeof(uint32_t)) { - constexpr size_t max_root_bytes = (std::numeric_limits::max() - - kKnownContainerBudgetSlackBytes) / - kKnownContainerBudgetMultiplier; + constexpr size_t max_root_bytes = + (std::numeric_limits::max() - kKnownGraphBudgetSlackBytes) / + kKnownGraphBudgetMultiplier; if (FORY_PREDICT_FALSE(root_bytes > max_root_bytes)) { - return set_container_memory_error( - "root input size overflows automatic container memory budget"); + return set_graph_memory_error( + "root input size overflows automatic graph memory budget"); } } - remaining_container_memory_bytes_ = - root_bytes * kKnownContainerBudgetMultiplier + - kKnownContainerBudgetSlackBytes; - container_budget_state_ = kContainerBudgetReady; + remaining_graph_memory_bytes_ = + root_bytes * kKnownGraphBudgetMultiplier + kKnownGraphBudgetSlackBytes; + graph_budget_state_ = kGraphBudgetReady; return true; } - FORY_ALWAYS_INLINE bool init_container_budget_unknown() { - const int64_t configured = config_->max_container_memory_bytes; + FORY_ALWAYS_INLINE bool init_graph_budget_unknown() { + const int64_t configured = config_->max_graph_memory_bytes; if (FORY_PREDICT_FALSE(configured > 0)) { - return init_explicit_container_budget(configured); + return init_explicit_graph_budget(configured); } - remaining_container_memory_bytes_ = kUnknownContainerBudgetBytes; - container_budget_state_ = kContainerBudgetReady; + remaining_graph_memory_bytes_ = kUnknownGraphBudgetBytes; + graph_budget_state_ = kGraphBudgetReady; return true; } - FORY_ALWAYS_INLINE void defer_container_budget_known(size_t root_bytes) { - pending_container_root_bytes_ = root_bytes; - container_budget_state_ = kContainerBudgetPendingKnown; + FORY_ALWAYS_INLINE void defer_graph_budget_known(size_t root_bytes) { + pending_graph_root_bytes_ = root_bytes; + graph_budget_state_ = kGraphBudgetPendingKnown; } - FORY_ALWAYS_INLINE void defer_container_budget_unknown() { - container_budget_state_ = kContainerBudgetPendingUnknown; + FORY_ALWAYS_INLINE void defer_graph_budget_unknown() { + graph_budget_state_ = kGraphBudgetPendingUnknown; } - FORY_ALWAYS_INLINE bool reserve_container_memory(size_t bytes) { - if (FORY_PREDICT_FALSE(container_budget_state_ != kContainerBudgetReady)) { - if (FORY_PREDICT_FALSE(!materialize_container_budget())) { + FORY_ALWAYS_INLINE bool reserve_graph_memory(size_t bytes) { + if (FORY_PREDICT_FALSE(graph_budget_state_ != kGraphBudgetReady)) { + if (FORY_PREDICT_FALSE(!materialize_graph_budget())) { return false; } } - const size_t remaining = remaining_container_memory_bytes_; + const size_t remaining = remaining_graph_memory_bytes_; if (FORY_PREDICT_FALSE(bytes > remaining)) { - return set_container_memory_exceeded(bytes, remaining); + return set_graph_memory_exceeded(bytes, remaining); } - remaining_container_memory_bytes_ = remaining - bytes; + remaining_graph_memory_bytes_ = remaining - bytes; return true; } template - FORY_ALWAYS_INLINE bool reserve_counted_container_memory(uint32_t length) { + FORY_ALWAYS_INLINE bool reserve_counted_graph_memory(uint32_t length) { constexpr size_t kMaxLength = static_cast(std::numeric_limits::max()); if constexpr (elem_bytes <= std::numeric_limits::max() / kMaxLength) { - return reserve_container_memory(static_cast(length) * elem_bytes); + return reserve_graph_memory(static_cast(length) * elem_bytes); } else { - return reserve_counted_container_checked(length, elem_bytes); + return reserve_counted_graph_checked(length, elem_bytes); } } @@ -726,26 +725,24 @@ class ReadContext { inline const Config &config() const { return *config_; } private: - static constexpr size_t kKnownContainerBudgetMultiplier = 8; - static constexpr size_t kKnownContainerBudgetSlackBytes = 64 * 1024; - static constexpr size_t kUnknownContainerBudgetBytes = - 128ULL * 1024ULL * 1024ULL; - static constexpr uint8_t kContainerBudgetReady = 0; - static constexpr uint8_t kContainerBudgetPendingKnown = 1; - static constexpr uint8_t kContainerBudgetPendingUnknown = 2; + static constexpr size_t kKnownGraphBudgetMultiplier = 8; + static constexpr size_t kKnownGraphBudgetSlackBytes = 64 * 1024; + static constexpr size_t kUnknownGraphBudgetBytes = 128ULL * 1024ULL * 1024ULL; + static constexpr uint8_t kGraphBudgetReady = 0; + static constexpr uint8_t kGraphBudgetPendingKnown = 1; + static constexpr uint8_t kGraphBudgetPendingUnknown = 2; FORY_NOINLINE Result check_remote_type_meta_limit(const TypeMeta &type_meta); void record_remote_type_meta(const std::string &type_key); - FORY_NOINLINE bool reserve_counted_container_checked(uint32_t length, - size_t elem_bytes); - FORY_NOINLINE bool init_explicit_container_budget(int64_t configured); - FORY_NOINLINE bool materialize_container_budget(); - FORY_NOINLINE bool set_container_memory_error(const std::string &message); - FORY_NOINLINE bool set_container_memory_overflow(uint32_t length, + FORY_NOINLINE bool reserve_counted_graph_checked(uint32_t length, size_t elem_bytes); - FORY_NOINLINE bool set_container_memory_exceeded(size_t bytes, - size_t remaining); + FORY_NOINLINE bool init_explicit_graph_budget(int64_t configured); + FORY_NOINLINE bool materialize_graph_budget(); + FORY_NOINLINE bool set_graph_memory_error(const std::string &message); + FORY_NOINLINE bool set_graph_memory_overflow(uint32_t length, + size_t elem_bytes); + FORY_NOINLINE bool set_graph_memory_exceeded(size_t bytes, size_t remaining); // Error state - accumulated during deserialization, checked at the end Error error_; @@ -755,9 +752,9 @@ class ReadContext { std::unique_ptr type_resolver_; RefReader ref_reader_; uint32_t current_dyn_depth_; - uint8_t container_budget_state_ = kContainerBudgetReady; - size_t pending_container_root_bytes_ = 0; - size_t remaining_container_memory_bytes_ = std::numeric_limits::max(); + uint8_t graph_budget_state_ = kGraphBudgetReady; + size_t pending_graph_root_bytes_ = 0; + size_t remaining_graph_memory_bytes_ = std::numeric_limits::max(); // Meta sharing state (for compatible mode) // Persistent cache storage for TypeInfo objects keyed by meta header. diff --git a/cpp/fory/serialization/fory.h b/cpp/fory/serialization/fory.h index 1ef660cc90..bdd86d25b4 100644 --- a/cpp/fory/serialization/fory.h +++ b/cpp/fory/serialization/fory.h @@ -109,13 +109,13 @@ class ForyBuilder { return *this; } - /// Set maximum estimated container-owned memory for one root deserialization. + /// Set maximum estimated graph memory for one root deserialization. /// /// Use `-1` for automatic limits. Positive values are explicit byte limits. - ForyBuilder &max_container_memory_bytes(int64_t max_bytes) { + ForyBuilder &max_graph_memory_bytes(int64_t max_bytes) { FORY_CHECK(max_bytes == -1 || max_bytes > 0) - << "max_container_memory_bytes must be positive or -1 for auto"; - config_.max_container_memory_bytes = max_bytes; + << "max_graph_memory_bytes must be positive or -1 for auto"; + config_.max_graph_memory_bytes = max_bytes; return *this; } @@ -889,11 +889,18 @@ class Fory : public BaseFory { } read_ctx_->attach(buffer); - if constexpr (needs_container_budget_v) { + if constexpr (needs_graph_budget_v) { if constexpr (unknown_root) { - read_ctx_->defer_container_budget_unknown(); + read_ctx_->defer_graph_budget_unknown(); } else { - read_ctx_->defer_container_budget_known(root_bytes); + read_ctx_->defer_graph_budget_known(root_bytes); + } + constexpr size_t root_owner_bytes = graph_value_owner_self_bytes(); + if constexpr (root_owner_bytes != 0) { + if (FORY_PREDICT_FALSE( + !read_ctx_->reserve_graph_memory(root_owner_bytes))) { + return Unexpected(read_ctx_->take_error()); + } } } ReadContextGuard guard(*read_ctx_); diff --git a/cpp/fory/serialization/container_memory_budget_test.cc b/cpp/fory/serialization/graph_memory_budget_test.cc similarity index 56% rename from cpp/fory/serialization/container_memory_budget_test.cc rename to cpp/fory/serialization/graph_memory_budget_test.cc index e255a98a26..e212d9ee56 100644 --- a/cpp/fory/serialization/container_memory_budget_test.cc +++ b/cpp/fory/serialization/graph_memory_budget_test.cc @@ -20,7 +20,6 @@ #include "fory/serialization/fory.h" #include "gtest/gtest.h" #include -#include #include #include #include @@ -52,6 +51,12 @@ struct BudgetItem { FORY_STRUCT(BudgetItem, id, name); }; +struct BudgetEmpty { + bool operator==(const BudgetEmpty &) const { return true; } + + FORY_STRUCT(BudgetEmpty); +}; + struct BudgetSiblings { std::vector left; std::vector right; @@ -74,17 +79,17 @@ struct BudgetFixedArrayOwner { FORY_STRUCT(BudgetFixedArrayOwner, prefix, items); }; -template -auto with_fory(int64_t max_container_memory_bytes, Fn &&fn) { +template auto with_fory(int64_t max_graph_memory_bytes, Fn &&fn) { auto fory = Fory::builder() .xlang(true) .compatible(false) .track_ref(false) - .max_container_memory_bytes(max_container_memory_bytes) + .max_graph_memory_bytes(max_graph_memory_bytes) .build(); fory.register_struct(1); fory.register_struct(2); fory.register_struct(3); + fory.register_struct(4); return std::forward(fn)(fory); } @@ -95,8 +100,9 @@ template std::vector serialize_value(const T &value) { } size_t nested_empty_budget(size_t count) { + using Outer = std::vector>; using Inner = std::vector; - return count * sizeof(Inner); + return sizeof(Outer) + count * sizeof(Inner); } template @@ -104,11 +110,13 @@ void expect_budget_boundary(const T &value, size_t required) { ASSERT_GT(required, 0u); auto bytes = serialize_value(value); - auto small_result = - with_fory(static_cast(required - 1), - [&](Fory &fory) { return fory.deserialize(bytes); }); - ASSERT_FALSE(small_result.ok()); - EXPECT_EQ(small_result.error().code(), ErrorCode::InvalidData); + if (required > 1) { + auto small_result = + with_fory(static_cast(required - 1), + [&](Fory &fory) { return fory.deserialize(bytes); }); + ASSERT_FALSE(small_result.ok()); + EXPECT_EQ(small_result.error().code(), ErrorCode::InvalidData); + } auto exact_result = with_fory(static_cast(required), @@ -117,20 +125,20 @@ void expect_budget_boundary(const T &value, size_t required) { EXPECT_EQ(exact_result.value(), value); } -TEST(ContainerMemoryBudgetTest, KnownLengthAutoBudget) { +TEST(GraphMemoryBudgetTest, KnownLengthAutoBudget) { Config config; - config.max_container_memory_bytes = -1; + config.max_graph_memory_bytes = -1; ReadContext context(config, std::make_unique()); constexpr size_t root_bytes = 17; const size_t expected = root_bytes * 8 + kKnownBudgetSlack; - ASSERT_TRUE(context.init_container_budget_known(root_bytes)); - ASSERT_TRUE(context.reserve_container_memory(expected)); - ASSERT_FALSE(context.reserve_container_memory(1)); + ASSERT_TRUE(context.init_graph_budget_known(root_bytes)); + ASSERT_TRUE(context.reserve_graph_memory(expected)); + ASSERT_FALSE(context.reserve_graph_memory(1)); EXPECT_EQ(context.take_error().code(), ErrorCode::InvalidData); } -TEST(ContainerMemoryBudgetTest, StreamAutoBudget) { +TEST(GraphMemoryBudgetTest, StreamAutoBudget) { constexpr size_t count = 10000; std::vector> value(count); auto bytes = serialize_value(value); @@ -153,10 +161,11 @@ TEST(ContainerMemoryBudgetTest, StreamAutoBudget) { EXPECT_EQ(stream_result.value(), value); } -TEST(ContainerMemoryBudgetTest, ExplicitOverride) { +TEST(GraphMemoryBudgetTest, ExplicitOverride) { std::vector value(8); auto bytes = serialize_value(value); - const size_t required = value.size() * sizeof(BudgetItem); + const size_t required = + sizeof(std::vector) + value.size() * sizeof(BudgetItem); auto small_result = with_fory(static_cast(required - 1), [&](Fory &fory) { @@ -173,10 +182,87 @@ TEST(ContainerMemoryBudgetTest, ExplicitOverride) { EXPECT_EQ(exact_result.value(), value); } -TEST(ContainerMemoryBudgetTest, NestedEmptyContainersUseParentStorage) { +TEST(GraphMemoryBudgetTest, SmartPointerStructOwners) { + auto shared_value = std::make_shared(); + shared_value->id = 7; + shared_value->name = "shared"; + auto shared_bytes = serialize_value(shared_value); + constexpr size_t shared_required = + sizeof(std::shared_ptr) + sizeof(BudgetItem); + + auto shared_small = + with_fory(static_cast(shared_required - 1), [&](Fory &fory) { + return fory.deserialize>(shared_bytes); + }); + ASSERT_FALSE(shared_small.ok()); + EXPECT_EQ(shared_small.error().code(), ErrorCode::InvalidData); + + auto shared_exact = + with_fory(static_cast(shared_required), [&](Fory &fory) { + return fory.deserialize>(shared_bytes); + }); + ASSERT_TRUE(shared_exact.ok()) << shared_exact.error().to_string(); + ASSERT_NE(shared_exact.value(), nullptr); + EXPECT_EQ(*shared_exact.value(), *shared_value); + + auto unique_value = std::make_unique(); + unique_value->id = 9; + unique_value->name = "unique"; + auto unique_bytes = serialize_value(unique_value); + + constexpr size_t unique_required = + sizeof(std::unique_ptr) + sizeof(BudgetItem); + auto unique_small = + with_fory(static_cast(unique_required - 1), [&](Fory &fory) { + return fory.deserialize>(unique_bytes); + }); + ASSERT_FALSE(unique_small.ok()); + EXPECT_EQ(unique_small.error().code(), ErrorCode::InvalidData); + + auto unique_exact = + with_fory(static_cast(unique_required), [&](Fory &fory) { + return fory.deserialize>(unique_bytes); + }); + ASSERT_TRUE(unique_exact.ok()) << unique_exact.error().to_string(); + ASSERT_NE(unique_exact.value(), nullptr); + EXPECT_EQ(*unique_exact.value(), *unique_value); +} + +TEST(GraphMemoryBudgetTest, SmartPointerVectorOwner) { + auto value = std::make_shared>(3); + auto bytes = serialize_value(value); + const size_t required = sizeof(std::shared_ptr>) + + sizeof(std::vector) + + value->size() * sizeof(BudgetItem); + + auto small_result = + with_fory(static_cast(required - 1), [&](Fory &fory) { + return fory.deserialize>>( + bytes); + }); + ASSERT_FALSE(small_result.ok()); + EXPECT_EQ(small_result.error().code(), ErrorCode::InvalidData); + + auto exact_result = + with_fory(static_cast(required), [&](Fory &fory) { + return fory.deserialize>>( + bytes); + }); + ASSERT_TRUE(exact_result.ok()) << exact_result.error().to_string(); + ASSERT_NE(exact_result.value(), nullptr); + EXPECT_EQ(*exact_result.value(), *value); +} + +TEST(GraphMemoryBudgetTest, EmptyStructRootChargesOwner) { + BudgetEmpty value; + expect_budget_boundary(value, sizeof(BudgetEmpty)); +} + +TEST(GraphMemoryBudgetTest, NestedEmptyContainersUseParentStorage) { std::vector> value(1); auto bytes = serialize_value(value); - const size_t required = sizeof(std::vector); + const size_t required = sizeof(std::vector>) + + sizeof(std::vector); auto small_result = with_fory(static_cast(required - 1), [&](Fory &fory) { @@ -193,77 +279,85 @@ TEST(ContainerMemoryBudgetTest, NestedEmptyContainersUseParentStorage) { EXPECT_EQ(exact_result.value(), value); } -TEST(ContainerMemoryBudgetTest, SiblingCumulativeBudget) { +TEST(GraphMemoryBudgetTest, SiblingCumulativeBudget) { BudgetSiblings value; value.left.resize(16); value.right.resize(16); auto bytes = serialize_value(value); + const size_t root_owner = sizeof(BudgetSiblings); const size_t one_vector = value.left.size() * sizeof(BudgetItem); auto small_result = - with_fory(static_cast(one_vector), [&](Fory &fory) { + with_fory(static_cast(root_owner + one_vector), [&](Fory &fory) { return fory.deserialize(bytes); }); ASSERT_FALSE(small_result.ok()); EXPECT_EQ(small_result.error().code(), ErrorCode::InvalidData); - auto enough_result = - with_fory(static_cast(one_vector * 2), [&](Fory &fory) { - return fory.deserialize(bytes); - }); + auto enough_result = with_fory( + static_cast(root_owner + one_vector * 2), + [&](Fory &fory) { return fory.deserialize(bytes); }); ASSERT_TRUE(enough_result.ok()) << enough_result.error().to_string(); EXPECT_EQ(enough_result.value(), value); } -TEST(ContainerMemoryBudgetTest, MapBudget) { +TEST(GraphMemoryBudgetTest, MapBudget) { std::map value{{"a", 1}, {"b", 2}, {"c", 3}}; const size_t entry_bytes = sizeof(std::string) + sizeof(int32_t); - const size_t required = value.size() * entry_bytes; + const size_t required = + sizeof(std::map) + value.size() * entry_bytes; expect_budget_boundary(value, required); } -TEST(ContainerMemoryBudgetTest, CollectionLowerBounds) { +TEST(GraphMemoryBudgetTest, CollectionLowerBounds) { std::deque deque_value(4); - expect_budget_boundary(deque_value, deque_value.size() * sizeof(BudgetItem)); + expect_budget_boundary(deque_value, + sizeof(std::deque) + + deque_value.size() * sizeof(BudgetItem)); std::list list_value(4); - expect_budget_boundary(list_value, list_value.size() * sizeof(BudgetItem)); + expect_budget_boundary(list_value, + sizeof(std::list) + + list_value.size() * sizeof(BudgetItem)); std::forward_list forward_value(4); - expect_budget_boundary(forward_value, size_t{4} * sizeof(BudgetItem)); + expect_budget_boundary(forward_value, sizeof(std::forward_list) + + size_t{4} * sizeof(BudgetItem)); } -TEST(ContainerMemoryBudgetTest, VectorBoolUsesPackedStorage) { +TEST(GraphMemoryBudgetTest, VectorBoolChargesPackedStorage) { std::vector value(33); value[0] = true; value[32] = true; - const size_t packed_bytes = (value.size() + CHAR_BIT - 1) / CHAR_BIT; - const size_t required = packed_bytes; - ASSERT_LT(required, value.size()); - - expect_budget_boundary(value, required); + expect_budget_boundary(value, size_t{5}); } -TEST(ContainerMemoryBudgetTest, OrderedSetAndMapLowerBounds) { +TEST(GraphMemoryBudgetTest, OrderedSetAndMapLowerBounds) { std::set set_value{1, 2, 3, 4}; - expect_budget_boundary(set_value, set_value.size() * sizeof(int32_t)); + expect_budget_boundary(set_value, sizeof(std::set) + + set_value.size() * sizeof(int32_t)); std::map map_value{{"a", 1}, {"b", 2}}; - expect_budget_boundary( - map_value, map_value.size() * (sizeof(std::string) + sizeof(int32_t))); + expect_budget_boundary(map_value, + sizeof(std::map) + + map_value.size() * + (sizeof(std::string) + sizeof(int32_t))); } -TEST(ContainerMemoryBudgetTest, UnorderedContainersLowerBounds) { +TEST(GraphMemoryBudgetTest, UnorderedContainersLowerBounds) { std::unordered_set set_value{1, 2, 3, 4}; - expect_budget_boundary(set_value, set_value.size() * sizeof(int32_t)); + expect_budget_boundary(set_value, sizeof(std::unordered_set) + + set_value.size() * sizeof(int32_t)); std::unordered_map map_value{{"a", 1}, {"b", 2}}; - expect_budget_boundary( - map_value, map_value.size() * (sizeof(std::string) + sizeof(int32_t))); + expect_budget_boundary(map_value, + sizeof(std::unordered_map) + + map_value.size() * + (sizeof(std::string) + sizeof(int32_t))); } -TEST(ContainerMemoryBudgetTest, ArrayHasNoStandaloneReservation) { +TEST(GraphMemoryBudgetTest, ArrayHasNoStandaloneReservation) { std::array value{{1, 2, 3, 4}}; auto bytes = serialize_value(value); auto result = with_fory(1, [&](Fory &fory) { @@ -273,18 +367,19 @@ TEST(ContainerMemoryBudgetTest, ArrayHasNoStandaloneReservation) { EXPECT_EQ(result.value(), value); } -TEST(ContainerMemoryBudgetTest, FixedInlineOwnerChargesNestedVector) { +TEST(GraphMemoryBudgetTest, FixedInlineOwnerChargesNestedVector) { BudgetFixedArrayOwner value; value.prefix = {{1, 2, 3, 4}}; value.items.resize(3); - const size_t required = value.items.size() * sizeof(BudgetItem); + const size_t required = + sizeof(BudgetFixedArrayOwner) + value.items.size() * sizeof(BudgetItem); expect_budget_boundary(value, required); } -TEST(ContainerMemoryBudgetTest, DensePathsSkipped) { +TEST(GraphMemoryBudgetTest, DensePathsSkipped) { { - std::string value = "container-budget-string"; + std::string value = "graph-budget-string"; auto bytes = serialize_value(value); auto result = with_fory( 1, [&](Fory &fory) { return fory.deserialize(bytes); }); @@ -311,7 +406,7 @@ TEST(ContainerMemoryBudgetTest, DensePathsSkipped) { } } -TEST(ContainerMemoryBudgetTest, ByteCheckStillRejectsLargeLength) { +TEST(GraphMemoryBudgetTest, ByteCheckStillRejectsLargeLength) { Config config; auto resolver = std::make_unique(); ReadContext ctx(config, std::move(resolver)); diff --git a/cpp/fory/serialization/map_serializer.h b/cpp/fory/serialization/map_serializer.h index c07a5f0a44..82b41c322e 100644 --- a/cpp/fory/serialization/map_serializer.h +++ b/cpp/fory/serialization/map_serializer.h @@ -98,8 +98,7 @@ inline bool reserve_map(MapType &map, ReadContext &ctx, uint32_t length) { "map entry memory estimate overflows"); constexpr size_t elem_bytes = sizeof(Key) + sizeof(Value); if (FORY_PREDICT_FALSE( - (!ctx.template reserve_counted_container_memory( - length)))) { + (!ctx.template reserve_counted_graph_memory(length)))) { return false; } if (FORY_PREDICT_FALSE(!ctx.buffer().ensure_readable(length, ctx.error()))) { @@ -113,7 +112,7 @@ template inline bool reserve_empty_map(ReadContext &ctx) { if (FORY_PREDICT_FALSE(ctx.has_error())) { return false; } - return ctx.reserve_container_memory(0); + return ctx.reserve_graph_memory(0); } /// write chunk size at header offset diff --git a/cpp/fory/serialization/serializer_traits.h b/cpp/fory/serialization/serializer_traits.h index e26ba5c899..59aa361d25 100644 --- a/cpp/fory/serialization/serializer_traits.h +++ b/cpp/fory/serialization/serializer_traits.h @@ -24,7 +24,9 @@ #include "fory/meta/type_index.h" #include "fory/meta/type_traits.h" #include "fory/type/type.h" +#include "fory/util/macros.h" #include +#include #include #include #include @@ -246,84 +248,104 @@ template inline constexpr bool is_fory_serializable_v = is_fory_serializable::value; // ============================================================================ -// Container budget reachability +// Graph budget reachability // ============================================================================ template -struct needs_container_budget : std::false_type {}; - -template <> struct needs_container_budget : std::false_type {}; -template <> struct needs_container_budget : std::false_type {}; -template <> struct needs_container_budget : std::false_type {}; -template <> struct needs_container_budget : std::false_type {}; -template <> struct needs_container_budget : std::false_type {}; -template <> struct needs_container_budget : std::false_type {}; -template <> struct needs_container_budget : std::false_type {}; -template <> struct needs_container_budget : std::false_type {}; -template <> struct needs_container_budget : std::false_type {}; -template <> struct needs_container_budget : std::false_type {}; -template <> struct needs_container_budget : std::false_type {}; -template <> -struct needs_container_budget : std::false_type {}; - -template -struct needs_container_budget< +struct needs_graph_budget : std::false_type {}; + +template <> struct needs_graph_budget : std::false_type {}; +template <> struct needs_graph_budget : std::false_type {}; +template <> struct needs_graph_budget : std::false_type {}; +template <> struct needs_graph_budget : std::false_type {}; +template <> struct needs_graph_budget : std::false_type {}; +template <> struct needs_graph_budget : std::false_type {}; +template <> struct needs_graph_budget : std::false_type {}; +template <> struct needs_graph_budget : std::false_type {}; +template <> struct needs_graph_budget : std::false_type {}; +template <> struct needs_graph_budget : std::false_type {}; +template <> struct needs_graph_budget : std::false_type {}; +template <> struct needs_graph_budget : std::false_type {}; + +template +struct needs_graph_budget< T, std::enable_if_t || is_list_v || is_deque_v || is_forward_list_v || is_set_like_v || is_map_like_v>> : std::true_type {}; template -struct needs_container_budget, void> - : std::bool_constant, void> + : std::bool_constant>>::value> {}; template -struct needs_container_budget, void> - : std::bool_constant, void> + : std::bool_constant>>::value> {}; template -struct needs_container_budget, void> : std::true_type {}; +struct needs_graph_budget, void> : std::true_type {}; template -struct needs_container_budget, void> : std::true_type {}; +struct needs_graph_budget, void> : std::true_type {}; template -struct needs_container_budget, void> - : std::bool_constant<(needs_container_budget>>::value || - ...)> {}; +struct needs_graph_budget, void> : std::true_type {}; template -struct needs_container_budget, void> - : std::bool_constant<(needs_container_budget, void> + : std::bool_constant<(needs_graph_budget>>::value || ...)> {}; template -constexpr bool struct_needs_container_budget_impl(std::index_sequence) { +constexpr bool struct_needs_graph_budget_impl(std::index_sequence) { return ( - needs_container_budget< + needs_graph_budget< std::remove_cv_t>>>>::value || ...); } template -struct needs_container_budget>> { -private: - using FieldDescriptor = - decltype(::fory::meta::fory_field_info(std::declval())); - using Ptrs = typename FieldDescriptor::PtrsType; +struct needs_graph_budget>> + : std::true_type {}; -public: - static constexpr bool value = struct_needs_container_budget_impl( - std::make_index_sequence{}); -}; +template +inline constexpr bool needs_graph_budget_v = + needs_graph_budget>>::value; + +template struct is_dense_primitive_vector : std::false_type {}; + +template +struct is_dense_primitive_vector> + : std::bool_constant> {}; template -inline constexpr bool needs_container_budget_v = - needs_container_budget>>::value; +inline constexpr bool is_dense_primitive_vector_v = is_dense_primitive_vector< + std::remove_cv_t>>::value; + +template constexpr size_t graph_value_owner_self_bytes() { + using Value = std::remove_cv_t>; + if constexpr (!needs_graph_budget_v || + is_dense_primitive_vector_v) { + return size_t{0}; + } else if constexpr (std::is_empty_v) { + return size_t{1}; + } else { + return sizeof(Value); + } +} + +template +FORY_ALWAYS_INLINE bool reserve_allocated_value_owner(Context &ctx) { + constexpr size_t bytes = graph_value_owner_self_bytes(); + if constexpr (bytes == 0) { + return true; + } else { + return ctx.reserve_graph_memory(bytes); + } +} // ============================================================================ // Generic Type Detection diff --git a/cpp/fory/serialization/smart_ptr_serializers.h b/cpp/fory/serialization/smart_ptr_serializers.h index 2ffe0ca98d..6796fb2e4a 100644 --- a/cpp/fory/serialization/smart_ptr_serializers.h +++ b/cpp/fory/serialization/smart_ptr_serializers.h @@ -444,6 +444,9 @@ template struct Serializer> { "Cannot use monomorphic deserialization for abstract type")); return nullptr; } else { + if (FORY_PREDICT_FALSE(!reserve_allocated_value_owner(ctx))) { + return nullptr; + } T value = Serializer::read(ctx, RefMode::None, false); if (ctx.has_error()) { return nullptr; @@ -454,6 +457,9 @@ template struct Serializer> { } else { // T is guaranteed to be a value type (not pointer or nullable wrapper) // by static_assert, so no inner ref metadata needed. + if (FORY_PREDICT_FALSE(!reserve_allocated_value_owner(ctx))) { + return nullptr; + } T value = Serializer::read(ctx, RefMode::None, read_type); if (ctx.has_error()) { return nullptr; @@ -546,6 +552,9 @@ template struct Serializer> { } else { // For circular references: pre-allocate and store BEFORE reading if (is_first_occurrence) { + if (FORY_PREDICT_FALSE(!reserve_allocated_value_owner(ctx))) { + return nullptr; + } auto result = std::make_shared(); ctx.ref_reader().store_shared_ref_at(reserved_ref_id, result); T value = Serializer::read(ctx, RefMode::None, false); @@ -555,6 +564,9 @@ template struct Serializer> { *result = std::move(value); return result; } else { + if (FORY_PREDICT_FALSE(!reserve_allocated_value_owner(ctx))) { + return nullptr; + } T value = Serializer::read(ctx, RefMode::None, false); if (ctx.has_error()) { return nullptr; @@ -572,6 +584,9 @@ template struct Serializer> { // references (like self_ref pointing back to the parent) to resolve. if (is_first_occurrence) { // Pre-allocate with default construction and store immediately + if (FORY_PREDICT_FALSE(!reserve_allocated_value_owner(ctx))) { + return nullptr; + } auto result = std::make_shared(); ctx.ref_reader().store_shared_ref_at(reserved_ref_id, result); // Read struct data - forward refs can now find this object @@ -584,6 +599,9 @@ template struct Serializer> { return result; } else { // Not first occurrence, just read and wrap + if (FORY_PREDICT_FALSE(!reserve_allocated_value_owner(ctx))) { + return nullptr; + } T value = Serializer::read(ctx, RefMode::None, read_type); if (ctx.has_error()) { return nullptr; @@ -610,6 +628,9 @@ template struct Serializer> { return std::shared_ptr(obj_ptr); } else { // T is guaranteed to be a value type by static_assert. + if (FORY_PREDICT_FALSE(!reserve_allocated_value_owner(ctx))) { + return nullptr; + } T value = Serializer::read_with_type_info(ctx, RefMode::None, type_info); if (ctx.has_error()) { @@ -689,6 +710,9 @@ template struct Serializer> { // For circular references: pre-allocate and store BEFORE reading const bool is_first_occurrence = flag == REF_VALUE_FLAG; if (is_first_occurrence) { + if (FORY_PREDICT_FALSE(!reserve_allocated_value_owner(ctx))) { + return nullptr; + } auto result = std::make_shared(); ctx.ref_reader().store_shared_ref_at(reserved_ref_id, result); T value = @@ -699,6 +723,9 @@ template struct Serializer> { *result = std::move(value); return result; } else { + if (FORY_PREDICT_FALSE(!reserve_allocated_value_owner(ctx))) { + return nullptr; + } T value = Serializer::read_with_type_info(ctx, RefMode::None, type_info); if (ctx.has_error()) { @@ -710,6 +737,9 @@ template struct Serializer> { } static inline std::shared_ptr read_data(ReadContext &ctx) { + if (FORY_PREDICT_FALSE(!reserve_allocated_value_owner(ctx))) { + return nullptr; + } T value = Serializer::read_data(ctx); if (ctx.has_error()) { return nullptr; @@ -894,6 +924,9 @@ template struct Serializer> { "Cannot use monomorphic deserialization for abstract type")); return nullptr; } else { + if (FORY_PREDICT_FALSE(!reserve_allocated_value_owner(ctx))) { + return nullptr; + } T value = Serializer::read(ctx, RefMode::None, false); if (ctx.has_error()) { return nullptr; @@ -903,6 +936,9 @@ template struct Serializer> { } } else { // T is guaranteed to be a value type by static_assert. + if (FORY_PREDICT_FALSE(!reserve_allocated_value_owner(ctx))) { + return nullptr; + } T value = Serializer::read(ctx, RefMode::None, read_type); if (ctx.has_error()) { return nullptr; @@ -960,6 +996,9 @@ template struct Serializer> { "Cannot use monomorphic deserialization for abstract type")); return nullptr; } else { + if (FORY_PREDICT_FALSE(!reserve_allocated_value_owner(ctx))) { + return nullptr; + } T value = Serializer::read(ctx, RefMode::None, false); if (ctx.has_error()) { return nullptr; @@ -969,6 +1008,9 @@ template struct Serializer> { } } else { // T is guaranteed to be a value type by static_assert. + if (FORY_PREDICT_FALSE(!reserve_allocated_value_owner(ctx))) { + return nullptr; + } T value = Serializer::read(ctx, RefMode::None, read_type); if (ctx.has_error()) { return nullptr; @@ -994,6 +1036,9 @@ template struct Serializer> { return std::unique_ptr(obj_ptr); } else { // T is guaranteed to be a value type by static_assert. + if (FORY_PREDICT_FALSE(!reserve_allocated_value_owner(ctx))) { + return nullptr; + } T value = Serializer::read_with_type_info(ctx, RefMode::None, type_info); if (ctx.has_error()) { @@ -1037,6 +1082,9 @@ template struct Serializer> { return std::unique_ptr(obj_ptr); } else { // T is guaranteed to be a value type by static_assert. + if (FORY_PREDICT_FALSE(!reserve_allocated_value_owner(ctx))) { + return nullptr; + } T value = Serializer::read_with_type_info(ctx, RefMode::None, type_info); if (ctx.has_error()) { @@ -1047,6 +1095,9 @@ template struct Serializer> { } static inline std::unique_ptr read_data(ReadContext &ctx) { + if (FORY_PREDICT_FALSE(!reserve_allocated_value_owner(ctx))) { + return nullptr; + } T value = Serializer::read_data(ctx); if (ctx.has_error()) { return nullptr; diff --git a/cpp/fory/serialization/type_resolver.h b/cpp/fory/serialization/type_resolver.h index 1a14a8570f..fa3b61f5af 100644 --- a/cpp/fory/serialization/type_resolver.h +++ b/cpp/fory/serialization/type_resolver.h @@ -1573,6 +1573,9 @@ template inline std::any any_read_adapter(ReadContext &ctx) { if (FORY_PREDICT_FALSE(ctx.has_error())) { return std::any(); } + if (FORY_PREDICT_FALSE(!reserve_allocated_value_owner(ctx))) { + return std::any(); + } if constexpr (std::is_copy_constructible::value) { return std::any(std::move(value)); } @@ -2127,6 +2130,9 @@ void *TypeResolver::harness_read_adapter(ReadContext &ctx, RefMode ref_mode, if (FORY_PREDICT_FALSE(ctx.has_error())) { return nullptr; } + if (FORY_PREDICT_FALSE(!reserve_allocated_value_owner(ctx))) { + return nullptr; + } return new T(std::move(value)); } @@ -2151,6 +2157,9 @@ void *TypeResolver::harness_read_data_adapter(ReadContext &ctx) { if (FORY_PREDICT_FALSE(ctx.has_error())) { return nullptr; } + if (FORY_PREDICT_FALSE(!reserve_allocated_value_owner(ctx))) { + return nullptr; + } return new T(std::move(value)); } @@ -2173,6 +2182,9 @@ void *TypeResolver::harness_read_compatible_adapter(ReadContext &ctx, if (FORY_PREDICT_FALSE(ctx.has_error())) { return nullptr; } + if (FORY_PREDICT_FALSE(!reserve_allocated_value_owner(ctx))) { + return nullptr; + } return new T(std::move(value)); } diff --git a/csharp/src/Fory.Generator/ForyModelGenerator.cs b/csharp/src/Fory.Generator/ForyModelGenerator.cs index d819197470..68ba979c72 100644 --- a/csharp/src/Fory.Generator/ForyModelGenerator.cs +++ b/csharp/src/Fory.Generator/ForyModelGenerator.cs @@ -212,6 +212,11 @@ private static void EmitObjectSerializer(StringBuilder sb, TypeModel model) sb.AppendLine(" private static bool __ForyRefMetaMatches;"); sb.AppendLine( $" private const bool __ForyAllFieldsBuiltIn = {BoolLiteral(model.SortedMembers.All(m => m.DynamicAnyKind == DynamicAnyKind.None && m.Classification.IsBuiltIn))};"); + if (model.Kind == DeclKind.Class) + { + sb.AppendLine($" private static readonly long __ForyGraphMemoryBytes = {ModelGraphMemoryExpr(model)};"); + } + sb.AppendLine( " private static global::System.Collections.Generic.IReadOnlyList? __ForyNoRefTypeMetaFields;"); sb.AppendLine( @@ -448,6 +453,11 @@ private static void EmitObjectSerializer(StringBuilder sb, TypeModel model) sb.AppendLine(); sb.AppendLine($" private {model.TypeName} ReadDataWithoutTypeMeta(global::Apache.Fory.ReadContext context)"); sb.AppendLine(" {"); + if (model.Kind == DeclKind.Class) + { + sb.AppendLine(" context.ReserveGraphMemory(__ForyGraphMemoryBytes);"); + } + sb.AppendLine($" {model.TypeName} valueNoTypeMeta = new {model.TypeName}();"); if (model.Kind == DeclKind.Class) { @@ -482,6 +492,11 @@ private static void EmitObjectSerializer(StringBuilder sb, TypeModel model) sb.AppendLine(" }"); sb.AppendLine(); sb.AppendLine(" global::Apache.Fory.TypeMeta typeMeta = maybeTypeMeta;"); + if (model.Kind == DeclKind.Class) + { + sb.AppendLine(" context.ReserveGraphMemory(__ForyGraphMemoryBytes);"); + } + sb.AppendLine($" {model.TypeName} value = new {model.TypeName}();"); if (model.Kind == DeclKind.Class) { @@ -592,6 +607,11 @@ private static void EmitObjectSerializer(StringBuilder sb, TypeModel model) sb.AppendLine(" }"); sb.AppendLine(" }"); sb.AppendLine(); + if (model.Kind == DeclKind.Class) + { + sb.AppendLine(" context.ReserveGraphMemory(__ForyGraphMemoryBytes);"); + } + sb.AppendLine($" {model.TypeName} valueSchema = new {model.TypeName}();"); if (model.Kind == DeclKind.Class) { @@ -1165,12 +1185,12 @@ private static void EmitReadCompatibleListArrayPayload( $"(typeof({elementTypeName}).IsValueType ? global::System.Runtime.CompilerServices.Unsafe.SizeOf<{elementTypeName}>() : 4)"; if (codec.CarrierKind == CarrierKind.Array) { - sb.AppendLine($"{indent}context.ReserveCountedContainerMemory({lengthVar}, {elementBytesExpr});"); + sb.AppendLine($"{indent}context.ReserveGraphMemory(1L + (long){lengthVar} * {elementBytesExpr});"); sb.AppendLine($"{indent}{codec.TypeName} {targetVar} = new {ElementTypeName(codec.TypeName)}[{lengthVar}];"); } else { - sb.AppendLine($"{indent}context.ReserveCountedContainerMemory({lengthVar}, {elementBytesExpr});"); + sb.AppendLine($"{indent}context.ReserveGraphMemory(1L + (long){lengthVar} * {elementBytesExpr});"); sb.AppendLine($"{indent}{codec.TypeName} {targetVar} = new({lengthVar});"); } @@ -1517,8 +1537,8 @@ private static void EmitReadPackedArrayPayload( } else { - string elementBytesExpr = ContainerElementBytesExpr(PackedArrayElementTypeName(codec.TypeId)); - sb.AppendLine($"{indent}context.ReserveCountedContainerMemory({countVar}, {elementBytesExpr});"); + string elementBytesExpr = GraphElementBytesExpr(PackedArrayElementTypeName(codec.TypeId)); + sb.AppendLine($"{indent}context.ReserveGraphMemory(1L + (long){countVar} * {elementBytesExpr});"); sb.AppendLine($"{indent}{codec.TypeName} {targetVar} = new({countVar});"); } @@ -1558,7 +1578,7 @@ private static void EmitReadCollectionPayload( string sameTypeVar = $"__forySameType{id++}"; string declaredVar = $"__foryDeclared{id++}"; sb.AppendLine($"{indent}int {lengthVar} = checked((int)context.Reader.ReadVarUInt32());"); - sb.AppendLine($"{indent}context.ReserveCountedContainerMemory({lengthVar}, {ContainerElementBytesExpr(element)});"); + sb.AppendLine($"{indent}context.ReserveGraphMemory(1L + (long){lengthVar} * {GraphElementBytesExpr(element)});"); sb.AppendLine($"{indent}if ({lengthVar} != 0)"); sb.AppendLine($"{indent}{{"); sb.AppendLine($"{indent} context.Reader.CheckBound({lengthVar});"); @@ -1664,7 +1684,7 @@ private static void EmitReadMapPayload( FieldCodecModel value = codec.Generics[1]; string totalVar = $"__foryTotal{id++}"; sb.AppendLine($"{indent}int {totalVar} = checked((int)context.Reader.ReadVarUInt32());"); - sb.AppendLine($"{indent}context.ReserveCountedContainerMemory({totalVar}, {ContainerMapElementBytesExpr(key, value)});"); + sb.AppendLine($"{indent}context.ReserveGraphMemory(1L + (long){totalVar} * {GraphMapElementBytesExpr(key, value)});"); sb.AppendLine($"{indent}if ({totalVar} != 0)"); sb.AppendLine($"{indent}{{"); sb.AppendLine($"{indent} context.Reader.CheckBound({totalVar});"); @@ -1823,22 +1843,44 @@ private static string ElementTypeName(string arrayTypeName) : "object"; } - private static string ContainerElementBytesExpr(FieldCodecModel codec) + private static string GraphElementBytesExpr(FieldCodecModel codec) { - return ContainerElementBytesExpr( + return GraphElementBytesExpr( codec.Nullable && !codec.NullableValueType ? StripNullableForTypeOf(codec.TypeName) : codec.TypeName); } - private static string ContainerElementBytesExpr(string typeName) + private static string GraphElementBytesExpr(string typeName) { return $"(typeof({typeName}).IsValueType ? global::System.Runtime.CompilerServices.Unsafe.SizeOf<{typeName}>() : 4)"; } - private static string ContainerMapElementBytesExpr(FieldCodecModel key, FieldCodecModel value) + private static string GraphMapElementBytesExpr(FieldCodecModel key, FieldCodecModel value) + { + return $"((long){GraphElementBytesExpr(key)} + {GraphElementBytesExpr(value)})"; + } + + private static string ModelGraphMemoryExpr(TypeModel model) + { + System.Collections.Generic.List parts = new() { "1L" }; + foreach (MemberModel member in model.SortedMembers) + { + parts.Add(FieldGraphMemoryExpr(member)); + } + + return string.Join(" + ", parts); + } + + private static string FieldGraphMemoryExpr(MemberModel member) { - return $"((long){ContainerElementBytesExpr(key)} + {ContainerElementBytesExpr(value)})"; + if (member.Classification.IsPrimitive && member.Classification.PrimitiveSize > 0) + { + return $"{member.Classification.PrimitiveSize}L"; + } + + string typeName = StripNullableForTypeOf(member.TypeName); + return $"(typeof({typeName}).IsValueType ? global::System.Runtime.CompilerServices.Unsafe.SizeOf<{typeName}>() : 4L)"; } private static string PackedArrayElementTypeName(uint typeId) diff --git a/csharp/src/Fory/CollectionSerializers.cs b/csharp/src/Fory/CollectionSerializers.cs index d9baf94203..72e6169c57 100644 --- a/csharp/src/Fory/CollectionSerializers.cs +++ b/csharp/src/Fory/CollectionSerializers.cs @@ -32,6 +32,7 @@ internal static class CollectionBits internal static class CollectionCodec { + private const int CollectionBytes = 1; private const int ReferenceBytes = 4; [MethodImpl(MethodImplOptions.AggressiveInlining)] @@ -40,7 +41,7 @@ internal static class CollectionCodec [MethodImpl(MethodImplOptions.AggressiveInlining)] internal static void ReserveElementStorage(ReadContext context, int count) { - context.ReserveCountedContainerMemory(count, ElementBytes()); + context.ReserveGraphMemory(CollectionBytes + (long)count * ElementBytes()); } private static bool NeedsCompatibleElementTypeMeta(TypeInfo typeInfo, WriteContext context) @@ -536,7 +537,6 @@ public override void WriteData(WriteContext context, in T[] value, bool hasGener public override T[] ReadData(ReadContext context) { List values = CollectionCodec.ReadCollectionData(context.TypeResolver.GetSerializer(), context); - CollectionCodec.ReserveElementStorage(context, values.Count); return values.ToArray(); } } diff --git a/csharp/src/Fory/Config.cs b/csharp/src/Fory/Config.cs index 1947bac29c..207ca40b77 100644 --- a/csharp/src/Fory/Config.cs +++ b/csharp/src/Fory/Config.cs @@ -28,7 +28,7 @@ internal Config( bool compatible, bool checkStructVersion, int maxDepth, - long maxContainerMemoryBytes, + long maxGraphMemoryBytes, int maxTypeFields, int maxTypeMetaBytes, int maxSchemaVersionsPerType, @@ -38,11 +38,11 @@ internal Config( { throw new ArgumentOutOfRangeException(nameof(maxDepth), "MaxDepth must be greater than 0."); } - if (maxContainerMemoryBytes != -1 && maxContainerMemoryBytes <= 0) + if (maxGraphMemoryBytes != -1 && maxGraphMemoryBytes <= 0) { throw new ArgumentOutOfRangeException( - nameof(maxContainerMemoryBytes), - "MaxContainerMemoryBytes must be positive or -1 for auto."); + nameof(maxGraphMemoryBytes), + "MaxGraphMemoryBytes must be positive or -1 for auto."); } if (maxTypeFields <= 0) { @@ -65,7 +65,7 @@ internal Config( Compatible = compatible; CheckStructVersion = checkStructVersion; MaxDepth = maxDepth; - MaxContainerMemoryBytes = maxContainerMemoryBytes; + MaxGraphMemoryBytes = maxGraphMemoryBytes; MaxTypeFields = maxTypeFields; MaxTypeMetaBytes = maxTypeMetaBytes; MaxSchemaVersionsPerType = maxSchemaVersionsPerType; @@ -93,9 +93,9 @@ internal Config( public int MaxDepth { get; } ///

- /// Gets the maximum estimated container-owned memory accepted during one root deserialization. + /// Gets the maximum estimated graph memory accepted during one root deserialization. /// - public long MaxContainerMemoryBytes { get; } + public long MaxGraphMemoryBytes { get; } /// /// Gets the maximum accepted field count in one received struct TypeMeta. @@ -127,7 +127,7 @@ public sealed class ForyBuilder private bool? _compatible; private bool _checkStructVersion; private int _maxDepth = 20; - private long _maxContainerMemoryBytes = -1; + private long _maxGraphMemoryBytes = -1; private int _maxTypeFields = 512; private int _maxTypeMetaBytes = 4096; private int _maxSchemaVersionsPerType = 10; @@ -184,19 +184,19 @@ public ForyBuilder MaxDepth(int value) } /// - /// Sets the maximum estimated container-owned memory accepted during one root deserialization. + /// Sets the maximum estimated graph memory accepted during one root deserialization. /// Use -1 for the automatic root-size-based limit, or a positive byte limit. /// - public ForyBuilder MaxContainerMemoryBytes(long value) + public ForyBuilder MaxGraphMemoryBytes(long value) { if (value != -1 && value <= 0) { throw new ArgumentOutOfRangeException( nameof(value), - "MaxContainerMemoryBytes must be positive or -1 for auto."); + "MaxGraphMemoryBytes must be positive or -1 for auto."); } - _maxContainerMemoryBytes = value; + _maxGraphMemoryBytes = value; return this; } @@ -266,7 +266,7 @@ private Config BuildConfig() compatible: compatible, checkStructVersion: compatible ? false : _checkStructVersion, maxDepth: _maxDepth, - maxContainerMemoryBytes: _maxContainerMemoryBytes, + maxGraphMemoryBytes: _maxGraphMemoryBytes, maxTypeFields: _maxTypeFields, maxTypeMetaBytes: _maxTypeMetaBytes, maxSchemaVersionsPerType: _maxSchemaVersionsPerType, diff --git a/csharp/src/Fory/DictionarySerializers.cs b/csharp/src/Fory/DictionarySerializers.cs index 444c2c6fbd..3be9465a49 100644 --- a/csharp/src/Fory/DictionarySerializers.cs +++ b/csharp/src/Fory/DictionarySerializers.cs @@ -34,6 +34,7 @@ public abstract class DictionaryLikeSerializer : Seri where TDictionary : class, IDictionary where TKey : notnull { + private const int MapBytes = 1; private const int ReferenceBytes = 4; private static readonly long MapElementBytes = (long)ElementBytes() + ElementBytes(); @@ -43,7 +44,7 @@ public abstract class DictionaryLikeSerializer : Seri [MethodImpl(MethodImplOptions.AggressiveInlining)] private static void ReserveMapStorage(ReadContext context, int count) { - context.ReserveCountedContainerMemory(count, MapElementBytes); + context.ReserveGraphMemory(MapBytes + count * MapElementBytes); } public override TDictionary DefaultValue => null!; diff --git a/csharp/src/Fory/Fory.cs b/csharp/src/Fory/Fory.cs index edcfdb13b2..6591682b7b 100644 --- a/csharp/src/Fory/Fory.cs +++ b/csharp/src/Fory/Fory.cs @@ -190,7 +190,7 @@ public T Deserialize(ReadOnlySpan payload) { ByteReader reader = _readContext.Reader; reader.Reset(payload); - _readContext.InitContainerBudgetKnown(payload.Length); + _readContext.InitGraphBudgetKnown(payload.Length); T value = DeserializeFromReader(reader); if (reader.Remaining != 0) { @@ -211,7 +211,7 @@ public T Deserialize(byte[] payload) { ByteReader reader = _readContext.Reader; reader.Reset(payload); - _readContext.InitContainerBudgetKnown(payload.Length); + _readContext.InitGraphBudgetKnown(payload.Length); T value = DeserializeFromReader(reader); if (reader.Remaining != 0) { @@ -232,7 +232,7 @@ public T Deserialize(ref ReadOnlySequence payload) byte[] bytes = payload.ToArray(); ByteReader reader = _readContext.Reader; reader.Reset(bytes); - _readContext.InitContainerBudgetKnown(bytes.Length); + _readContext.InitGraphBudgetKnown(bytes.Length); T value = DeserializeFromReader(reader); payload = payload.Slice(reader.Cursor); return value; @@ -282,6 +282,7 @@ private T DeserializeFromReader(ByteReader reader) Serializer serializer = _typeResolver.GetSerializer(); ReadContext readContext = _readContext; readContext.ResetFor(reader); + GraphMemory.ReserveRootValue(readContext); RefMode refMode = Config.TrackRef ? RefMode.Tracking : RefMode.NullOnly; T value = serializer.Read(readContext, refMode, true); readContext.RefReader.Reset(); diff --git a/csharp/src/Fory/GraphMemory.cs b/csharp/src/Fory/GraphMemory.cs new file mode 100644 index 0000000000..3997388798 --- /dev/null +++ b/csharp/src/Fory/GraphMemory.cs @@ -0,0 +1,72 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +using System.Runtime.CompilerServices; + +namespace Apache.Fory; + +internal static class GraphMemory +{ + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal static long ValueOwnerBytes() => ValueOwner.Bytes; + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal static void ReserveRootValue(ReadContext context) + { + long bytes = ValueOwner.Bytes; + if (bytes != 0) + { + context.ReserveGraphMemory(bytes); + } + } + + private static class ValueOwner + { + internal static readonly long Bytes = Compute(); + + private static long Compute() + { + Type type = typeof(T); + if (!ShouldReserve(type)) + { + return 0; + } + + int bytes = Unsafe.SizeOf(); + return bytes == 0 ? 1 : bytes; + } + + private static bool ShouldReserve(Type type) + { + if (!type.IsValueType || + Nullable.GetUnderlyingType(type) is not null || + type.IsEnum || + type.IsPrimitive) + { + return false; + } + + return type != typeof(decimal) && + type != typeof(Half) && + type != typeof(BFloat16) && + type != typeof(DateOnly) && + type != typeof(DateTime) && + type != typeof(DateTimeOffset) && + type != typeof(TimeSpan); + } + } +} diff --git a/csharp/src/Fory/NullableKeyDictionary.cs b/csharp/src/Fory/NullableKeyDictionary.cs index ae08f775a0..cf479e6c81 100644 --- a/csharp/src/Fory/NullableKeyDictionary.cs +++ b/csharp/src/Fory/NullableKeyDictionary.cs @@ -391,6 +391,7 @@ IEnumerator IEnumerable.GetEnumerator() public sealed class NullableKeyDictionarySerializer : Serializer> { + private const int MapBytes = 1; private const int ReferenceBytes = 4; private static readonly long MapElementBytes = (long)ElementBytes() + ElementBytes(); @@ -400,7 +401,7 @@ public sealed class NullableKeyDictionarySerializer : Serializer DefaultValue => null!; diff --git a/csharp/src/Fory/PrimitiveDictionarySerializers.cs b/csharp/src/Fory/PrimitiveDictionarySerializers.cs index 594abd6833..4d56772192 100644 --- a/csharp/src/Fory/PrimitiveDictionarySerializers.cs +++ b/csharp/src/Fory/PrimitiveDictionarySerializers.cs @@ -664,6 +664,7 @@ public static void WriteMap(ReadContext context, int count) { - context.ReserveCountedContainerMemory( - count, - (long)ElementBytes() + ElementBytes()); + context.ReserveGraphMemory(MapBytes + count * ((long)ElementBytes() + ElementBytes())); } public static TMap ReadMap(ReadContext context) diff --git a/csharp/src/Fory/ReadContext.cs b/csharp/src/Fory/ReadContext.cs index b27af464bd..5098fdd532 100644 --- a/csharp/src/Fory/ReadContext.cs +++ b/csharp/src/Fory/ReadContext.cs @@ -22,7 +22,7 @@ namespace Apache.Fory; public sealed class ReadContext { private const int MinRemoteTypeMetaLimit = 8192; - internal const long KnownContainerBudgetSlackBytes = 64 * 1024; + internal const long KnownGraphBudgetSlackBytes = 64 * 1024; private readonly ReusableArray _typeMetaRefs = new(); private readonly UInt64Map _typeMetasByHeader = new(); @@ -43,8 +43,8 @@ public sealed class ReadContext private readonly Dictionary _remoteSchemaVersionsByType = []; private readonly Config _config; private int _totalAcceptedSchemaVersions; - private long _containerMemoryLimitBytes = long.MaxValue; - private long _remainingContainerMemoryBytes = long.MaxValue; + private long _graphMemoryLimitBytes = long.MaxValue; + private long _remainingGraphMemoryBytes = long.MaxValue; public ReadContext( ByteReader reader, @@ -76,73 +76,73 @@ public ReadContext( internal RefReader RefReader { get; } [MethodImpl(MethodImplOptions.AggressiveInlining)] - internal void InitContainerBudgetKnown(int rootBytes) + internal void InitGraphBudgetKnown(int rootBytes) { - long limit = _config.MaxContainerMemoryBytes; + long limit = _config.MaxGraphMemoryBytes; if (limit < 0) { - limit = (long)rootBytes * 8 + KnownContainerBudgetSlackBytes; + limit = (long)rootBytes * 8 + KnownGraphBudgetSlackBytes; } - _containerMemoryLimitBytes = limit; - _remainingContainerMemoryBytes = limit; + _graphMemoryLimitBytes = limit; + _remainingGraphMemoryBytes = limit; } [MethodImpl(MethodImplOptions.AggressiveInlining)] /// - /// Reserves estimated container-owned memory for the current root deserialization. + /// Reserves estimated graph memory for the current root deserialization. /// /// - /// Serializer owners compute container-specific formulas and pass raw bytes here. This + /// Serializer owners compute owner-specific formulas and pass raw bytes here. This /// accounting does not replace byte-availability checks before backing allocation. /// - public void ReserveContainerMemory(long bytes) + public void ReserveGraphMemory(long bytes) { - long remaining = _remainingContainerMemoryBytes; + long remaining = _remainingGraphMemoryBytes; if ((ulong)bytes > (ulong)remaining) { - ThrowContainerBudgetExceeded(bytes, remaining, _containerMemoryLimitBytes); + ThrowGraphBudgetExceeded(bytes, remaining, _graphMemoryLimitBytes); } - _remainingContainerMemoryBytes = remaining - bytes; + _remainingGraphMemoryBytes = remaining - bytes; } [MethodImpl(MethodImplOptions.AggressiveInlining)] /// /// Reserves multiplied by estimated - /// container-owned bytes for the current root deserialization. + /// graph-owner bytes for the current root deserialization. /// /// /// This helper owns only overflow-safe arithmetic; concrete serializers and generated /// serializers still own the collection, array, and map storage formulas. /// - public void ReserveCountedContainerMemory(int count, long elementBytes) + public void ReserveCountedGraphMemory(int count, long elementBytes) { if (count < 0 || elementBytes < 0) { - ThrowContainerBudgetOverflow(); + ThrowGraphBudgetOverflow(); } uint length = (uint)count; if (elementBytes != 0 && length > long.MaxValue / elementBytes) { - ThrowContainerBudgetOverflow(); + ThrowGraphBudgetOverflow(); } - ReserveContainerMemory((long)length * elementBytes); + ReserveGraphMemory((long)length * elementBytes); } [MethodImpl(MethodImplOptions.NoInlining)] - private static void ThrowContainerBudgetOverflow() + private static void ThrowGraphBudgetOverflow() { - throw new InvalidDataException("container memory estimate overflows"); + throw new InvalidDataException("graph memory estimate overflows"); } [MethodImpl(MethodImplOptions.NoInlining)] - private static void ThrowContainerBudgetExceeded(long bytes, long remaining, long limit) + private static void ThrowGraphBudgetExceeded(long bytes, long remaining, long limit) { throw new InvalidDataException( - $"estimated container memory request {bytes} bytes exceeds MaxContainerMemoryBytes remaining budget {remaining} bytes out of effective limit {limit} bytes"); + $"estimated graph memory request {bytes} bytes exceeds MaxGraphMemoryBytes remaining budget {remaining} bytes out of effective limit {limit} bytes"); } internal void ResetFor(ByteReader reader) diff --git a/csharp/src/Fory/TypeInfo.cs b/csharp/src/Fory/TypeInfo.cs index 8341d418c1..0b9a5f8630 100644 --- a/csharp/src/Fory/TypeInfo.cs +++ b/csharp/src/Fory/TypeInfo.cs @@ -102,6 +102,7 @@ internal static TypeInfo Create(Type type, Serializer serializer) bool evolving = ResolveStructEvolving(type, userTypeKind); bool isNullableType = !type.IsValueType || Nullable.GetUnderlyingType(type) is not null; bool isRefType = type != typeof(string) && !type.IsValueType; + long boxedValueBytes = GraphMemory.ValueOwnerBytes(); return new TypeInfo( type, serializer, @@ -118,10 +119,10 @@ internal static TypeInfo Create(Type type, Serializer serializer) namespaceName: null, typeName: null, (context, value, hasGenerics) => WriteDataObject(serializer, context, value, hasGenerics), - context => serializer.ReadData(context), + context => ReadDataObject(serializer, context, boxedValueBytes), (context, value, refMode, writeTypeInfo, hasGenerics) => WriteObject(serializer, context, value, refMode, writeTypeInfo, hasGenerics), - (context, refMode, readTypeInfo) => serializer.Read(context, refMode, readTypeInfo), + (context, refMode, readTypeInfo) => ReadObject(serializer, context, refMode, readTypeInfo, boxedValueBytes), typeMetaFields, builtInTypeId, null); @@ -178,6 +179,16 @@ private static void WriteDataObject(Serializer serializer, WriteContext co serializer.WriteData(context, CoerceRuntimeValue(serializer, value), hasGenerics); } + private static object? ReadDataObject(Serializer serializer, ReadContext context, long boxedValueBytes) + { + if (boxedValueBytes != 0) + { + context.ReserveGraphMemory(boxedValueBytes); + } + + return serializer.ReadData(context); + } + private static void WriteObject( Serializer serializer, WriteContext context, @@ -189,6 +200,21 @@ private static void WriteObject( serializer.Write(context, CoerceRuntimeValue(serializer, value), refMode, writeTypeInfo, hasGenerics); } + private static object? ReadObject( + Serializer serializer, + ReadContext context, + RefMode refMode, + bool readTypeInfo, + long boxedValueBytes) + { + if (boxedValueBytes != 0) + { + context.ReserveGraphMemory(boxedValueBytes); + } + + return serializer.Read(context, refMode, readTypeInfo); + } + private static T CoerceRuntimeValue(Serializer serializer, object? value) { if (value is T typed) diff --git a/csharp/tests/Fory.Tests/ContainerMemoryBudgetTests.cs b/csharp/tests/Fory.Tests/GraphMemoryBudgetTests.cs similarity index 66% rename from csharp/tests/Fory.Tests/ContainerMemoryBudgetTests.cs rename to csharp/tests/Fory.Tests/GraphMemoryBudgetTests.cs index 6dcd9d9c96..fee7ce963c 100644 --- a/csharp/tests/Fory.Tests/ContainerMemoryBudgetTests.cs +++ b/csharp/tests/Fory.Tests/GraphMemoryBudgetTests.cs @@ -30,6 +30,11 @@ public sealed class BudgetItem public string Name { get; set; } = string.Empty; } +[ForyStruct] +public sealed class BudgetEmpty +{ +} + [ForyStruct] public sealed class BudgetSiblings { @@ -43,6 +48,12 @@ public sealed class BudgetArrayHolder public BudgetItem[] Values { get; set; } = []; } +[ForyStruct] +public struct BudgetValue +{ + public int Id { get; set; } +} + [ForyStruct] public sealed class GeneratedSchemaListBudget { @@ -64,25 +75,34 @@ public sealed class GeneratedSchemaMapBudget public Dictionary Values { get; set; } = []; } -public sealed class ContainerMemoryBudgetTests +public sealed class GraphMemoryBudgetTests { private const int ReferenceBytes = 4; + private const int ObjectBytes = 1; + private const long BudgetEmptyBytes = ObjectBytes; + private const long BudgetItemBytes = ObjectBytes + 4 + ReferenceBytes; + private const long BudgetSiblingsBytes = ObjectBytes + ReferenceBytes + ReferenceBytes; + private const long BudgetArrayHolderBytes = ObjectBytes + ReferenceBytes; + private const long GeneratedGraphHolderBytes = ObjectBytes + ReferenceBytes; + private const long BudgetValueBytes = 4; private static int ElementBytes() => typeof(T).IsValueType ? Unsafe.SizeOf() : ReferenceBytes; - private static ForyRuntime NewFory(long maxContainerMemoryBytes = -1) + private static ForyRuntime NewFory(long maxGraphMemoryBytes = -1) { return ForyRuntime.Builder() .Compatible(false) .TrackRef(false) - .MaxContainerMemoryBytes(maxContainerMemoryBytes) + .MaxGraphMemoryBytes(maxGraphMemoryBytes) .Build() .Register(1001) - .Register(1002) - .Register(1003) - .Register(1004) - .Register(1005) - .Register(1006); + .Register(1002) + .Register(1003) + .Register(1004) + .Register(1005) + .Register(1006) + .Register(1007) + .Register(1008); } private static byte[] Serialize(T value) @@ -92,29 +112,29 @@ private static byte[] Serialize(T value) private static long ListBudget(int count) { - return (long)count * ElementBytes(); + return ObjectBytes + (long)count * ElementBytes(); } private static long ArrayBudget(int count) { - return (long)count * ElementBytes(); + return ObjectBytes + (long)count * ElementBytes(); } private static long MapBudget(int count) { - return (long)count * (ElementBytes() + ElementBytes()); + return ObjectBytes + (long)count * (ElementBytes() + ElementBytes()); } [Fact] public void KnownLengthAutoBudgetUsesInputBytes() { const int rootBytes = 17; - long expected = rootBytes * 8 + ReadContext.KnownContainerBudgetSlackBytes; + long expected = rootBytes * 8 + ReadContext.KnownGraphBudgetSlackBytes; ReadContext context = new(new ByteReader([]), new TypeResolver(), NewFory().Config); - context.InitContainerBudgetKnown(rootBytes); - context.ReserveContainerMemory(expected); - Assert.Throws(() => context.ReserveContainerMemory(ReferenceBytes)); + context.InitGraphBudgetKnown(rootBytes); + context.ReserveGraphMemory(expected); + Assert.Throws(() => context.ReserveGraphMemory(ReferenceBytes)); } [Fact] @@ -133,13 +153,24 @@ public void ExplicitConfigOverridesAutoBudget() { List value = Enumerable.Range(0, 8).Select(i => new BudgetItem { Id = i }).ToList(); byte[] bytes = Serialize(value); - long required = ListBudget(value.Count); + long required = ListBudget(value.Count) + value.Count * BudgetItemBytes; Assert.Throws(() => NewFory(required - 1).Deserialize>(bytes)); List result = NewFory(required).Deserialize>(bytes); Assert.Equal(value.Count, result.Count); } + [Fact] + public void EmptyObjectOwnerIsCharged() + { + List value = [new BudgetEmpty()]; + byte[] bytes = Serialize(value); + long required = ListBudget(value.Count) + value.Count * BudgetEmptyBytes; + + Assert.Throws(() => NewFory(required - 1).Deserialize>(bytes)); + Assert.Single(NewFory(required).Deserialize>(bytes)); + } + [Fact] public void SiblingContainersShareOneBudget() { @@ -149,10 +180,11 @@ public void SiblingContainersShareOneBudget() Right = Enumerable.Range(0, 16).Select(i => new BudgetItem { Id = i }).ToList(), }; byte[] bytes = Serialize(value); - long oneList = ListBudget(16); + long oneList = ListBudget(16) + 16 * BudgetItemBytes; + long required = BudgetSiblingsBytes + oneList * 2; - Assert.Throws(() => NewFory(oneList).Deserialize(bytes)); - BudgetSiblings result = NewFory(oneList * 2).Deserialize(bytes); + Assert.Throws(() => NewFory(required - 1).Deserialize(bytes)); + BudgetSiblings result = NewFory(required).Deserialize(bytes); Assert.Equal(16, result.Left.Count); Assert.Equal(16, result.Right.Count); } @@ -177,7 +209,8 @@ public void ReferenceArrayAndInlineValueListAreCharged() Values = Enumerable.Range(0, 4).Select(i => new BudgetItem { Id = i }).ToArray(), }; byte[] holderBytes = Serialize(holder); - long holderRequired = ListBudget(4) + ArrayBudget(4); + long holderRequired = + BudgetArrayHolderBytes + ArrayBudget(4) + holder.Values.Length * BudgetItemBytes; Assert.Throws(() => NewFory(holderRequired - 1).Deserialize(holderBytes)); Assert.Equal(4, NewFory(holderRequired).Deserialize(holderBytes).Values.Length); @@ -188,18 +221,33 @@ public void ReferenceArrayAndInlineValueListAreCharged() Assert.Equal(ints, NewFory(listRequired).Deserialize>(intBytes)); } + [Fact] + public void ValueStructOwnerIsChargedByHolder() + { + BudgetValue value = new() { Id = 7 }; + byte[] valueBytes = Serialize(value); + Assert.Throws(() => NewFory(BudgetValueBytes - 1).Deserialize(valueBytes)); + Assert.Equal(value.Id, NewFory(BudgetValueBytes).Deserialize(valueBytes).Id); + + List values = Enumerable.Range(0, 4).Select(i => new BudgetValue { Id = i }).ToList(); + byte[] listBytes = Serialize(values); + long listRequired = ListBudget(values.Count); + Assert.Throws(() => NewFory(listRequired - 1).Deserialize>(listBytes)); + Assert.Equal(values.Select(v => v.Id), NewFory(listRequired).Deserialize>(listBytes).Select(v => v.Id)); + } + [Fact] public void GeneratedSchemaContainersAreCharged() { GeneratedSchemaListBudget list = new() { Values = [1, 2, 3, 4, 5, 6] }; byte[] listBytes = Serialize(list); - long listRequired = ListBudget(list.Values.Count); + long listRequired = GeneratedGraphHolderBytes + ListBudget(list.Values.Count); Assert.Throws(() => NewFory(listRequired - 1).Deserialize(listBytes)); Assert.Equal(list.Values, NewFory(listRequired).Deserialize(listBytes).Values); GeneratedPackedListBudget packed = new() { Values = [1, 2, 3, 4, 5, 6] }; byte[] packedBytes = Serialize(packed); - long packedRequired = ListBudget(packed.Values.Count); + long packedRequired = GeneratedGraphHolderBytes + ListBudget(packed.Values.Count); Assert.Throws(() => NewFory(packedRequired - 1).Deserialize(packedBytes)); Assert.Equal(packed.Values, NewFory(packedRequired).Deserialize(packedBytes).Values); @@ -208,7 +256,7 @@ public void GeneratedSchemaContainersAreCharged() Values = new Dictionary { [1] = 1, [2] = 2, [3] = 3 }, }; byte[] mapBytes = Serialize(map); - long mapRequired = MapBudget(map.Values.Count); + long mapRequired = GeneratedGraphHolderBytes + MapBudget(map.Values.Count); Assert.Throws(() => NewFory(mapRequired - 1).Deserialize(mapBytes)); Assert.Equal(map.Values, NewFory(mapRequired).Deserialize(mapBytes).Values); } diff --git a/dart/packages/fory/lib/src/codegen/fory_generator.dart b/dart/packages/fory/lib/src/codegen/fory_generator.dart index cae611048a..5595fac063 100644 --- a/dart/packages/fory/lib/src/codegen/fory_generator.dart +++ b/dart/packages/fory/lib/src/codegen/fory_generator.dart @@ -628,9 +628,11 @@ final class ForyGenerator extends Generator { ..writeln() ..writeln(' @override') ..writeln(' ${structSpec.name} read(ReadContext context) {'); + final graphObjectBytes = _graphObjectBytes(structSpec); switch (structSpec.constructionModel.mode) { case _ConstructorMode.mutable: + output.writeln(' context.reserveGraphMemory($graphObjectBytes);'); output.writeln(' final value = ${structSpec.name}();'); if (_structNeedsEarlyReadReference(structSpec)) { output @@ -755,6 +757,7 @@ final class ForyGenerator extends Generator { } } final constructorInvocation = _constructorInvocation(structSpec); + output.writeln(' context.reserveGraphMemory($graphObjectBytes);'); output.writeln(' final value = $constructorInvocation;'); for (final fieldName in structSpec.constructionModel.postConstructionFieldNames) { @@ -809,6 +812,9 @@ final class ForyGenerator extends Generator { ); switch (structSpec.constructionModel.mode) { case _ConstructorMode.mutable: + output.writeln( + ' context.reserveGraphMemory(${_graphObjectBytes(structSpec)});', + ); output.writeln(' final value = ${structSpec.name}();'); if (_structNeedsEarlyReadReference(structSpec)) { output @@ -999,6 +1005,9 @@ final class ForyGenerator extends Generator { ..writeln(' }'); } final constructorInvocation = _constructorInvocation(structSpec); + output.writeln( + ' context.reserveGraphMemory(${_graphObjectBytes(structSpec)});', + ); output.writeln(' final value = $constructorInvocation;'); for (final fieldName in structSpec.constructionModel.postConstructionFieldNames) { @@ -1218,6 +1227,9 @@ final class ForyGenerator extends Generator { return '${structSpec.name}($arguments)'; } + int _graphObjectBytes(_GeneratedStructSpec structSpec) => + 1 + structSpec.fields.length * 4; + bool _isSkipped(FieldElement field) { final annotation = _fieldAnnotationOf(field); if (annotation == null) { diff --git a/dart/packages/fory/lib/src/config.dart b/dart/packages/fory/lib/src/config.dart index 6f529ecc3c..3fea4e99df 100644 --- a/dart/packages/fory/lib/src/config.dart +++ b/dart/packages/fory/lib/src/config.dart @@ -28,7 +28,7 @@ final class Config { static const int defaultMaxTypeMetaBytes = 4096; static const int defaultMaxSchemaVersionsPerType = 10; static const int defaultMaxAverageSchemaVersionsPerType = 3; - static const int defaultMaxContainerMemoryBytes = -1; + static const int defaultMaxGraphMemoryBytes = -1; /// Enables compatible struct encoding and decoding. /// @@ -57,10 +57,10 @@ final class Config { /// types. final int maxAverageSchemaVersionsPerType; - /// Maximum estimated container-owned memory per root deserialization. + /// Maximum estimated graph memory per root deserialization. /// /// `-1` means auto. Positive values are explicit byte limits. - final int maxContainerMemoryBytes; + final int maxGraphMemoryBytes; /// Creates an immutable configuration object. /// @@ -75,7 +75,7 @@ final class Config { this.maxSchemaVersionsPerType = defaultMaxSchemaVersionsPerType, this.maxAverageSchemaVersionsPerType = defaultMaxAverageSchemaVersionsPerType, - this.maxContainerMemoryBytes = defaultMaxContainerMemoryBytes, + this.maxGraphMemoryBytes = defaultMaxGraphMemoryBytes, }) : checkStructVersion = compatible ? false : checkStructVersion, assert(maxDepth > 0, 'maxDepth must be positive'), assert(maxTypeFields > 0, 'maxTypeFields must be positive'), @@ -89,7 +89,7 @@ final class Config { 'maxAverageSchemaVersionsPerType must be positive', ), assert( - maxContainerMemoryBytes == -1 || maxContainerMemoryBytes > 0, - 'maxContainerMemoryBytes must be -1 or positive', + maxGraphMemoryBytes == -1 || maxGraphMemoryBytes > 0, + 'maxGraphMemoryBytes must be -1 or positive', ); } diff --git a/dart/packages/fory/lib/src/context/read_context.dart b/dart/packages/fory/lib/src/context/read_context.dart index ba186c4f80..a47b382c33 100644 --- a/dart/packages/fory/lib/src/context/read_context.dart +++ b/dart/packages/fory/lib/src/context/read_context.dart @@ -58,8 +58,8 @@ final class ReadContext { late Buffer _buffer; final List _sharedTypes = []; int _depth = 0; - int _effectiveContainerMemoryBytes = 0; - int _remainingContainerMemoryBytes = 0; + int _effectiveGraphMemoryBytes = 0; + int _remainingGraphMemoryBytes = 0; @internal ReadContext( @@ -73,17 +73,17 @@ final class ReadContext { @pragma('vm:prefer-inline') void prepare(Buffer buffer) { _buffer = buffer; - final configured = config.maxContainerMemoryBytes; + final configured = config.maxGraphMemoryBytes; final limit = configured > 0 ? configured : buffer.readableBytes * _knownRootBudgetMultiplier + _knownRootBudgetSlackBytes; if (limit > _maxSafeBudgetBytes) { - _throwContainerMemoryOverflow(limit); + _throwGraphMemoryOverflow(limit); } - _effectiveContainerMemoryBytes = limit; - _remainingContainerMemoryBytes = limit; + _effectiveGraphMemoryBytes = limit; + _remainingGraphMemoryBytes = limit; } @internal @@ -92,8 +92,8 @@ final class ReadContext { _refReader.reset(); _metaStringReader.reset(); _depth = 0; - _effectiveContainerMemoryBytes = 0; - _remainingContainerMemoryBytes = 0; + _effectiveGraphMemoryBytes = 0; + _remainingGraphMemoryBytes = 0; } /// The active input buffer for the current operation. @@ -106,37 +106,37 @@ final class ReadContext { RefReader get refReader => _refReader; @internal - int get effectiveContainerMemoryBytes => _effectiveContainerMemoryBytes; + int get effectiveGraphMemoryBytes => _effectiveGraphMemoryBytes; @internal - int get remainingContainerMemoryBytes => _remainingContainerMemoryBytes; + int get remainingGraphMemoryBytes => _remainingGraphMemoryBytes; @internal @pragma('vm:prefer-inline') - void reserveContainerMemory(int bytes) { + void reserveGraphMemory(int bytes) { if (bytes < 0 || bytes > _maxSafeBudgetBytes) { - _throwContainerMemoryOverflow(bytes); + _throwGraphMemoryOverflow(bytes); } - final remaining = _remainingContainerMemoryBytes - bytes; + final remaining = _remainingGraphMemoryBytes - bytes; if (remaining < 0) { - _throwContainerMemoryExceeded(bytes); + _throwGraphMemoryExceeded(bytes); } - _remainingContainerMemoryBytes = remaining; + _remainingGraphMemoryBytes = remaining; } @pragma('vm:never-inline') - Never _throwContainerMemoryOverflow(int bytes) { + Never _throwGraphMemoryOverflow(int bytes) { throw StateError( - 'maxContainerMemoryBytes overflow: requested $bytes estimated container bytes.', + 'maxGraphMemoryBytes overflow: requested $bytes estimated graph bytes.', ); } @pragma('vm:never-inline') - Never _throwContainerMemoryExceeded(int bytes) { + Never _throwGraphMemoryExceeded(int bytes) { throw StateError( - 'maxContainerMemoryBytes exceeded: requested $bytes estimated container bytes, ' - '$_remainingContainerMemoryBytes remaining, effective limit ' - '$_effectiveContainerMemoryBytes.', + 'maxGraphMemoryBytes exceeded: requested $bytes estimated graph bytes, ' + '$_remainingGraphMemoryBytes remaining, effective limit ' + '$_effectiveGraphMemoryBytes.', ); } diff --git a/dart/packages/fory/lib/src/fory.dart b/dart/packages/fory/lib/src/fory.dart index 48a5f9b133..b39f3a0db1 100644 --- a/dart/packages/fory/lib/src/fory.dart +++ b/dart/packages/fory/lib/src/fory.dart @@ -62,13 +62,13 @@ final class Fory { int maxSchemaVersionsPerType = Config.defaultMaxSchemaVersionsPerType, int maxAverageSchemaVersionsPerType = Config.defaultMaxAverageSchemaVersionsPerType, - int maxContainerMemoryBytes = Config.defaultMaxContainerMemoryBytes, + int maxGraphMemoryBytes = Config.defaultMaxGraphMemoryBytes, }) { - if (maxContainerMemoryBytes != Config.defaultMaxContainerMemoryBytes && - maxContainerMemoryBytes <= 0) { + if (maxGraphMemoryBytes != Config.defaultMaxGraphMemoryBytes && + maxGraphMemoryBytes <= 0) { throw ArgumentError.value( - maxContainerMemoryBytes, - 'maxContainerMemoryBytes', + maxGraphMemoryBytes, + 'maxGraphMemoryBytes', 'must be -1 or positive', ); } @@ -80,7 +80,7 @@ final class Fory { maxTypeMetaBytes: maxTypeMetaBytes, maxSchemaVersionsPerType: maxSchemaVersionsPerType, maxAverageSchemaVersionsPerType: maxAverageSchemaVersionsPerType, - maxContainerMemoryBytes: maxContainerMemoryBytes, + maxGraphMemoryBytes: maxGraphMemoryBytes, ); _readBuffer = Buffer(); _writeBuffer = Buffer(); diff --git a/dart/packages/fory/lib/src/serializer/collection_serializers.dart b/dart/packages/fory/lib/src/serializer/collection_serializers.dart index 5d8f7234ac..344d0f85f7 100644 --- a/dart/packages/fory/lib/src/serializer/collection_serializers.dart +++ b/dart/packages/fory/lib/src/serializer/collection_serializers.dart @@ -39,6 +39,7 @@ import 'package:fory/src/types/int64.dart'; import 'package:fory/src/types/uint64.dart'; const int _referenceBytes = 4; +const int _ownerBytes = 1; @pragma('vm:prefer-inline') void _writeDirectTypeInfoValue( @@ -385,7 +386,7 @@ final class SetSerializer extends Serializer { elementFieldType, hasPreservedRef: hasPreservedRef, ); - context.reserveContainerMemory(values.length * _referenceBytes); + context.reserveGraphMemory(_ownerBytes + values.length * _referenceBytes); return Set.of(values); } } @@ -510,7 +511,9 @@ Object _readCompatibleListAsArrayField( String fieldName, ) { final size = context.buffer.readVarUint32(); - context.reserveContainerMemory(size * _arrayElementBytes(arrayTypeId)); + context.reserveGraphMemory( + _ownerBytes + size * _arrayElementBytes(arrayTypeId), + ); if (size == 0) { return _newArrayValue(arrayTypeId, 0); } @@ -644,11 +647,11 @@ void _setArrayValue(Object target, int arrayTypeId, int index, Object? value) { Object _arrayToListValue(ReadContext context, Object? raw) { if (raw is BoolList) { - context.reserveContainerMemory(raw.length * _referenceBytes); + context.reserveGraphMemory(_ownerBytes + raw.length * _referenceBytes); return raw.toList(); } if (raw is Iterable) { - context.reserveContainerMemory(raw.length * _referenceBytes); + context.reserveGraphMemory(_ownerBytes + raw.length * _referenceBytes); return raw.toList(); } throw StateError('Expected compatible array payload.'); @@ -658,9 +661,14 @@ Object _arrayToListValue(ReadContext context, Object? raw) { List readTypedListPayload( ReadContext context, FieldType? elementFieldType, - T Function(Object? value) convert, -) { - final state = _prepareListRead(context, elementFieldType); + T Function(Object? value) convert, { + bool reserveOwner = true, +}) { + final state = _prepareListRead( + context, + elementFieldType, + reserveOwner: reserveOwner, + ); if (state.size == 0) { return List.empty(growable: false); } @@ -738,8 +746,13 @@ Set readTypedSetPayload( FieldType? elementFieldType, T Function(Object? value) convert, ) { - final values = readTypedListPayload(context, elementFieldType, convert); - context.reserveContainerMemory(values.length * _referenceBytes); + final values = readTypedListPayload( + context, + elementFieldType, + convert, + reserveOwner: false, + ); + context.reserveGraphMemory(_ownerBytes + values.length * _referenceBytes); return Set.of(values); } @@ -928,10 +941,13 @@ final class _PreparedListRead { @pragma('vm:prefer-inline') _PreparedListRead _prepareListRead( ReadContext context, - FieldType? elementFieldType, -) { + FieldType? elementFieldType, { + bool reserveOwner = true, +}) { final size = context.buffer.readVarUint32(); - context.reserveContainerMemory(size * _referenceBytes); + if (reserveOwner) { + context.reserveGraphMemory(_ownerBytes + size * _referenceBytes); + } if (size == 0) { return _PreparedListRead( size: 0, diff --git a/dart/packages/fory/lib/src/serializer/map_serializers.dart b/dart/packages/fory/lib/src/serializer/map_serializers.dart index a872837bab..f0aae14c72 100644 --- a/dart/packages/fory/lib/src/serializer/map_serializers.dart +++ b/dart/packages/fory/lib/src/serializer/map_serializers.dart @@ -259,7 +259,7 @@ Map readTypedMapPayload( bool hasPreservedRef = false, }) { var remaining = context.buffer.readVarUint32(); - context.reserveContainerMemory(remaining * 2 * _referenceBytes); + context.reserveGraphMemory(1 + remaining * 2 * _referenceBytes); context.buffer.checkReadableBytes(remaining); final declaredKeyTypeInfo = keyFieldType == null || keyFieldType.isDynamic diff --git a/dart/packages/fory/test/container_memory_budget_test.dart b/dart/packages/fory/test/graph_memory_budget_test.dart similarity index 68% rename from dart/packages/fory/test/container_memory_budget_test.dart rename to dart/packages/fory/test/graph_memory_budget_test.dart index 52463492f7..2cb2269978 100644 --- a/dart/packages/fory/test/container_memory_budget_test.dart +++ b/dart/packages/fory/test/graph_memory_budget_test.dart @@ -27,9 +27,15 @@ import 'package:fory/src/serializer/collection_serializers.dart'; import 'package:fory/src/serializer/map_serializers.dart'; import 'package:test/test.dart'; -part 'container_memory_budget_test.fory.dart'; +part 'graph_memory_budget_test.fory.dart'; -const Matcher _throwsContainerBudget = ThrowsContainerBudget(); +const Matcher _throwsGraphBudget = ThrowsGraphBudget(); +const int _objectBytes = 1; +const int _referenceBytes = 4; + +int _objectGraphBytes(int fields) => _objectBytes + fields * _referenceBytes; +int _listGraphBytes(int count) => _objectBytes + count * _referenceBytes; +int _mapGraphBytes(int count) => _objectBytes + count * 2 * _referenceBytes; @ForyStruct() class BudgetGeneratedEnvelope { @@ -61,12 +67,12 @@ class BudgetCompatibleArrayEnvelope { Int32List values = Int32List(0); } -final class ThrowsContainerBudget extends Matcher { - const ThrowsContainerBudget(); +final class ThrowsGraphBudget extends Matcher { + const ThrowsGraphBudget(); @override Description describe(Description description) { - return description.add('throws a maxContainerMemoryBytes StateError'); + return description.add('throws a maxGraphMemoryBytes StateError'); } @override @@ -77,14 +83,14 @@ final class ThrowsContainerBudget extends Matcher { try { item(); } on StateError catch (error) { - return error.message.contains('maxContainerMemoryBytes'); + return error.message.contains('maxGraphMemoryBytes'); } return false; } } void _registerGenerated(Fory fory) { - ContainerMemoryBudgetTestForyModule.register( + GraphMemoryBudgetTestForyModule.register( fory, BudgetGeneratedEnvelope, name: 'test.BudgetGeneratedEnvelope', @@ -92,7 +98,7 @@ void _registerGenerated(Fory fory) { } void _registerCompatibleList(Fory fory) { - ContainerMemoryBudgetTestForyModule.register( + GraphMemoryBudgetTestForyModule.register( fory, BudgetCompatibleListEnvelope, name: 'test.BudgetCompatibleEnvelope', @@ -100,15 +106,15 @@ void _registerCompatibleList(Fory fory) { } void _registerCompatibleArray(Fory fory) { - ContainerMemoryBudgetTestForyModule.register( + GraphMemoryBudgetTestForyModule.register( fory, BudgetCompatibleArrayEnvelope, name: 'test.BudgetCompatibleEnvelope', ); } -ReadContext _readContext(Buffer buffer, {int maxContainerMemoryBytes = -1}) { - final config = Config(maxContainerMemoryBytes: maxContainerMemoryBytes); +ReadContext _readContext(Buffer buffer, {int maxGraphMemoryBytes = -1}) { + final config = Config(maxGraphMemoryBytes: maxGraphMemoryBytes); final resolver = TypeResolver(config); return ReadContext(config, resolver, RefReader(), MetaStringReader(resolver)) ..prepare(buffer); @@ -118,54 +124,73 @@ Uint8List _serialize(Object? value) => Fory().serialize(value); Object? _readWithBudget(Object? value, int budget) { return Fory( - maxContainerMemoryBytes: budget, + maxGraphMemoryBytes: budget, ).deserialize(_serialize(value)); } void main() { - group('container memory budget', () { + group('graph memory budget', () { test('known length auto derives from input bytes', () { final buffer = Buffer.wrap(Uint8List(17)); final context = _readContext(buffer); - expect(context.effectiveContainerMemoryBytes, equals(17 * 8 + 64 * 1024)); + expect(context.effectiveGraphMemoryBytes, equals(17 * 8 + 64 * 1024)); expect( - () => context.reserveContainerMemory(17 * 8 + 64 * 1024), + () => context.reserveGraphMemory(17 * 8 + 64 * 1024), returnsNormally, ); - expect(() => context.reserveContainerMemory(1), _throwsContainerBudget); + expect(() => context.reserveGraphMemory(1), _throwsGraphBudget); }); test('explicit config overrides auto', () { final buffer = Buffer.wrap(Uint8List(4096)); - final context = _readContext(buffer, maxContainerMemoryBytes: 31); + final context = _readContext(buffer, maxGraphMemoryBytes: 31); - expect(context.effectiveContainerMemoryBytes, equals(31)); - expect(() => context.reserveContainerMemory(31), returnsNormally); - expect(() => context.reserveContainerMemory(1), _throwsContainerBudget); - expect(() => Fory(maxContainerMemoryBytes: 0), throwsArgumentError); - expect(() => Fory(maxContainerMemoryBytes: -2), throwsArgumentError); + expect(context.effectiveGraphMemoryBytes, equals(31)); + expect(() => context.reserveGraphMemory(31), returnsNormally); + expect(() => context.reserveGraphMemory(1), _throwsGraphBudget); + expect(() => Fory(maxGraphMemoryBytes: 0), throwsArgumentError); + expect(() => Fory(maxGraphMemoryBytes: -2), throwsArgumentError); }); test('uses parent storage for nested empty containers', () { final value = [[]]; - expect(() => _readWithBudget(value, 3), _throwsContainerBudget); - expect(_readWithBudget(value, 4), equals(value)); + expect( + () => + _readWithBudget(value, _listGraphBytes(1) + _listGraphBytes(0) - 1), + _throwsGraphBudget, + ); + expect( + _readWithBudget(value, _listGraphBytes(1) + _listGraphBytes(0)), + equals(value), + ); }); test('reserves sibling containers cumulatively', () { final value = [[], [], []]; - expect(() => _readWithBudget(value, 11), _throwsContainerBudget); - expect(_readWithBudget(value, 12), equals(value)); + expect( + () => _readWithBudget( + value, + _listGraphBytes(3) + 3 * _listGraphBytes(0) - 1, + ), + _throwsGraphBudget, + ); + expect( + _readWithBudget(value, _listGraphBytes(3) + 3 * _listGraphBytes(0)), + equals(value), + ); }); test('reserves map entries', () { final value = {'a': 1}; - expect(() => _readWithBudget(value, 7), _throwsContainerBudget); - expect(_readWithBudget(value, 8), equals(value)); + expect( + () => _readWithBudget(value, _mapGraphBytes(1) - 1), + _throwsGraphBudget, + ); + expect(_readWithBudget(value, _mapGraphBytes(1)), equals(value)); }); test('reserves generated list set and map reads', () { @@ -178,14 +203,19 @@ void main() { ..counts = {'one': 1}, ); - final failingReader = Fory(maxContainerMemoryBytes: 19); + final required = + _objectGraphBytes(3) + + _listGraphBytes(1) + + _listGraphBytes(1) + + _mapGraphBytes(1); + final failingReader = Fory(maxGraphMemoryBytes: required - 1); _registerGenerated(failingReader); expect( () => failingReader.deserialize(bytes), - _throwsContainerBudget, + _throwsGraphBudget, ); - final passingReader = Fory(maxContainerMemoryBytes: 20); + final passingReader = Fory(maxGraphMemoryBytes: required); _registerGenerated(passingReader); final roundTrip = passingReader.deserialize( bytes, @@ -202,14 +232,15 @@ void main() { BudgetCompatibleListEnvelope()..values = [1, 2, 3], ); - final arrayFail = Fory(maxContainerMemoryBytes: 11); + final required = _objectGraphBytes(1) + _objectBytes + 3 * 4; + final arrayFail = Fory(maxGraphMemoryBytes: required - 1); _registerCompatibleArray(arrayFail); expect( () => arrayFail.deserialize(listBytes), - _throwsContainerBudget, + _throwsGraphBudget, ); - final arrayPass = Fory(maxContainerMemoryBytes: 12); + final arrayPass = Fory(maxGraphMemoryBytes: required); _registerCompatibleArray(arrayPass); expect( arrayPass @@ -226,14 +257,14 @@ void main() { ..values = Int32List.fromList([1, 2, 3]), ); - final listFail = Fory(maxContainerMemoryBytes: 11); + final listFail = Fory(maxGraphMemoryBytes: required - 1); _registerCompatibleList(listFail); expect( () => listFail.deserialize(arrayBytes), - _throwsContainerBudget, + _throwsGraphBudget, ); - final listPass = Fory(maxContainerMemoryBytes: 12); + final listPass = Fory(maxGraphMemoryBytes: required); _registerCompatibleList(listPass); expect( listPass.deserialize(arrayBytes).values, @@ -242,7 +273,7 @@ void main() { }); test('skips strings binary and dense typed arrays', () { - final fory = Fory(maxContainerMemoryBytes: 1); + final fory = Fory(maxGraphMemoryBytes: 1); final text = List.filled(128, 'x').join(); expect(fory.deserialize(Fory().serialize(text)), hasLength(128)); diff --git a/docs/guide/cpp/configuration.md b/docs/guide/cpp/configuration.md index 3041fe04df..bd13ff7e27 100644 --- a/docs/guide/cpp/configuration.md +++ b/docs/guide/cpp/configuration.md @@ -96,14 +96,14 @@ When enabled, avoids duplicating shared objects and handles cycles. **Default:** `true` -### max_container_memory_bytes(int64_t) +### max_graph_memory_bytes(int64_t) -Set the maximum estimated memory that container objects may reserve during one -root deserialization. +Set the maximum estimated shallow graph memory accepted during one root +deserialization. ```cpp auto fory = Fory::builder() - .max_container_memory_bytes(64 * 1024 * 1024) + .max_graph_memory_bytes(64 * 1024 * 1024) .build(); ``` @@ -112,12 +112,12 @@ automatic limit is the root input size multiplied by `8`, plus `64 KiB`. For stream roots, the automatic limit is `128 MiB` because the full root size is not known up front. Positive values always override the automatic limit. -This budget is a portable lower-bound estimate for container-owned storage such -as dynamic collection backing storage, map key/value storage, and -object/reference array slots. It is not an exact process heap limit and does -not include STL implementation details such as debug nodes, table buckets, or -allocator headers. Empty containers with no dynamic backing normally do not -consume the budget. Dedicated string, binary, and primitive dense-array payloads +This budget is a portable lower-bound estimate for shallow materialized graph +owners such as dynamic collection backing storage, map key/value storage, +object/reference array slots, and struct or object field storage. It is not an +exact process heap limit and does not include STL implementation details such as +debug nodes, table buckets, or allocator headers. Dedicated string, binary, and +primitive dense-array payloads continue to rely on their byte-availability checks instead. `std::vector` is counted as packed standard-container storage. @@ -232,7 +232,7 @@ auto fory = Fory::builder().build_thread_safe(); // Returns ThreadSafeFory | `xlang(bool)` | Use xlang mode | `true` | | `compatible(bool)` | Enable schema evolution | `true` | | `track_ref(bool)` | Enable reference tracking | `true` | -| `max_container_memory_bytes(int64_t)` | Max estimated container memory per root read | `-1` | +| `max_graph_memory_bytes(int64_t)` | Max estimated graph memory per root read | `-1` | | `max_dyn_depth(uint32_t)` | Maximum nesting depth for dynamic types | `5` | | `max_type_fields(uint32_t)` | Max fields in one received struct metadata body | `512` | | `max_type_meta_bytes(uint32_t)` | Max encoded bytes in one received metadata body | `4096` | @@ -246,8 +246,8 @@ Security-related configuration: - Register all structs and polymorphic implementations before deserializing untrusted payloads. - Use `check_struct_version(true)` with `compatible(false)` for intentional same-schema payloads. -- Leave `max_container_memory_bytes(-1)` enabled for automatic root-size-based container limits, or - set a positive value for a stricter trusted-workload envelope. +- Leave `max_graph_memory_bytes(-1)` enabled for automatic root-size-based graph limits, or set a + positive value for a stricter trusted-workload envelope. - Keep `max_dyn_depth(...)` as low as your model permits to reject unexpectedly deep polymorphic graphs. - Keep the remote schema metadata limits at their defaults unless the data is not malicious and a diff --git a/docs/guide/csharp/configuration.md b/docs/guide/csharp/configuration.md index b463e3226b..5c045316f5 100644 --- a/docs/guide/csharp/configuration.md +++ b/docs/guide/csharp/configuration.md @@ -41,7 +41,7 @@ ThreadSafeFory threadSafe = Fory.Builder().BuildThreadSafe(); | `Compatible` | `true` | Compatible schema-evolution metadata enabled | | `CheckStructVersion` | `false` | Struct schema hash checks disabled | | `MaxDepth` | `20` | Max dynamic nesting depth | -| `MaxContainerMemoryBytes` | `-1` | Auto container memory budget | +| `MaxGraphMemoryBytes` | `-1` | Auto graph memory budget | | `MaxTypeFields` | `512` | Max fields in one received struct metadata body | | `MaxTypeMetaBytes` | `4096` | Max encoded bytes in one received metadata body | | `MaxSchemaVersionsPerType` | `10` | Max remote metadata versions for one logical type | @@ -97,14 +97,13 @@ Fory fory = Fory.Builder() `value` must be greater than `0`. -### `MaxContainerMemoryBytes(long value)` +### `MaxGraphMemoryBytes(long value)` -Sets the maximum estimated lower-bound container-owned storage accepted during one root -deserialization. +Sets the maximum estimated shallow graph memory accepted during one root deserialization. ```csharp Fory fory = Fory.Builder() - .MaxContainerMemoryBytes(64L * 1024 * 1024) + .MaxGraphMemoryBytes(64L * 1024 * 1024) .Build(); ``` @@ -189,8 +188,8 @@ Security-related configuration: - Register only the expected types before deserializing untrusted payloads. - Use `CheckStructVersion(true)` with `Compatible(false)` for intentional same-schema payloads. - Set `MaxDepth(...)` to reject unexpectedly deep dynamic object graphs. -- Set `MaxContainerMemoryBytes(...)` to cap estimated lower-bound list, array, set, and map storage - during one root deserialization. +- Set `MaxGraphMemoryBytes(...)` to cap estimated shallow graph memory during one root + deserialization. - Keep the remote schema metadata limits at their defaults unless the data is not malicious and a trusted peer sends larger metadata or many schema versions. - Prefer generated or registered concrete models over broad dynamic fields for untrusted input. diff --git a/docs/guide/dart/configuration.md b/docs/guide/dart/configuration.md index d1466d83af..2c8b6baea3 100644 --- a/docs/guide/dart/configuration.md +++ b/docs/guide/dart/configuration.md @@ -38,7 +38,7 @@ final fory = Fory( maxTypeMetaBytes: 4096, maxSchemaVersionsPerType: 10, maxAverageSchemaVersionsPerType: 3, - maxContainerMemoryBytes: 64 * 1024 * 1024, + maxGraphMemoryBytes: 64 * 1024 * 1024, ); ``` @@ -108,13 +108,12 @@ final fory = Fory( - `maxAverageSchemaVersionsPerType` limits the average across accepted remote types. The effective global floor is `8192` schemas. -### `maxContainerMemoryBytes` +### `maxGraphMemoryBytes` -Limits estimated lower-bound container-owned storage for one root deserialization. The budget -covers Dart list/set/object-reference slots, map key/value slots, and compatible list/array -materialization. Empty containers without backing storage normally do not consume the budget. It -does not count strings, binary values, or dense typed-array payloads, which are protected by -byte-availability checks. +Limits estimated shallow graph memory for one root deserialization. The budget covers materialized +Dart lists, sets, maps, object/reference arrays, structs, objects, and compatible list/array +materialization. It does not count strings, binary values, or dense typed-array payloads, which are +protected by byte-availability checks. The default is `-1`, which means auto. Dart root inputs are memory-backed, so auto derives from the root input size: @@ -127,7 +126,7 @@ Set a positive value when a trusted workload legitimately contains compact, cont payloads: ```dart -final fory = Fory(maxContainerMemoryBytes: 256 * 1024 * 1024); +final fory = Fory(maxGraphMemoryBytes: 256 * 1024 * 1024); ``` ## Defaults @@ -141,7 +140,7 @@ final fory = Fory(maxContainerMemoryBytes: 256 * 1024 * 1024); | `maxTypeMetaBytes` | 4096 | | `maxSchemaVersionsPerType` | 10 | | `maxAverageSchemaVersionsPerType` | 3 | -| `maxContainerMemoryBytes` | -1 | +| `maxGraphMemoryBytes` | -1 | ## Xlang Notes @@ -158,8 +157,8 @@ Security-related configuration: - Register only the expected generated models before deserializing untrusted payloads. - Use `checkStructVersion: true` with `compatible: false` for intentional same-schema payloads. - Set `maxDepth` to reject unexpectedly deep payload shapes. -- Keep `maxContainerMemoryBytes` at the auto default for most inputs, or set an explicit positive - byte limit for known trusted container-heavy payloads. +- Keep `maxGraphMemoryBytes` at the auto default for most inputs, or set an explicit positive byte + limit for known trusted graph-heavy payloads. - Keep the remote schema metadata limits at their defaults unless the data is not malicious and a trusted peer sends larger metadata or many schema versions. - Prefer generated schemas and explicit field metadata over broad dynamic fields for untrusted input. diff --git a/docs/guide/go/configuration.md b/docs/guide/go/configuration.md index 636fb4a22c..c260312c29 100644 --- a/docs/guide/go/configuration.md +++ b/docs/guide/go/configuration.md @@ -39,7 +39,7 @@ Default settings: | MaxDepth | 20 | Maximum nesting depth | | IsXlang | true | Xlang mode enabled | | Compatible | true | Compatible schema-evolution metadata enabled | -| MaxContainerMemoryBytes | -1 | Automatic container memory limit per root read | +| MaxGraphMemoryBytes | -1 | Automatic graph memory limit per root read | | MaxTypeFields | 512 | Max fields in one received struct metadata body | | MaxTypeMetaBytes | 4096 | Max encoded bytes in one received metadata body | | MaxSchemaVersionsPerType | 10 | Max remote metadata versions for one logical type | @@ -52,7 +52,7 @@ f := fory.New( fory.WithXlang(true), fory.WithTrackRef(true), fory.WithMaxDepth(10), - fory.WithMaxContainerMemoryBytes(-1), + fory.WithMaxGraphMemoryBytes(-1), fory.WithMaxTypeFields(512), fory.WithMaxTypeMetaBytes(4096), fory.WithMaxSchemaVersionsPerType(10), @@ -129,12 +129,12 @@ f := fory.New(fory.WithMaxDepth(30)) - Protects against deeply nested, recursive structures or malicious data - Serialization fails with error when exceeded -### WithMaxContainerMemoryBytes +### WithMaxGraphMemoryBytes -Limit estimated container-owned memory accepted during one root deserialization: +Limit estimated shallow graph memory accepted during one root deserialization: ```go -f := fory.New(fory.WithMaxContainerMemoryBytes(256 * 1024 * 1024)) +f := fory.New(fory.WithMaxGraphMemoryBytes(256 * 1024 * 1024)) ``` The default `-1` selects an automatic limit. Byte-slice roots use: @@ -145,10 +145,9 @@ inputBytes * 8 + 64 KiB `DeserializeFromReader` and `DeserializeFromStream` use `128 MiB` because the full root length is unknown. The budget covers lower-bound slice backing -storage, map key/value storage, sets, and generated container reads. Empty -containers without backing storage normally do not consume the budget. Strings, -binary blobs, and primitive dense array -owners keep their byte-availability checks and are not charged to this budget. +storage, map key/value storage, sets, generated object reads, and materialized +struct field storage. Strings, binary blobs, and primitive dense array owners +keep their byte-availability checks and are not reserved against this budget. Set a positive value when a service needs a stricter or larger limit for trusted data. diff --git a/docs/guide/java/configuration.md b/docs/guide/java/configuration.md index ec4d7b1eba..e5bf5d461e 100644 --- a/docs/guide/java/configuration.md +++ b/docs/guide/java/configuration.md @@ -38,7 +38,7 @@ This page documents all configuration options available through `ForyBuilder`. | `registerGuavaTypes` | Whether to pre-register Guava types such as `RegularImmutableMap`/`RegularImmutableList`. These types are not public API, but seem pretty stable. | `true` | | `requireClassRegistration` | Disabling may allow unknown classes to be deserialized, potentially causing security risks. | `true` | | `maxDepth` | Set max depth for deserialization, when depth exceeds, an exception will be thrown. This can be used to refuse deserialization DDOS attack. | `50` | -| `maxContainerMemoryBytes` | Maximum estimated container-owned memory accepted during one root deserialization. `-1` derives an automatic limit from the input shape: known-length inputs use `inputBytes * 8 + 64 KiB`, and stream or unknown-length inputs use `128 MiB`. Positive values set an explicit byte limit. | `-1` | +| `maxGraphMemoryBytes` | Maximum estimated shallow graph memory accepted during one root deserialization. `-1` derives an automatic limit from the input shape: known-length inputs use `inputBytes * 8 + 64 KiB`, and stream or unknown-length inputs use `128 MiB`. Positive values set an explicit byte limit. | `-1` | | `maxTypeFields` | Maximum fields accepted in one received remote struct metadata body. | `512` | | `maxTypeMetaBytes` | Maximum encoded body bytes accepted for one received TypeDef or TypeMeta body, excluding the 8-byte header and any extended-size varint. | `4096` | | `maxSchemaVersionsPerType` | Maximum accepted remote metadata versions for one logical type. | `10` | @@ -91,7 +91,7 @@ Keep class registration enabled for production and any untrusted payload source: Fory fory = Fory.builder() .requireClassRegistration(true) .withMaxDepth(50) - .withMaxContainerMemoryBytes(-1) + .withMaxGraphMemoryBytes(-1) .build(); ``` @@ -99,10 +99,9 @@ Security-related options: - `requireClassRegistration(true)` restricts deserialization to registered classes. - `withMaxDepth(...)` rejects unexpectedly deep object graphs. -- `withMaxContainerMemoryBytes(...)` bounds estimated lower-bound container-owned storage during - one root deserialization. Empty containers without backing storage normally do not consume the - budget. Keep `-1` for the automatic input-shaped default, or set a positive byte limit when - trusted payloads need a larger or smaller limit. +- `withMaxGraphMemoryBytes(...)` bounds estimated shallow graph memory during one root + deserialization. Keep `-1` for the automatic input-shaped default, or set a positive byte limit + when trusted payloads need a larger or smaller limit. - `withMaxTypeFields(...)` and `withMaxTypeMetaBytes(...)` bound the field count and encoded body size of one received remote metadata body. - `withMaxSchemaVersionsPerType(...)` and diff --git a/docs/guide/javascript/configuration.md b/docs/guide/javascript/configuration.md index 4f732c040c..0db0f5d7ab 100644 --- a/docs/guide/javascript/configuration.md +++ b/docs/guide/javascript/configuration.md @@ -43,7 +43,7 @@ const fory = new Fory({ ref: true, compatible: true, maxDepth: 100, - maxContainerMemoryBytes: -1, + maxGraphMemoryBytes: -1, maxTypeFields: 512, maxTypeMetaBytes: 4096, maxSchemaVersionsPerType: 10, @@ -57,7 +57,7 @@ const fory = new Fory({ | `ref` | `false` | Enable reference tracking for shared or circular object graphs | | `compatible` | `true` | Allow field additions/removals without breaking existing messages | | `maxDepth` | `50` | Maximum nesting depth. Must be `>= 2`. Increase for deeply nested structures | -| `maxContainerMemoryBytes` | `-1` | Maximum estimated container-owned memory accepted during one root deserialization | +| `maxGraphMemoryBytes` | `-1` | Maximum estimated shallow graph memory accepted during one root deserialization | | `maxTypeFields` | `512` | Maximum fields accepted in one received remote struct metadata body | | `maxTypeMetaBytes` | `4096` | Maximum encoded body bytes accepted for one received TypeMeta body | | `maxSchemaVersionsPerType` | `10` | Maximum accepted remote metadata versions for one logical type | @@ -94,28 +94,27 @@ to that struct. For cross-language payloads, set `compatible: false` only after verifying that every language uses the same schema, or when native types are generated from Fory schema IDL. See [Schema Evolution](schema-evolution.md). -## Container Memory Budget +## Graph Memory Budget -`maxContainerMemoryBytes` limits estimated lower-bound container-owned storage -accepted during one root deserialization. The budget covers array, set, object -array, and map reference slots; it is not an exact JavaScript heap limit. Empty -containers without backing storage normally do not consume the budget. The -default `-1` derives an automatic limit from the input bytes. JavaScript -deserializes from `Uint8Array` roots, so the automatic limit is -`inputBytes * 8 + 64 KiB`. +`maxGraphMemoryBytes` limits estimated shallow graph memory accepted during one +root deserialization. The budget covers materialized arrays, sets, object +arrays, maps, structs, and objects; it is not an exact JavaScript heap limit. +The default `-1` derives an automatic limit from the input bytes. JavaScript +deserializes from `Uint8Array` roots, so the automatic limit is `inputBytes \* 8 + +- 64 KiB`. Use a positive byte value to set an explicit lower or higher limit: ```ts const fory = new Fory({ - maxContainerMemoryBytes: 32 * 1024 * 1024, + maxGraphMemoryBytes: 32 * 1024 * 1024, }); ``` String, binary, and dedicated dense primitive array payloads keep their normal -byte-size checks and do not consume this container budget. Raise the limit only -for trusted workloads that legitimately contain very compact, container-heavy -graphs. +byte-size checks and do not consume this graph budget. Raise the limit only for +trusted workloads that legitimately contain very compact object graphs. ## Optional HPS String Path @@ -135,7 +134,7 @@ Security-related configuration: - Register only the expected schemas before deserializing untrusted payloads. - Set `maxDepth` for the maximum nesting depth your service accepts. -- Set `maxContainerMemoryBytes` for the maximum container memory your service +- Set `maxGraphMemoryBytes` for the maximum graph memory your service accepts from one root payload. - Keep `maxTypeFields` and `maxTypeMetaBytes` at their defaults unless the data is not malicious and a trusted peer sends larger remote metadata. diff --git a/docs/guide/python/configuration.md b/docs/guide/python/configuration.md index 5922a204b9..3acd0b12e8 100644 --- a/docs/guide/python/configuration.md +++ b/docs/guide/python/configuration.md @@ -40,7 +40,7 @@ class Fory: max_type_meta_bytes: int = 4096, max_schema_versions_per_type: int = 10, max_average_schema_versions_per_type: int = 3, - max_container_memory_bytes: int = -1, + max_graph_memory_bytes: int = -1, policy: DeserializationPolicy = None, field_nullable: bool = False, meta_compressor=None, @@ -71,7 +71,7 @@ class ThreadSafeFory: | `max_type_meta_bytes` | `int` | `4096` | Maximum encoded body bytes accepted for one received TypeDef body, excluding the 8-byte header and any extended-size varint. | | `max_schema_versions_per_type` | `int` | `10` | Maximum accepted remote metadata versions for one logical type. | | `max_average_schema_versions_per_type` | `int` | `3` | Average accepted remote metadata versions across accepted remote types. The effective global floor is `8192` schemas. | -| `max_container_memory_bytes` | `int` | `-1` | Maximum estimated container-owned memory for one root deserialization. `-1` selects the automatic limit. | +| `max_graph_memory_bytes` | `int` | `-1` | Maximum estimated shallow graph memory for one root deserialization. `-1` selects the automatic limit. | | `policy` | `DeserializationPolicy \| None` | `None` | Deserialization policy used for security checks. Strongly recommended when `strict=False`. | | `field_nullable` | `bool` | `False` | Treat dataclass fields as nullable by default. | | `meta_compressor` | `Any` | `None` | Optional metadata compressor used for compatible-mode metadata encoding. | @@ -199,7 +199,7 @@ fory = pyfory.Fory( max_type_meta_bytes=4096, max_schema_versions_per_type=10, max_average_schema_versions_per_type=3, - max_container_memory_bytes=-1, + max_graph_memory_bytes=-1, ) fory.register(UserModel, name="example.User") @@ -225,11 +225,11 @@ Received remote metadata is also limited: - `max_type_meta_bytes` limits the encoded body bytes accepted for one received TypeDef body. - `max_schema_versions_per_type` limits accepted remote metadata versions for one logical type. - `max_average_schema_versions_per_type` limits the average across accepted remote types. -- `max_container_memory_bytes` limits estimated lower-bound list, tuple, set, dict, and - object-array storage created during one root deserialization. Empty containers without backing - storage normally do not consume the budget. The default `-1` uses `input_bytes * 8 + 64 KiB` for - known-length inputs and `128 MiB` for stream inputs. Set a positive byte value for trusted - payloads that legitimately contain larger container graphs. +- `max_graph_memory_bytes` limits estimated shallow graph memory created during one root + deserialization, including materialized lists, tuples, sets, dicts, object arrays, structs, and + Python objects. The default `-1` uses `input_bytes * 8 + 64 KiB` for known-length inputs and + `128 MiB` for stream inputs. Set a positive byte value for trusted payloads that legitimately + contain larger object graphs. These limits do not change `strict`, `policy`, dynamic loading, unknown-class handling, or schema-evolution semantics. @@ -286,7 +286,7 @@ unchanged. - Register all expected application types before deserialization. - Use `DeserializationPolicy` when `strict=False` is necessary. - Keep `max_depth` low enough to reject unexpectedly deep payloads. -- Keep `max_container_memory_bytes=-1` unless a trusted workload needs a higher explicit limit. +- Keep `max_graph_memory_bytes=-1` unless a trusted workload needs a higher explicit limit. - Do not treat xlang/native mode choice as a security control. ## Related Topics diff --git a/docs/guide/rust/configuration.md b/docs/guide/rust/configuration.md index f47295ce70..249af3ee71 100644 --- a/docs/guide/rust/configuration.md +++ b/docs/guide/rust/configuration.md @@ -110,15 +110,15 @@ let fory = Fory::builder() - `max_average_schema_versions_per_type` defaults to `3` and limits the average across accepted remote types. The effective global floor is `8192` schemas. -### Container Memory Budget +### Graph Memory Budget -`max_container_memory_bytes(...)` limits the estimated lower-bound container-owned storage accepted -during one root read. The budget covers `Vec`/collection element storage and map key/value storage; -it is not an exact process heap limit. Empty containers without backing storage normally do not -consume the budget. The default is `-1`, which selects an automatic limit based on the input size: +`max_graph_memory_bytes(...)` limits estimated shallow graph memory accepted during one root read. +The budget covers `Vec`/collection element storage, map key/value storage, and materialized struct +or object field storage; it is not an exact process heap limit. The default is `-1`, which selects +an automatic limit based on the input size: ```rust -let fory = Fory::builder().max_container_memory_bytes(-1).build(); +let fory = Fory::builder().max_graph_memory_bytes(-1).build(); ``` For byte-slice and `Reader` roots, the automatic limit is: @@ -131,7 +131,7 @@ Set a positive byte value when trusted payloads need a larger or smaller limit: ```rust let fory = Fory::builder() - .max_container_memory_bytes(256 * 1024 * 1024) + .max_graph_memory_bytes(256 * 1024 * 1024) .build(); ``` @@ -160,9 +160,9 @@ let fory = Fory::builder().xlang(false).compatible(false).build(); // Custom depth limit let fory = Fory::builder().max_dyn_depth(10).build(); -// Custom container memory budget +// Custom graph memory budget let fory = Fory::builder() - .max_container_memory_bytes(256 * 1024 * 1024) + .max_graph_memory_bytes(256 * 1024 * 1024) .build(); // Combined configuration @@ -179,7 +179,7 @@ let fory = Fory::builder() | `compatible(bool)` | Enable schema evolution | `true` | | `xlang(bool)` | Use xlang mode | `true` | | `max_dyn_depth(u32)` | Maximum nesting depth for dynamic types | `5` | -| `max_container_memory_bytes(i64)` | Estimated container memory per root read | `-1` | +| `max_graph_memory_bytes(i64)` | Estimated graph memory per root read | `-1` | | `max_type_fields(usize)` | Max fields in one received struct metadata body | `512` | | `max_type_meta_bytes(usize)` | Max encoded bytes in one received metadata body | `4096` | | `max_schema_versions_per_type(usize)` | Max remote metadata versions for one logical type | `10` | @@ -200,8 +200,8 @@ Security-related configuration: - Register application structs and trait-object implementations before deserializing untrusted payloads. - Use `max_dyn_depth(...)` to reject unexpectedly deep dynamic object graphs. -- Keep `max_container_memory_bytes(-1)` for the default input-shaped container budget, or set a - positive byte limit for trusted workloads with larger legitimate containers. +- Keep `max_graph_memory_bytes(-1)` for the default input-shaped graph budget, or set a positive + byte limit for trusted workloads with larger legitimate object graphs. - Keep the remote schema metadata limits at their defaults unless the data is not malicious and a trusted peer sends larger metadata or many schema versions. - Prefer concrete typed fields over `dyn Any` or broad trait-object fields for untrusted input. diff --git a/docs/guide/swift/configuration.md b/docs/guide/swift/configuration.md index 1301050edf..9086b4cf3c 100644 --- a/docs/guide/swift/configuration.md +++ b/docs/guide/swift/configuration.md @@ -31,7 +31,7 @@ public struct Config { public let compatible: Bool public let checkClassVersion: Bool public let maxDepth: Int - public let maxContainerMemoryBytes: Int64 + public let maxGraphMemoryBytes: Int64 public let maxTypeFields: Int public let maxTypeMetaBytes: Int public let maxSchemaVersionsPerType: Int @@ -93,11 +93,10 @@ let fory = Fory(compatible: false, checkClassVersion: true) `maxDepth` bounds decoded payload nesting depth. -`maxContainerMemoryBytes` bounds the estimated lower-bound container-owned storage accepted during -one root deserialization. Swift roots are currently `Data` or `ByteBuffer`, so auto uses the root -input byte length times `8`, plus `64 KiB`. Empty containers without backing storage normally do -not consume the budget. Use `-1` for the default automatic limit; a positive value overrides it. -`0` and negative values other than `-1` are rejected. +`maxGraphMemoryBytes` bounds estimated shallow graph memory accepted during one root +deserialization. Swift roots are currently `Data` or `ByteBuffer`, so auto uses the root input byte +length times `8`, plus `64 KiB`. Use `-1` for the default automatic limit; a positive value +overrides it. `0` and negative values other than `-1` are rejected. Compatible-mode remote metadata is also limited: @@ -112,7 +111,7 @@ Compatible-mode remote metadata is also limited: ```swift let fory = Fory( maxDepth: 5, - maxContainerMemoryBytes: -1, + maxGraphMemoryBytes: -1, maxTypeFields: 512, maxTypeMetaBytes: 4096, maxSchemaVersionsPerType: 10, @@ -149,7 +148,6 @@ Security-related configuration: - Register only the expected generated models before deserializing untrusted payloads. - Use `checkClassVersion` with `compatible: false` for intentional same-schema payloads. - Set `maxDepth` for the largest nesting depth your service accepts. -- Set `maxContainerMemoryBytes` to cap estimated lower-bound list, set, array, and map storage - during one root deserialization. +- Set `maxGraphMemoryBytes` to cap estimated shallow graph memory during one root deserialization. - Keep the remote schema metadata limits at their defaults unless the data is not malicious and a trusted peer sends larger metadata or many schema versions. diff --git a/docs/security/deserialization.md b/docs/security/deserialization.md index 4bbf4a8cc2..618d7e0bc5 100644 --- a/docs/security/deserialization.md +++ b/docs/security/deserialization.md @@ -205,55 +205,46 @@ validation can cause a no-progress loop, unbounded resource growth, retained state, or success across a Fory policy boundary. Protocol-allowed chunk segmentation is normal input and is not a security issue by itself. -## Container Memory Budget - -Runtimes should enforce a root-deserialization budget for estimated -container-owned memory. This is cumulative accounting for containers created by -one root read; it is not exact heap measurement and it is not a raw element-slot -limit. - -The public configuration should be named around `maxContainerMemoryBytes`. -`-1` means automatic input-shaped budgeting. Positive user configuration always -wins. For known-length root input, the automatic budget is -`inputBytes * 8 + 64 KiB`. For true stream or otherwise unknown-length root -input, the automatic budget is fixed at `128 MiB`. Stream budgeting should not -depend on dynamic bytes-read accounting. - -Container budget accounting should: - -- happen in root-operation read state, with cleanup owned by the root - deserialization `finally`; -- keep read context/read state limited to raw byte reservation and generic - counted-byte arithmetic; collection/map/array storage formulas belong in the - concrete serializer or generated serializer owner; +## Root Graph Memory Budget + +Runtimes should enforce a root-deserialization budget for estimated shallow memory created by one +materialized graph. This is cumulative accounting for graph owners created by one root read; it is +not exact heap measurement and it is not a raw element-slot limit. + +The public configuration is `maxGraphMemoryBytes`. `-1` means automatic input-shaped budgeting. +Positive user configuration always wins. For known-length root input, the automatic budget is +`inputBytes * 8 + 64 KiB`. For true stream or otherwise unknown-length root input, the automatic +budget is fixed at `128 MiB`. Stream budgeting should not depend on dynamic bytes-read accounting. + +Graph budget accounting should: + +- happen in root-operation read state, with cleanup owned by the root deserialization `finally`; +- keep read context/read state limited to raw byte reservation and generic counted-byte arithmetic; + collection, map, array, struct, and object storage formulas belong in the concrete serializer or + generated serializer owner; - reject arithmetic overflow before comparing budget or allocating; -- estimate lower-bound owner storage: reference-backed containers and - object/reference arrays charge reference slots, inline/value containers charge - element storage, reference-backed maps charge two references per entry, and - inline/value maps charge key plus value storage; -- treat fixed/header cost as zero by default, charging it only when the owner - path creates an independently allocated container/control entity that is not - already covered by parent inline/value storage and the charged size is a - documented conservative lower bound; -- preserve existing byte-availability checks before backing allocation or - capacity reservation; -- skip dedicated string, binary, primitive array, and primitive dense-array - owner paths. - -Each runtime must inspect the concrete container path before choosing formulas. -Reference-backed containers should charge reference storage, using a 4-byte -reference slot when the actual reference slot size is not cheap or reliable to -query. Inline/value containers such as a value-type vector or list must charge -the inline element storage instead of treating those elements as references. -General inline-value containers must not be skipped just because dedicated -primitive dense arrays are skipped. - -Runtimes should not guess object headers, array headers, allocator headers, -debug-mode fields, hash buckets, tree links, hash-chain links, node headers, -map-entry objects, spare blocks, or runtime table layouts unless the owner path -has a cheap, stable, explicit lower-bound storage signal and documents the -formula. C++ STL node, allocator, and debug-mode overheads should not be guessed -when only value storage is reliably known. +- estimate lower-bound shallow owner storage: independently materialized collections, maps, sets, + and reference arrays reserve nonzero shallow self cost plus backing/reference/inline storage, and + struct/record/POJO/tuple, compatible, generated, and dynamic object owners reserve a nonzero + shallow self cost plus shallow field storage; +- use a 4-byte reference slot when the actual reference slot size is not cheap or reliable to query, + and use primitive/value field widths for inline storage; +- preserve existing byte-availability checks before backing allocation or capacity reservation; +- skip enum/union as separate owners and skip dedicated string, binary, primitive scalar, primitive + array, and primitive dense-array leaf owners. + +Each runtime must inspect the concrete owner path before choosing formulas. Reserve self storage +exactly once at the owner that stores or allocates the value. Reference-backed paths reserve parent +owner self cost plus reference storage, while each referenced heap owner reserves its own shallow +self cost when materialized. Inline/value paths reserve inline element, field, or root storage in the +holder/allocation owner; nested value serializers must not charge their own self storage again. +Parents must not recursively include child object, collection, map, string, binary, or primitive +dense-array contents; the child owner reserves its own shallow memory when it is materialized. + +Runtimes should not guess object headers, array headers, allocator headers, debug-mode fields, hash +buckets, tree links, hash-chain links, node headers, map-entry objects, spare blocks, or runtime +table layouts unless the owner path has a cheap, stable, explicit lower-bound storage signal and +documents the formula. C++ STL node, allocator, and debug-mode overheads should not be guessed. ## Skip Semantics diff --git a/docs/specification/xlang_implementation_guide.md b/docs/specification/xlang_implementation_guide.md index f960461708..3e5076d0db 100644 --- a/docs/specification/xlang_implementation_guide.md +++ b/docs/specification/xlang_implementation_guide.md @@ -398,36 +398,42 @@ duplicate keys, element value semantics, and protocol strictness remain owned by the container/map serializer and should be validated only when they protect a real owner invariant. -Container readers should also charge a root-operation estimated container memory -budget before allocation or size hinting. The budget belongs to `ReadContext` or -the equivalent root read state, not to serializers and not to ambient -thread-local state. Positive `maxContainerMemoryBytes` configuration wins; auto -configuration uses `inputBytes * 8 + 64 KiB` for known-length root input and -fixed `128 MiB` for true stream or unknown-length root input. Do not add dynamic -stream bytes-read accounting for this budget. +Materializing readers should also reserve a root-operation estimated graph +memory budget before allocation or size hinting. The budget belongs to +`ReadContext` or the equivalent root read state, not to serializers and not to +ambient thread-local state. Positive `maxGraphMemoryBytes` configuration wins; +auto configuration uses `inputBytes * 8 + 64 KiB` for known-length root input +and fixed `128 MiB` for true stream or unknown-length root input. Do not add +dynamic stream bytes-read accounting for this budget. Read context or equivalent read state owns only raw byte accounting and generic counted-byte arithmetic, such as reserving `bytes` or `count * elementBytes` -with overflow checks. It must not expose collection/map/array semantic -reservation APIs. Concrete serializers and generated serializer owners compute -the storage constants and formulas for the container path they allocate. - -The budget estimates lower-bound container-owned storage, not exact heap bytes. -Reference-backed containers and object/reference arrays charge reference slots; -inline/value containers charge element storage; reference-backed maps charge two -references per entry; and inline/value maps charge key plus value storage. -Fixed/header cost defaults to zero and is charged only when the owner path -creates an independently allocated container/control entity, that entity is not -already covered by parent inline/value storage, and the charged size is a -documented conservative lower bound. Empty containers with no dynamic backing -normally charge zero. Skip dedicated string, binary, primitive array, and -primitive dense-array owners, but do not skip general inline-value containers -such as vectors or lists of value objects. If reference slot size is not cheap -or reliable to query, use a 4-byte reference slot. Native runtimes may use -conservative lower-bound estimates instead of guessing non-portable container, -allocator, table, node, entry, or debug-layout details. Reject arithmetic -overflow before budget comparison or allocation, and keep the existing -`checkReadableBytes` proof before backing allocation or capacity reservation. +with overflow checks. It must not expose collection, map, array, struct, or +object semantic reservation APIs. Concrete serializers and generated serializer +owners compute the storage constants and formulas for the owner path they +allocate. + +The budget estimates lower-bound shallow memory for materialized graph owners, +not exact heap bytes. Reserve self storage exactly once at the owner that stores +or allocates the value. Reference-backed containers, maps, sets, and +object/reference arrays reserve nonzero owner self cost plus reference slots; +each referenced heap owner then reserves its own shallow self cost when +materialized. Inline/value containers reserve element storage; inline/value maps +reserve key plus value storage; root/product/box owners reserve value self +storage; and nested value serializers reserve only additional dynamic storage +they allocate. Struct/record/POJO/tuple, compatible, generated, and dynamic +object owners reserve a nonzero shallow self cost plus shallow field storage. +Parents must not recursively include child object, collection, map, string, +binary, or primitive dense-array contents. Skip enum/union as separate owners and +skip dedicated string, binary, primitive scalar, primitive array, and primitive +dense-array leaf owners, but do not skip general inline-value containers such as +vectors or lists of value objects. If reference slot size is not cheap or +reliable to query, use a 4-byte reference slot. Native runtimes may use +conservative lower-bound estimates instead of guessing non-portable object, +container, allocator, table, node, entry, or debug-layout details. Reject +arithmetic overflow before budget comparison or allocation, and keep the +existing `checkReadableBytes` proof before backing +allocation or capacity reservation. For TypeDef or TypeMeta bodies, first prove that the encoded metadata body bytes are readable through the byte owner. Field-list allocation should happen after diff --git a/go/fory/README.md b/go/fory/README.md index e9c633ad8c..c3f7f9f8ed 100644 --- a/go/fory/README.md +++ b/go/fory/README.md @@ -93,15 +93,15 @@ f := fory.New(fory.WithXlang(false), fory.WithCompatible(false)) // Set maximum nesting depth f := fory.New(fory.WithMaxDepth(20)) -// Set maximum estimated container memory for one root read -f := fory.New(fory.WithMaxContainerMemoryBytes(256 * 1024 * 1024)) +// Set maximum estimated graph memory for one root read +f := fory.New(fory.WithMaxGraphMemoryBytes(256 * 1024 * 1024)) // Combine multiple options f := fory.New( fory.WithXlang(true), fory.WithTrackRef(true), fory.WithMaxDepth(20), - fory.WithMaxContainerMemoryBytes(-1), + fory.WithMaxGraphMemoryBytes(-1), ) ``` diff --git a/go/fory/array.go b/go/fory/array.go index c0d58034aa..4a7921cba7 100644 --- a/go/fory/array.go +++ b/go/fory/array.go @@ -318,7 +318,7 @@ func (s arrayDynSerializer) Write(ctx *WriteContext, refMode RefMode, writeType func (s arrayDynSerializer) ReadData(ctx *ReadContext, value reflect.Value) { // Create a temp slice to read into, then copy back to array sliceType := reflect.SliceOf(value.Type().Elem()) - if !ctx.ReserveCountedContainerMemory(value.Len(), int64(value.Type().Elem().Size())) { + if !ctx.ReserveCountedGraphMemory(value.Len(), int64(value.Type().Elem().Size())) { return } tempSlice := reflect.MakeSlice(sliceType, value.Len(), value.Len()) diff --git a/go/fory/codegen/decoder.go b/go/fory/codegen/decoder.go index f2fe674df8..16e53f8a7f 100644 --- a/go/fory/codegen/decoder.go +++ b/go/fory/codegen/decoder.go @@ -69,7 +69,13 @@ func generateReadInterface(buf *bytes.Buffer, s *StructInfo) error { fmt.Fprintf(buf, "\tvar v *%s\n", s.Name) fmt.Fprintf(buf, "\tif value.Kind() == reflect.Ptr {\n") fmt.Fprintf(buf, "\t\tif value.IsNil() {\n") - fmt.Fprintf(buf, "\t\t\t// For pointer types, allocate using value.Type().Elem()\n") + fmt.Fprintf(buf, "\t\t\tgraphBytes := int64(unsafe.Sizeof(%s{}))\n", s.Name) + fmt.Fprintf(buf, "\t\t\tif graphBytes == 0 {\n") + fmt.Fprintf(buf, "\t\t\t\tgraphBytes = 1\n") + fmt.Fprintf(buf, "\t\t\t}\n") + fmt.Fprintf(buf, "\t\t\tif !ctx.ReserveGraphMemory(graphBytes) {\n") + fmt.Fprintf(buf, "\t\t\t\treturn\n") + fmt.Fprintf(buf, "\t\t\t}\n") fmt.Fprintf(buf, "\t\t\tvalue.Set(reflect.New(value.Type().Elem()))\n") fmt.Fprintf(buf, "\t\t}\n") fmt.Fprintf(buf, "\t\tv = value.Interface().(*%s)\n", s.Name) @@ -172,7 +178,7 @@ func generateFieldReadTyped(buf *bytes.Buffer, field *FieldInfo) error { fmt.Fprintf(buf, "\t\t\tif ctx.HasError() {\n") fmt.Fprintf(buf, "\t\t\t\treturn ctx.TakeError()\n") fmt.Fprintf(buf, "\t\t\t}\n") - fmt.Fprintf(buf, "\t\t\tif !ctx.ReserveCountedContainerMemory(sliceLen, %s) {\n", unsafeSizeExpr(elemType)) + fmt.Fprintf(buf, "\t\t\tif !ctx.ReserveCountedGraphMemory(sliceLen, %s) {\n", unsafeSizeExpr(elemType)) fmt.Fprintf(buf, "\t\t\t\treturn ctx.TakeError()\n") fmt.Fprintf(buf, "\t\t\t}\n") fmt.Fprintf(buf, "\t\t\tif sliceLen == 0 {\n") @@ -203,7 +209,7 @@ func generateFieldReadTyped(buf *bytes.Buffer, field *FieldInfo) error { fmt.Fprintf(buf, "\t\t\t\tif ctx.HasError() {\n") fmt.Fprintf(buf, "\t\t\t\t\treturn ctx.TakeError()\n") fmt.Fprintf(buf, "\t\t\t\t}\n") - fmt.Fprintf(buf, "\t\t\t\tif !ctx.ReserveCountedContainerMemory(sliceLen, %s) {\n", unsafeSizeExpr(elemType)) + fmt.Fprintf(buf, "\t\t\t\tif !ctx.ReserveCountedGraphMemory(sliceLen, %s) {\n", unsafeSizeExpr(elemType)) fmt.Fprintf(buf, "\t\t\t\t\treturn ctx.TakeError()\n") fmt.Fprintf(buf, "\t\t\t\t}\n") fmt.Fprintf(buf, "\t\t\t\tif sliceLen == 0 {\n") @@ -510,7 +516,7 @@ func generateSliceReadInline(buf *bytes.Buffer, sliceType *types.Slice, fieldAcc fmt.Fprintf(buf, "\t\t\tif ctx.HasError() {\n") fmt.Fprintf(buf, "\t\t\t\treturn ctx.TakeError()\n") fmt.Fprintf(buf, "\t\t\t}\n") - fmt.Fprintf(buf, "\t\t\tif !ctx.ReserveCountedContainerMemory(sliceLen, %s) {\n", unsafeSizeExpr(elemType)) + fmt.Fprintf(buf, "\t\t\tif !ctx.ReserveCountedGraphMemory(sliceLen, %s) {\n", unsafeSizeExpr(elemType)) fmt.Fprintf(buf, "\t\t\t\treturn ctx.TakeError()\n") fmt.Fprintf(buf, "\t\t\t}\n") fmt.Fprintf(buf, "\t\t\tif sliceLen == 0 {\n") @@ -531,7 +537,7 @@ func generateSliceReadInline(buf *bytes.Buffer, sliceType *types.Slice, fieldAcc fmt.Fprintf(buf, "\t\t\t\tif ctx.HasError() {\n") fmt.Fprintf(buf, "\t\t\t\t\treturn ctx.TakeError()\n") fmt.Fprintf(buf, "\t\t\t\t}\n") - fmt.Fprintf(buf, "\t\t\t\tif !ctx.ReserveCountedContainerMemory(sliceLen, %s) {\n", unsafeSizeExpr(elemType)) + fmt.Fprintf(buf, "\t\t\t\tif !ctx.ReserveCountedGraphMemory(sliceLen, %s) {\n", unsafeSizeExpr(elemType)) fmt.Fprintf(buf, "\t\t\t\t\treturn ctx.TakeError()\n") fmt.Fprintf(buf, "\t\t\t\t}\n") fmt.Fprintf(buf, "\t\t\t\tif sliceLen == 0 {\n") @@ -560,7 +566,7 @@ func generateSliceReadInlineNoNull(buf *bytes.Buffer, sliceType *types.Slice, fi fmt.Fprintf(buf, "%sif ctx.HasError() {\n", indent) fmt.Fprintf(buf, "%s\treturn ctx.TakeError()\n", indent) fmt.Fprintf(buf, "%s}\n", indent) - fmt.Fprintf(buf, "%sif !ctx.ReserveCountedContainerMemory(sliceLen, %s) {\n", indent, unsafeSizeExpr(elemType)) + fmt.Fprintf(buf, "%sif !ctx.ReserveCountedGraphMemory(sliceLen, %s) {\n", indent, unsafeSizeExpr(elemType)) fmt.Fprintf(buf, "%s\treturn ctx.TakeError()\n", indent) fmt.Fprintf(buf, "%s}\n", indent) fmt.Fprintf(buf, "%sif sliceLen == 0 {\n", indent) @@ -586,7 +592,7 @@ func generateSliceReadInlineNoNull(buf *bytes.Buffer, sliceType *types.Slice, fi fmt.Fprintf(buf, "%sif ctx.HasError() {\n", indent) fmt.Fprintf(buf, "%s\treturn ctx.TakeError()\n", indent) fmt.Fprintf(buf, "%s}\n", indent) - fmt.Fprintf(buf, "%sif !ctx.ReserveCountedContainerMemory(sliceLen, %s) {\n", indent, unsafeSizeExpr(elemType)) + fmt.Fprintf(buf, "%sif !ctx.ReserveCountedGraphMemory(sliceLen, %s) {\n", indent, unsafeSizeExpr(elemType)) fmt.Fprintf(buf, "%s\treturn ctx.TakeError()\n", indent) fmt.Fprintf(buf, "%s}\n", indent) fmt.Fprintf(buf, "%sif sliceLen == 0 {\n", indent) @@ -855,7 +861,7 @@ func generateMapReadInline(buf *bytes.Buffer, mapType *types.Map, fieldAccess st fmt.Fprintf(buf, "\t\t\tif ctx.HasError() {\n") fmt.Fprintf(buf, "\t\t\t\treturn ctx.TakeError()\n") fmt.Fprintf(buf, "\t\t\t}\n") - fmt.Fprintf(buf, "\t\t\tif !ctx.ReserveCountedContainerMemory(mapLen, %s + %s) {\n", unsafeSizeExpr(keyType), unsafeSizeExpr(valueType)) + fmt.Fprintf(buf, "\t\t\tif !ctx.ReserveCountedGraphMemory(mapLen, %s + %s) {\n", unsafeSizeExpr(keyType), unsafeSizeExpr(valueType)) fmt.Fprintf(buf, "\t\t\t\treturn ctx.TakeError()\n") fmt.Fprintf(buf, "\t\t\t}\n") fmt.Fprintf(buf, "\t\t\tif mapLen == 0 {\n") @@ -876,7 +882,7 @@ func generateMapReadInline(buf *bytes.Buffer, mapType *types.Map, fieldAccess st fmt.Fprintf(buf, "\t\t\t\tif ctx.HasError() {\n") fmt.Fprintf(buf, "\t\t\t\t\treturn ctx.TakeError()\n") fmt.Fprintf(buf, "\t\t\t\t}\n") - fmt.Fprintf(buf, "\t\t\t\tif !ctx.ReserveCountedContainerMemory(mapLen, %s + %s) {\n", unsafeSizeExpr(keyType), unsafeSizeExpr(valueType)) + fmt.Fprintf(buf, "\t\t\t\tif !ctx.ReserveCountedGraphMemory(mapLen, %s + %s) {\n", unsafeSizeExpr(keyType), unsafeSizeExpr(valueType)) fmt.Fprintf(buf, "\t\t\t\t\treturn ctx.TakeError()\n") fmt.Fprintf(buf, "\t\t\t\t}\n") fmt.Fprintf(buf, "\t\t\t\tif mapLen == 0 {\n") @@ -914,7 +920,7 @@ func generateMapReadInlineNoNull(buf *bytes.Buffer, mapType *types.Map, fieldAcc fmt.Fprintf(buf, "%sif ctx.HasError() {\n", indent) fmt.Fprintf(buf, "%s\treturn ctx.TakeError()\n", indent) fmt.Fprintf(buf, "%s}\n", indent) - fmt.Fprintf(buf, "%sif !ctx.ReserveCountedContainerMemory(mapLen, %s + %s) {\n", indent, unsafeSizeExpr(keyType), unsafeSizeExpr(valueType)) + fmt.Fprintf(buf, "%sif !ctx.ReserveCountedGraphMemory(mapLen, %s + %s) {\n", indent, unsafeSizeExpr(keyType), unsafeSizeExpr(valueType)) fmt.Fprintf(buf, "%s\treturn ctx.TakeError()\n", indent) fmt.Fprintf(buf, "%s}\n", indent) fmt.Fprintf(buf, "%sif mapLen == 0 {\n", indent) diff --git a/go/fory/codegen/generator.go b/go/fory/codegen/generator.go index 30fcbd7fb3..c054ea462b 100644 --- a/go/fory/codegen/generator.go +++ b/go/fory/codegen/generator.go @@ -34,7 +34,7 @@ import ( var logger = log.New(os.Stdout, "", 0) -func typeNeedsContainerReservation(t types.Type) bool { +func typeNeedsGraphReservation(t types.Type) bool { if _, ok := t.(*types.Slice); ok { return true } @@ -295,9 +295,9 @@ func generateCodeForFile(pkg *packages.Package, structs []*StructInfo, sourceFil // Determine which imports are needed needsTime := false - needsReflect := false + needsReflect := len(structs) > 0 needsOptional := false - needsUnsafe := false + needsUnsafe := len(structs) > 0 for _, s := range structs { for _, field := range s.Fields { @@ -307,15 +307,13 @@ func generateCodeForFile(pkg *packages.Package, structs []*StructInfo, sourceFil } if field.IsOptional { needsOptional = true - if field.OptionalElem != nil && typeNeedsContainerReservation(field.OptionalElem) { + if field.OptionalElem != nil && typeNeedsGraphReservation(field.OptionalElem) { needsUnsafe = true } } - if typeNeedsContainerReservation(field.Type) { + if typeNeedsGraphReservation(field.Type) { needsUnsafe = true } - // We need reflect for the interface compatibility methods - needsReflect = true } } @@ -570,9 +568,9 @@ func generateCode(pkg *packages.Package, structs []*StructInfo) error { // Determine which imports are needed needsTime := false - needsReflect := false + needsReflect := len(structs) > 0 needsOptional := false - needsUnsafe := false + needsUnsafe := len(structs) > 0 for _, s := range structs { for _, field := range s.Fields { @@ -582,15 +580,13 @@ func generateCode(pkg *packages.Package, structs []*StructInfo) error { } if field.IsOptional { needsOptional = true - if field.OptionalElem != nil && typeNeedsContainerReservation(field.OptionalElem) { + if field.OptionalElem != nil && typeNeedsGraphReservation(field.OptionalElem) { needsUnsafe = true } } - if typeNeedsContainerReservation(field.Type) { + if typeNeedsGraphReservation(field.Type) { needsUnsafe = true } - // We need reflect for the interface compatibility methods - needsReflect = true } } diff --git a/go/fory/container_memory_budget_test.go b/go/fory/container_memory_budget_test.go deleted file mode 100644 index 00281f1e13..0000000000 --- a/go/fory/container_memory_budget_test.go +++ /dev/null @@ -1,227 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package fory - -import ( - "bytes" - "strings" - "testing" - - "github.com/stretchr/testify/require" -) - -type budgetItem struct { - A int32 -} - -type budgetSiblings struct { - A []string - B []string -} - -func TestContainerMemoryBudgetConfig(t *testing.T) { - require.Equal(t, int64(-1), New().config.MaxContainerMemoryBytes) - require.Equal(t, int64(123), New(WithMaxContainerMemoryBytes(123)).config.MaxContainerMemoryBytes) - require.Panics(t, func() { New(WithMaxContainerMemoryBytes(0)) }) - require.Panics(t, func() { New(WithMaxContainerMemoryBytes(-2)) }) -} - -func TestContainerMemoryBudgetAutoLimits(t *testing.T) { - ctx := NewReadContext(false) - ctx.initContainerMemoryBudget(10, false) - require.False(t, ctx.HasError()) - require.Equal(t, int64(10)*knownRootBudgetMultiplier+knownRootBudgetSlackBytes, ctx.containerMemoryLimitBytes) - require.True(t, ctx.ReserveContainerMemory(ctx.containerMemoryLimitBytes)) - require.False(t, ctx.ReserveContainerMemory(1)) - require.Contains(t, ctx.CheckError().Error(), "maxContainerMemoryBytes") - - ctx = NewReadContext(false) - ctx.initContainerMemoryBudget(10, true) - require.False(t, ctx.HasError()) - require.Equal(t, streamRootBudgetBytes, ctx.containerMemoryLimitBytes) - require.True(t, ctx.ReserveContainerMemory(streamRootBudgetBytes)) - require.False(t, ctx.ReserveContainerMemory(1)) - require.Contains(t, ctx.CheckError().Error(), "maxContainerMemoryBytes") - - ctx = NewReadContext(false) - ctx.maxContainerMemoryBytes = 77 - ctx.initContainerMemoryBudget(10, true) - require.False(t, ctx.HasError()) - require.Equal(t, int64(77), ctx.containerMemoryLimitBytes) -} - -func TestContainerMemoryBudgetKnownVsStreamRoot(t *testing.T) { - writer := New(WithCompatible(false)) - values := make([]any, 12000) - for i := range values { - values[i] = []any{} - } - data, err := writer.Serialize(values) - require.NoError(t, err) - - var fromBytes []any - err = New(WithCompatible(false)).Deserialize(data, &fromBytes) - require.Error(t, err) - require.Contains(t, err.Error(), "maxContainerMemoryBytes") - - var fromStream []any - err = New(WithCompatible(false)).DeserializeFromReader(bytes.NewReader(data), &fromStream) - require.NoError(t, err) - require.Len(t, fromStream, len(values)) -} - -func TestContainerMemoryBudgetBufferRoots(t *testing.T) { - writer := New(WithCompatible(false)) - value := []string{"a", "b"} - data, err := writer.Serialize(value) - require.NoError(t, err) - - reader := New(WithCompatible(false)) - var fromCallback []string - err = reader.DeserializeWithCallbackBuffers(NewByteBuffer(data), &fromCallback, nil) - require.NoError(t, err) - require.Equal(t, value, fromCallback) - - var fromBuffer []string - err = reader.DeserializeFrom(NewByteBuffer(data), &fromBuffer) - require.NoError(t, err) - require.Equal(t, value, fromBuffer) -} - -func TestContainerMemoryBudgetExplicitOverride(t *testing.T) { - writer := New(WithCompatible(false)) - values := make([]any, 12000) - for i := range values { - values[i] = []any{} - } - data, err := writer.Serialize(values) - require.NoError(t, err) - - var out []any - err = New(WithCompatible(false), WithMaxContainerMemoryBytes(4*1024*1024)).Deserialize(data, &out) - require.NoError(t, err) - require.Len(t, out, len(values)) -} - -func TestContainerMemoryBudgetEmptyAndCumulative(t *testing.T) { - data, err := New(WithCompatible(false)).Serialize([]any{}) - require.NoError(t, err) - var empty []any - err = New(WithCompatible(false), WithMaxContainerMemoryBytes(1)).Deserialize(data, &empty) - require.NoError(t, err) - require.Empty(t, empty) - - writer := New(WithCompatible(false)) - require.NoError(t, writer.RegisterStructByName(budgetSiblings{}, "test.BudgetSiblings")) - data, err = writer.Serialize(&budgetSiblings{A: []string{"a"}, B: []string{"b"}}) - require.NoError(t, err) - reader := New(WithCompatible(false), WithMaxContainerMemoryBytes(stringElementBytes)) - require.NoError(t, reader.RegisterStructByName(budgetSiblings{}, "test.BudgetSiblings")) - var out budgetSiblings - err = reader.Deserialize(data, &out) - require.Error(t, err) - require.Contains(t, err.Error(), "maxContainerMemoryBytes") - reader = New(WithCompatible(false), WithMaxContainerMemoryBytes(2*stringElementBytes)) - require.NoError(t, reader.RegisterStructByName(budgetSiblings{}, "test.BudgetSiblings")) - require.NoError(t, reader.Deserialize(data, &out)) - require.Equal(t, []string{"a"}, out.A) - require.Equal(t, []string{"b"}, out.B) -} - -func TestContainerMemoryBudgetMapAndOverflow(t *testing.T) { - data, err := New().Serialize(map[string]string{"k": "v"}) - require.NoError(t, err) - var out map[string]string - oneEntryBudget := containerSizeOf[string]() + containerSizeOf[string]() - err = New(WithMaxContainerMemoryBytes(oneEntryBudget-1)).Deserialize(data, &out) - require.Error(t, err) - require.Contains(t, err.Error(), "maxContainerMemoryBytes") - - ctx := NewReadContext(false) - ctx.initContainerMemoryBudget(0, true) - require.False(t, ctx.ReserveCountedContainerMemory(MaxInt, MaxInt64)) - require.Contains(t, ctx.CheckError().Error(), "overflows") -} - -func TestContainerMemoryBudgetSlicesAndInlineValues(t *testing.T) { - data, err := New().Serialize([]string{"a"}) - require.NoError(t, err) - var stringsOut []string - err = New(WithMaxContainerMemoryBytes(containerSizeOf[string]()-1)).Deserialize(data, &stringsOut) - require.Error(t, err) - require.Contains(t, err.Error(), "maxContainerMemoryBytes") - - writer := New() - require.NoError(t, writer.RegisterStructByName(budgetItem{}, "test.BudgetItem")) - data, err = writer.Serialize([]budgetItem{{A: 1}}) - require.NoError(t, err) - reader := New(WithMaxContainerMemoryBytes(containerSizeOf[budgetItem]() - 1)) - require.NoError(t, reader.RegisterStructByName(budgetItem{}, "test.BudgetItem")) - var items []budgetItem - err = reader.Deserialize(data, &items) - require.Error(t, err) - require.Contains(t, err.Error(), "maxContainerMemoryBytes") -} - -func TestContainerMemoryBudgetSkipsDenseOwners(t *testing.T) { - f := New(WithMaxContainerMemoryBytes(1)) - - stringData, err := New().Serialize(strings.Repeat("x", 128)) - require.NoError(t, err) - var s string - require.NoError(t, f.Deserialize(stringData, &s)) - require.Len(t, s, 128) - - bytesData, err := New().Serialize([]byte{1, 2, 3, 4}) - require.NoError(t, err) - var b []byte - require.NoError(t, f.Deserialize(bytesData, &b)) - require.Equal(t, []byte{1, 2, 3, 4}, b) - - intsData, err := New().Serialize([]int32{1, 2, 3, 4}) - require.NoError(t, err) - var ints []int32 - require.NoError(t, f.Deserialize(intsData, &ints)) - require.Equal(t, []int32{1, 2, 3, 4}, ints) -} - -func TestContainerMemoryBudgetPreservesByteChecks(t *testing.T) { - buf := NewByteBuffer(nil) - buf.WriteByte_(XLangFlag) - buf.WriteInt8(NotNullValueFlag) - buf.WriteUint8(uint8(LIST)) - buf.WriteLength(1024) - buf.WriteInt8(int8(CollectionIsSameType)) - buf.WriteUint8(uint8(STRING)) - - var stringsOut []string - err := New(WithMaxContainerMemoryBytes(8*1024*1024)).Deserialize(buf.Bytes(), &stringsOut) - require.Error(t, err) - require.Contains(t, err.Error(), "buffer out of bound") - - buf = NewByteBuffer(nil) - buf.WriteByte_(XLangFlag) - buf.WriteInt8(NotNullValueFlag) - buf.WriteUint8(uint8(INT32_ARRAY)) - buf.WriteLength(4096) - - var ints []int32 - err = New(WithMaxContainerMemoryBytes(8*1024*1024)).Deserialize(buf.Bytes(), &ints) - require.Error(t, err) - require.Contains(t, err.Error(), "buffer out of bound") -} diff --git a/go/fory/field_serializer.go b/go/fory/field_serializer.go index b4cc586c75..0a3125ff21 100644 --- a/go/fory/field_serializer.go +++ b/go/fory/field_serializer.go @@ -74,7 +74,7 @@ func newDeclaredSliceSerializer(type_ reflect.Type, elemSerializer Serializer, r elemSerializer: elemSerializer, referencable: referencable, elemBytes: elemBytes, - maxLength: maxContainerCount(elemBytes), + maxLength: maxGraphCount(elemBytes), }, nil } diff --git a/go/fory/fory.go b/go/fory/fory.go index 3b6360aece..26811dcfaf 100644 --- a/go/fory/fory.go +++ b/go/fory/fory.go @@ -69,7 +69,7 @@ type Config struct { MaxDepth int IsXlang bool Compatible bool // Schema evolution compatibility mode - MaxContainerMemoryBytes int64 + MaxGraphMemoryBytes int64 MaxTypeFields int MaxTypeMetaBytes int MaxSchemaVersionsPerType int @@ -83,7 +83,7 @@ func defaultConfig() Config { MaxDepth: 20, IsXlang: true, MaxTypeFields: 512, - MaxContainerMemoryBytes: -1, + MaxGraphMemoryBytes: -1, MaxTypeMetaBytes: 4096, MaxSchemaVersionsPerType: 10, MaxAverageSchemaVersionsPerType: 3, @@ -112,14 +112,14 @@ func WithMaxDepth(depth int) Option { } } -// WithMaxContainerMemoryBytes sets the maximum estimated container-owned memory accepted during one root deserialization. +// WithMaxGraphMemoryBytes sets the maximum estimated graph memory accepted during one root deserialization. // Use -1 for the automatic input-shaped limit. -func WithMaxContainerMemoryBytes(size int64) Option { +func WithMaxGraphMemoryBytes(size int64) Option { if size != -1 && size <= 0 { - panic("MaxContainerMemoryBytes must be positive or -1 for auto") + panic("MaxGraphMemoryBytes must be positive or -1 for auto") } return func(f *Fory) { - f.config.MaxContainerMemoryBytes = size + f.config.MaxGraphMemoryBytes = size } } @@ -231,7 +231,7 @@ func New(opts ...Option) *Fory { f.writeCtx.xlang = f.config.IsXlang f.readCtx = NewReadContext(f.config.TrackRef) - f.readCtx.maxContainerMemoryBytes = f.config.MaxContainerMemoryBytes + f.readCtx.maxGraphMemoryBytes = f.config.MaxGraphMemoryBytes f.readCtx.typeResolver = f.typeResolver f.readCtx.refResolver = f.refResolver f.readCtx.compatible = f.config.Compatible @@ -570,7 +570,7 @@ func (f *Fory) Serialize(value any) ([]byte, error) { func (f *Fory) Deserialize(data []byte, v any) error { defer f.resetReadState() f.readCtx.SetData(data) - f.readCtx.initContainerMemoryBudget(len(data), false) + f.readCtx.initGraphMemoryBudget(len(data), false) if f.readCtx.HasError() { return f.readCtx.TakeError() } @@ -582,6 +582,9 @@ func (f *Fory) Deserialize(data []byte, v any) error { // Deserialize the value - TypeMeta is read inline using streaming protocol target := reflect.ValueOf(v).Elem() + if err := f.reserveRootGraphOwner(target); err != nil { + return err + } f.readCtx.ReadValue(target, RefModeTracking, true) if f.readCtx.HasError() { return f.readCtx.TakeError() @@ -666,7 +669,7 @@ func (f *Fory) DeserializeFrom(buf *ByteBuffer, v any) error { // Temporarily swap buffer origBuffer := f.readCtx.buffer f.readCtx.buffer = buf - f.readCtx.initContainerMemoryBudget(buf.readableBytes(), false) + f.readCtx.initGraphMemoryBudget(buf.readableBytes(), false) if f.readCtx.HasError() { f.readCtx.buffer = origBuffer return f.readCtx.TakeError() @@ -680,6 +683,10 @@ func (f *Fory) DeserializeFrom(buf *ByteBuffer, v any) error { // Deserialize the value - TypeMeta is read inline using streaming protocol target := reflect.ValueOf(v).Elem() + if err := f.reserveRootGraphOwner(target); err != nil { + f.readCtx.buffer = origBuffer + return err + } f.readCtx.ReadValue(target, RefModeTracking, true) if f.readCtx.HasError() { f.readCtx.buffer = origBuffer @@ -766,7 +773,7 @@ func (f *Fory) DeserializeWithCallbackBuffers(buffer *ByteBuffer, v any, buffers f.readCtx.buffer = nil f.readCtx.outOfBandBuffers = nil }() - f.readCtx.initContainerMemoryBudget(buffer.readableBytes(), false) + f.readCtx.initGraphMemoryBudget(buffer.readableBytes(), false) if f.readCtx.HasError() { return f.readCtx.TakeError() } @@ -794,7 +801,11 @@ func (f *Fory) DeserializeWithCallbackBuffers(buffer *ByteBuffer, v any, buffers } // Deserialize the value - TypeMeta is read inline using streaming protocol - f.readCtx.ReadValue(rv.Elem(), RefModeTracking, true) + target := rv.Elem() + if err := f.reserveRootGraphOwner(target); err != nil { + return err + } + f.readCtx.ReadValue(target, RefModeTracking, true) if f.readCtx.HasError() { return f.readCtx.TakeError() } @@ -1043,7 +1054,7 @@ func Deserialize[T any](f *Fory, data []byte, target *T) error { // Reuse context, reset and set new data f.readCtx.Reset() f.readCtx.SetData(data) - f.readCtx.initContainerMemoryBudget(len(data), false) + f.readCtx.initGraphMemoryBudget(len(data), false) if f.readCtx.HasError() { return f.readCtx.TakeError() } @@ -1187,6 +1198,9 @@ func Deserialize[T any](f *Fory, data []byte, target *T) error { // Slow path: use serializer-based deserialization targetVal := reflect.ValueOf(target).Elem() targetType := targetVal.Type() + if err := f.reserveRootGraphOwner(targetVal); err != nil { + return err + } // Get serializer for the target type serializer, err := f.typeResolver.getSerializerByType(targetType, false) @@ -1199,3 +1213,17 @@ func Deserialize[T any](f *Fory, data []byte, target *T) error { return f.readCtx.CheckError() } } + +func (f *Fory) reserveRootGraphOwner(target reflect.Value) error { + if !target.IsValid() || target.Kind() != reflect.Struct { + return nil + } + targetType := target.Type() + if targetType == dateReflectType || targetType == timeReflectType { + return nil + } + if !reserveStructGraph(f.readCtx, targetType) { + return f.readCtx.TakeError() + } + return nil +} diff --git a/go/fory/graph_memory_budget_test.go b/go/fory/graph_memory_budget_test.go new file mode 100644 index 0000000000..7c266d72e2 --- /dev/null +++ b/go/fory/graph_memory_budget_test.go @@ -0,0 +1,295 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package fory + +import ( + "bytes" + "reflect" + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +type budgetItem struct { + A int32 +} + +type budgetSiblings struct { + A []string + B []string +} + +func graphOwnerSizeOf[T any]() int64 { + bytes := graphSizeOf[T]() + if bytes == 0 { + return 1 + } + return bytes +} + +func TestGraphMemoryBudgetConfig(t *testing.T) { + require.Equal(t, int64(-1), New().config.MaxGraphMemoryBytes) + require.Equal(t, int64(123), New(WithMaxGraphMemoryBytes(123)).config.MaxGraphMemoryBytes) + require.Panics(t, func() { New(WithMaxGraphMemoryBytes(0)) }) + require.Panics(t, func() { New(WithMaxGraphMemoryBytes(-2)) }) +} + +func TestGraphMemoryBudgetAutoLimits(t *testing.T) { + ctx := NewReadContext(false) + ctx.initGraphMemoryBudget(10, false) + require.False(t, ctx.HasError()) + require.Equal(t, int64(10)*knownRootBudgetMultiplier+knownRootBudgetSlackBytes, ctx.graphMemoryLimitBytes) + require.True(t, ctx.ReserveGraphMemory(ctx.graphMemoryLimitBytes)) + require.False(t, ctx.ReserveGraphMemory(1)) + require.Contains(t, ctx.CheckError().Error(), "maxGraphMemoryBytes") + + ctx = NewReadContext(false) + ctx.initGraphMemoryBudget(10, true) + require.False(t, ctx.HasError()) + require.Equal(t, streamRootBudgetBytes, ctx.graphMemoryLimitBytes) + require.True(t, ctx.ReserveGraphMemory(streamRootBudgetBytes)) + require.False(t, ctx.ReserveGraphMemory(1)) + require.Contains(t, ctx.CheckError().Error(), "maxGraphMemoryBytes") + + ctx = NewReadContext(false) + ctx.maxGraphMemoryBytes = 77 + ctx.initGraphMemoryBudget(10, true) + require.False(t, ctx.HasError()) + require.Equal(t, int64(77), ctx.graphMemoryLimitBytes) +} + +func TestGraphMemoryBudgetKnownVsStreamRoot(t *testing.T) { + writer := New(WithCompatible(false)) + values := make([]any, 12000) + for i := range values { + values[i] = []any{} + } + data, err := writer.Serialize(values) + require.NoError(t, err) + + var fromBytes []any + err = New(WithCompatible(false)).Deserialize(data, &fromBytes) + require.Error(t, err) + require.Contains(t, err.Error(), "maxGraphMemoryBytes") + + var fromStream []any + err = New(WithCompatible(false)).DeserializeFromReader(bytes.NewReader(data), &fromStream) + require.NoError(t, err) + require.Len(t, fromStream, len(values)) +} + +func TestGraphMemoryBudgetBufferRoots(t *testing.T) { + writer := New(WithCompatible(false)) + value := []string{"a", "b"} + data, err := writer.Serialize(value) + require.NoError(t, err) + + reader := New(WithCompatible(false)) + var fromCallback []string + err = reader.DeserializeWithCallbackBuffers(NewByteBuffer(data), &fromCallback, nil) + require.NoError(t, err) + require.Equal(t, value, fromCallback) + + var fromBuffer []string + err = reader.DeserializeFrom(NewByteBuffer(data), &fromBuffer) + require.NoError(t, err) + require.Equal(t, value, fromBuffer) +} + +func TestGraphMemoryBudgetExplicitOverride(t *testing.T) { + writer := New(WithCompatible(false)) + values := make([]any, 12000) + for i := range values { + values[i] = []any{} + } + data, err := writer.Serialize(values) + require.NoError(t, err) + + var out []any + err = New(WithCompatible(false), WithMaxGraphMemoryBytes(4*1024*1024)).Deserialize(data, &out) + require.NoError(t, err) + require.Len(t, out, len(values)) +} + +func TestGraphMemoryBudgetEmptyAndCumulative(t *testing.T) { + data, err := New(WithCompatible(false)).Serialize([]any{}) + require.NoError(t, err) + var empty []any + err = New(WithCompatible(false), WithMaxGraphMemoryBytes(1)).Deserialize(data, &empty) + require.NoError(t, err) + require.Empty(t, empty) + + writer := New(WithCompatible(false)) + require.NoError(t, writer.RegisterStructByName(budgetSiblings{}, "test.BudgetSiblings")) + data, err = writer.Serialize(&budgetSiblings{A: []string{"a"}, B: []string{"b"}}) + require.NoError(t, err) + reader := New(WithCompatible(false), WithMaxGraphMemoryBytes(stringElementBytes)) + require.NoError(t, reader.RegisterStructByName(budgetSiblings{}, "test.BudgetSiblings")) + var out budgetSiblings + err = reader.Deserialize(data, &out) + require.Error(t, err) + require.Contains(t, err.Error(), "maxGraphMemoryBytes") + required := structGraphBytes(reflect.TypeOf(budgetSiblings{})) + 2*stringElementBytes + reader = New(WithCompatible(false), WithMaxGraphMemoryBytes(required)) + require.NoError(t, reader.RegisterStructByName(budgetSiblings{}, "test.BudgetSiblings")) + require.NoError(t, reader.Deserialize(data, &out)) + require.Equal(t, []string{"a"}, out.A) + require.Equal(t, []string{"b"}, out.B) +} + +func TestGraphMemoryBudgetMapAndOverflow(t *testing.T) { + data, err := New().Serialize(map[string]string{"k": "v"}) + require.NoError(t, err) + var out map[string]string + oneEntryBudget := graphSizeOf[string]() + graphSizeOf[string]() + err = New(WithMaxGraphMemoryBytes(oneEntryBudget-1)).Deserialize(data, &out) + require.Error(t, err) + require.Contains(t, err.Error(), "maxGraphMemoryBytes") + + ctx := NewReadContext(false) + ctx.initGraphMemoryBudget(0, true) + require.False(t, ctx.ReserveCountedGraphMemory(MaxInt, MaxInt64)) + require.Contains(t, ctx.CheckError().Error(), "overflows") +} + +func TestGraphMemoryBudgetSlicesAndInlineValues(t *testing.T) { + data, err := New().Serialize([]string{"a"}) + require.NoError(t, err) + var stringsOut []string + err = New(WithMaxGraphMemoryBytes(graphSizeOf[string]()-1)).Deserialize(data, &stringsOut) + require.Error(t, err) + require.Contains(t, err.Error(), "maxGraphMemoryBytes") + + writer := New() + require.NoError(t, writer.RegisterStructByName(budgetItem{}, "test.BudgetItem")) + data, err = writer.Serialize([]budgetItem{{A: 1}}) + require.NoError(t, err) + reader := New(WithMaxGraphMemoryBytes(graphSizeOf[budgetItem]() - 1)) + require.NoError(t, reader.RegisterStructByName(budgetItem{}, "test.BudgetItem")) + var items []budgetItem + err = reader.Deserialize(data, &items) + require.Error(t, err) + require.Contains(t, err.Error(), "maxGraphMemoryBytes") + reader = New(WithMaxGraphMemoryBytes(graphSizeOf[budgetItem]())) + require.NoError(t, reader.RegisterStructByName(budgetItem{}, "test.BudgetItem")) + require.NoError(t, reader.Deserialize(data, &items)) + require.Equal(t, []budgetItem{{A: 1}}, items) +} + +func TestGraphMemoryBudgetStructOwners(t *testing.T) { + writer := New(WithCompatible(false)) + require.NoError(t, writer.RegisterStructByName(budgetItem{}, "test.BudgetItem")) + data, err := writer.Serialize(&budgetItem{A: 7}) + require.NoError(t, err) + + required := graphOwnerSizeOf[budgetItem]() + reader := New(WithCompatible(false), WithMaxGraphMemoryBytes(required-1)) + require.NoError(t, reader.RegisterStructByName(budgetItem{}, "test.BudgetItem")) + var out *budgetItem + err = reader.Deserialize(data, &out) + require.Error(t, err) + require.Contains(t, err.Error(), "maxGraphMemoryBytes") + + reader = New(WithCompatible(false), WithMaxGraphMemoryBytes(required)) + require.NoError(t, reader.RegisterStructByName(budgetItem{}, "test.BudgetItem")) + require.NoError(t, reader.Deserialize(data, &out)) + require.Equal(t, int32(7), out.A) + + reader = New(WithCompatible(false), WithMaxGraphMemoryBytes(required-1)) + require.NoError(t, reader.RegisterStructByName(budgetItem{}, "test.BudgetItem")) + var outValue budgetItem + err = reader.Deserialize(data, &outValue) + require.Error(t, err) + require.Contains(t, err.Error(), "maxGraphMemoryBytes") + + reader = New(WithCompatible(false), WithMaxGraphMemoryBytes(required)) + require.NoError(t, reader.RegisterStructByName(budgetItem{}, "test.BudgetItem")) + require.NoError(t, reader.Deserialize(data, &outValue)) + require.Equal(t, int32(7), outValue.A) + + reader = New(WithCompatible(false), WithMaxGraphMemoryBytes(required-1)) + require.NoError(t, reader.RegisterStructByName(budgetItem{}, "test.BudgetItem")) + err = reader.DeserializeFromReader(bytes.NewReader(data), &outValue) + require.Error(t, err) + require.Contains(t, err.Error(), "maxGraphMemoryBytes") + + reader = New(WithCompatible(false), WithMaxGraphMemoryBytes(required)) + require.NoError(t, reader.RegisterStructByName(budgetItem{}, "test.BudgetItem")) + require.NoError(t, reader.DeserializeFromReader(bytes.NewReader(data), &outValue)) + require.Equal(t, int32(7), outValue.A) + + reader = New(WithCompatible(false), WithMaxGraphMemoryBytes(required-1)) + require.NoError(t, reader.RegisterStructByName(budgetItem{}, "test.BudgetItem")) + err = reader.DeserializeFromStream(NewInputStream(bytes.NewReader(data)), &outValue) + require.Error(t, err) + require.Contains(t, err.Error(), "maxGraphMemoryBytes") + + reader = New(WithCompatible(false), WithMaxGraphMemoryBytes(required)) + require.NoError(t, reader.RegisterStructByName(budgetItem{}, "test.BudgetItem")) + require.NoError(t, reader.DeserializeFromStream(NewInputStream(bytes.NewReader(data)), &outValue)) + require.Equal(t, int32(7), outValue.A) +} + +func TestGraphMemoryBudgetSkipsDenseOwners(t *testing.T) { + f := New(WithMaxGraphMemoryBytes(1)) + + stringData, err := New().Serialize(strings.Repeat("x", 128)) + require.NoError(t, err) + var s string + require.NoError(t, f.Deserialize(stringData, &s)) + require.Len(t, s, 128) + + bytesData, err := New().Serialize([]byte{1, 2, 3, 4}) + require.NoError(t, err) + var b []byte + require.NoError(t, f.Deserialize(bytesData, &b)) + require.Equal(t, []byte{1, 2, 3, 4}, b) + + intsData, err := New().Serialize([]int32{1, 2, 3, 4}) + require.NoError(t, err) + var ints []int32 + require.NoError(t, f.Deserialize(intsData, &ints)) + require.Equal(t, []int32{1, 2, 3, 4}, ints) +} + +func TestGraphMemoryBudgetPreservesByteChecks(t *testing.T) { + buf := NewByteBuffer(nil) + buf.WriteByte_(XLangFlag) + buf.WriteInt8(NotNullValueFlag) + buf.WriteUint8(uint8(LIST)) + buf.WriteLength(1024) + buf.WriteInt8(int8(CollectionIsSameType)) + buf.WriteUint8(uint8(STRING)) + + var stringsOut []string + err := New(WithMaxGraphMemoryBytes(8*1024*1024)).Deserialize(buf.Bytes(), &stringsOut) + require.Error(t, err) + require.Contains(t, err.Error(), "buffer out of bound") + + buf = NewByteBuffer(nil) + buf.WriteByte_(XLangFlag) + buf.WriteInt8(NotNullValueFlag) + buf.WriteUint8(uint8(INT32_ARRAY)) + buf.WriteLength(4096) + + var ints []int32 + err = New(WithMaxGraphMemoryBytes(8*1024*1024)).Deserialize(buf.Bytes(), &ints) + require.Error(t, err) + require.Contains(t, err.Error(), "buffer out of bound") +} diff --git a/go/fory/map.go b/go/fory/map.go index d3a661edeb..9d68706838 100644 --- a/go/fory/map.go +++ b/go/fory/map.go @@ -307,10 +307,10 @@ func (s mapSerializer) ReadData(ctx *ReadContext, value reflect.Value) { valueBytes := int64(mapType.Elem().Size()) elemBytes := keyBytes + valueBytes if elemBytes < keyBytes { - ctx.setContainerMemoryError("map entry size overflows: key=%d value=%d", keyBytes, valueBytes) + ctx.setGraphMemoryError("map entry size overflows: key=%d value=%d", keyBytes, valueBytes) return } - if !ctx.ReserveCountedContainerMemory(size, elemBytes) { + if !ctx.ReserveCountedGraphMemory(size, elemBytes) { return } if size == 0 { diff --git a/go/fory/map_primitive.go b/go/fory/map_primitive.go index e52b88b5b1..d0ed63a632 100644 --- a/go/fory/map_primitive.go +++ b/go/fory/map_primitive.go @@ -27,23 +27,23 @@ import ( var ( stringStringMapElemBytes = stringElementBytes + stringElementBytes - stringStringMapMaxLength = maxContainerCount(stringStringMapElemBytes) - stringInt64MapElemBytes = stringElementBytes + containerSizeOf[int64]() - stringInt64MapMaxLength = maxContainerCount(stringInt64MapElemBytes) - stringInt32MapElemBytes = stringElementBytes + containerSizeOf[int32]() - stringInt32MapMaxLength = maxContainerCount(stringInt32MapElemBytes) - stringIntMapElemBytes = stringElementBytes + containerSizeOf[int]() - stringIntMapMaxLength = maxContainerCount(stringIntMapElemBytes) - stringFloat64MapElemBytes = stringElementBytes + containerSizeOf[float64]() - stringFloat64MapMaxLength = maxContainerCount(stringFloat64MapElemBytes) - stringBoolMapElemBytes = stringElementBytes + containerSizeOf[bool]() - stringBoolMapMaxLength = maxContainerCount(stringBoolMapElemBytes) - int32Int32MapElemBytes = containerSizeOf[int32]() + containerSizeOf[int32]() - int32Int32MapMaxLength = maxContainerCount(int32Int32MapElemBytes) - int64Int64MapElemBytes = containerSizeOf[int64]() + containerSizeOf[int64]() - int64Int64MapMaxLength = maxContainerCount(int64Int64MapElemBytes) - intIntMapElemBytes = containerSizeOf[int]() + containerSizeOf[int]() - intIntMapMaxLength = maxContainerCount(intIntMapElemBytes) + stringStringMapMaxLength = maxGraphCount(stringStringMapElemBytes) + stringInt64MapElemBytes = stringElementBytes + graphSizeOf[int64]() + stringInt64MapMaxLength = maxGraphCount(stringInt64MapElemBytes) + stringInt32MapElemBytes = stringElementBytes + graphSizeOf[int32]() + stringInt32MapMaxLength = maxGraphCount(stringInt32MapElemBytes) + stringIntMapElemBytes = stringElementBytes + graphSizeOf[int]() + stringIntMapMaxLength = maxGraphCount(stringIntMapElemBytes) + stringFloat64MapElemBytes = stringElementBytes + graphSizeOf[float64]() + stringFloat64MapMaxLength = maxGraphCount(stringFloat64MapElemBytes) + stringBoolMapElemBytes = stringElementBytes + graphSizeOf[bool]() + stringBoolMapMaxLength = maxGraphCount(stringBoolMapElemBytes) + int32Int32MapElemBytes = graphSizeOf[int32]() + graphSizeOf[int32]() + int32Int32MapMaxLength = maxGraphCount(int32Int32MapElemBytes) + int64Int64MapElemBytes = graphSizeOf[int64]() + graphSizeOf[int64]() + int64Int64MapMaxLength = maxGraphCount(int64Int64MapElemBytes) + intIntMapElemBytes = graphSizeOf[int]() + graphSizeOf[int]() + intIntMapMaxLength = maxGraphCount(intIntMapElemBytes) ) // writeMapStringString writes map[string]string using chunk protocol @@ -94,7 +94,7 @@ func readTypedMapSize(ctx *ReadContext, elemBytes int64, maxLength int64) (int, if ctx.HasError() { return 0, false } - if !ctx.reserveCountedContainerMemory(size, elemBytes, maxLength) { + if !ctx.reserveCountedGraphMemory(size, elemBytes, maxLength) { return 0, false } if size == 0 { diff --git a/go/fory/pointer.go b/go/fory/pointer.go index b30a15f53d..5180deb627 100644 --- a/go/fory/pointer.go +++ b/go/fory/pointer.go @@ -140,6 +140,9 @@ func (s *ptrToValueSerializer) ReadData(ctx *ReadContext, value reflect.Value) { var newVal reflect.Value if value.IsNil() { // Allocate new value + if !reserveStructGraph(ctx, value.Type().Elem()) { + return + } newVal = reflect.New(value.Type().Elem()) value.Set(newVal) } else { @@ -195,6 +198,9 @@ func (s *ptrToValueSerializer) Read(ctx *ReadContext, refMode RefMode, readType if structSer, ok := typeInfo.Serializer.(*structSerializer); ok && len(structSer.fieldDefs) > 0 { // Allocate the pointer value if needed if value.IsNil() { + if !reserveStructGraph(ctx, value.Type().Elem()) { + return + } value.Set(reflect.New(value.Type().Elem())) } ctx.RefResolver().Reference(value) diff --git a/go/fory/reader.go b/go/fory/reader.go index 2222d98690..8f7c582bfc 100644 --- a/go/fory/reader.go +++ b/go/fory/reader.go @@ -29,24 +29,24 @@ import ( // ReadContext holds all state needed during deserialization. type ReadContext struct { - buffer *ByteBuffer - refReader *RefReader - trackRef bool // Cached flag to avoid indirection - xlang bool // Cross-language serialization mode - rootHeader byte - compatible bool // Schema evolution compatibility mode - typeResolver *TypeResolver // For complex type deserialization - refResolver *RefResolver // For reference tracking in native-mode paths - outOfBandBuffers []*ByteBuffer // Out-of-band buffers for deserialization - outOfBandIndex int // Current index into out-of-band buffers - depth int // Current nesting depth for cycle detection - maxDepth int // Maximum allowed nesting depth - err Error // Accumulated error state for deferred checking - lastTypePtr uintptr - lastTypeInfo *TypeInfo - maxContainerMemoryBytes int64 - containerMemoryLimitBytes int64 - remainingContainerMemoryBytes int64 + buffer *ByteBuffer + refReader *RefReader + trackRef bool // Cached flag to avoid indirection + xlang bool // Cross-language serialization mode + rootHeader byte + compatible bool // Schema evolution compatibility mode + typeResolver *TypeResolver // For complex type deserialization + refResolver *RefResolver // For reference tracking in native-mode paths + outOfBandBuffers []*ByteBuffer // Out-of-band buffers for deserialization + outOfBandIndex int // Current index into out-of-band buffers + depth int // Current nesting depth for cycle detection + maxDepth int // Maximum allowed nesting depth + err Error // Accumulated error state for deferred checking + lastTypePtr uintptr + lastTypeInfo *TypeInfo + maxGraphMemoryBytes int64 + graphMemoryLimitBytes int64 + remainingGraphMemoryBytes int64 } const ( @@ -56,21 +56,40 @@ const ( ) var referenceSlotBytes = int64(unsafe.Sizeof(uintptr(0))) -var stringElementBytes = containerSizeOf[string]() -var stringMaxLength = maxContainerCount(stringElementBytes) +var stringElementBytes = graphSizeOf[string]() +var stringMaxLength = maxGraphCount(stringElementBytes) -func containerSizeOf[T any]() int64 { +func graphSizeOf[T any]() int64 { var v T return int64(unsafe.Sizeof(v)) } -func maxContainerCount(elemBytes int64) int64 { +func maxGraphCount(elemBytes int64) int64 { if elemBytes == 0 { return MaxInt64 } return MaxInt64 / elemBytes } +func structGraphBytes(type_ reflect.Type) int64 { + if type_.Kind() != reflect.Struct { + return 0 + } + bytes := int64(type_.Size()) + if bytes == 0 { + return 1 + } + return bytes +} + +func reserveStructGraph(ctx *ReadContext, type_ reflect.Type) bool { + bytes := structGraphBytes(type_) + if bytes == 0 { + return true + } + return ctx.ReserveGraphMemory(bytes) +} + // IsXlang returns whether cross-language serialization mode is enabled func (c *ReadContext) IsXlang() bool { return c.xlang @@ -79,11 +98,13 @@ func (c *ReadContext) IsXlang() bool { // NewReadContext creates a new read context func NewReadContext(trackRef bool) *ReadContext { return &ReadContext{ - buffer: NewByteBuffer(nil), - refReader: NewRefReader(trackRef), - trackRef: trackRef, - maxDepth: 128, // Default maximum nesting depth - maxContainerMemoryBytes: -1, + buffer: NewByteBuffer(nil), + refReader: NewRefReader(trackRef), + trackRef: trackRef, + maxDepth: 128, // Default maximum nesting depth + maxGraphMemoryBytes: -1, + graphMemoryLimitBytes: MaxInt64, + remainingGraphMemoryBytes: MaxInt64, } } @@ -93,7 +114,7 @@ func (c *ReadContext) Reset() { c.outOfBandBuffers = nil c.outOfBandIndex = 0 c.err = Error{} // Clear error state - // Container budget state is overwritten by each root read before deserialization. + // Graph budget state is overwritten by each root read before deserialization. // Avoid extra reset stores on the successful root hot path. if c.refResolver != nil { c.refResolver.resetRead() @@ -103,86 +124,86 @@ func (c *ReadContext) Reset() { } } -func (c *ReadContext) initContainerMemoryBudget(rootInputBytes int, unknownLengthInput bool) { - limit := c.maxContainerMemoryBytes +func (c *ReadContext) initGraphMemoryBudget(rootInputBytes int, unknownLengthInput bool) { + limit := c.maxGraphMemoryBytes if limit <= 0 { if unknownLengthInput { limit = streamRootBudgetBytes } else { if rootInputBytes < 0 { - c.setContainerMemoryError("root input size must be non-negative: %d", rootInputBytes) + c.setGraphMemoryError("root input size must be non-negative: %d", rootInputBytes) return } if int64(rootInputBytes) > (MaxInt64-knownRootBudgetSlackBytes)/knownRootBudgetMultiplier { - c.setContainerMemoryError("root input size %d overflows automatic container memory budget", rootInputBytes) + c.setGraphMemoryError("root input size %d overflows automatic graph memory budget", rootInputBytes) return } limit = int64(rootInputBytes)*knownRootBudgetMultiplier + knownRootBudgetSlackBytes } } - c.containerMemoryLimitBytes = limit - c.remainingContainerMemoryBytes = limit + c.graphMemoryLimitBytes = limit + c.remainingGraphMemoryBytes = limit } -// ReserveCountedContainerMemory reserves length * elementBytes estimated container bytes. -func (c *ReadContext) ReserveCountedContainerMemory(length int, elemBytes int64) bool { +// ReserveCountedGraphMemory reserves length * elementBytes estimated graph bytes. +func (c *ReadContext) ReserveCountedGraphMemory(length int, elemBytes int64) bool { if length < 0 { - c.setContainerMemoryError("negative container length: %d", length) + c.setGraphMemoryError("negative graph element count: %d", length) return false } if elemBytes < 0 { - c.setContainerMemoryError("negative container element size: %d", elemBytes) + c.setGraphMemoryError("negative graph element size: %d", elemBytes) return false } if length == 0 { return true } - return c.reserveCountedContainerMemory(length, elemBytes, maxContainerCount(elemBytes)) + return c.reserveCountedGraphMemory(length, elemBytes, maxGraphCount(elemBytes)) } -func (c *ReadContext) reserveCountedContainerMemory(length int, elemBytes int64, maxLength int64) bool { +func (c *ReadContext) reserveCountedGraphMemory(length int, elemBytes int64, maxLength int64) bool { if length == 0 { return true } if int64(length) > maxLength { - c.setContainerMemoryError("container memory estimate overflows: length=%d elementBytes=%d", length, elemBytes) + c.setGraphMemoryError("graph memory estimate overflows: length=%d elementBytes=%d", length, elemBytes) return false } bytes := int64(length) * elemBytes - remaining := c.remainingContainerMemoryBytes + remaining := c.remainingGraphMemoryBytes if bytes > remaining { - c.setContainerMemoryExceeded(bytes, remaining) + c.setGraphMemoryExceeded(bytes, remaining) return false } - c.remainingContainerMemoryBytes = remaining - bytes + c.remainingGraphMemoryBytes = remaining - bytes return true } -// ReserveContainerMemory reserves raw estimated container-owned bytes. -func (c *ReadContext) ReserveContainerMemory(bytes int64) bool { +// ReserveGraphMemory reserves raw estimated graph-owner bytes. +func (c *ReadContext) ReserveGraphMemory(bytes int64) bool { if bytes < 0 { - c.setContainerMemoryError("estimated container memory must be non-negative, got %d bytes", bytes) + c.setGraphMemoryError("estimated graph memory must be non-negative, got %d bytes", bytes) return false } - remaining := c.remainingContainerMemoryBytes + remaining := c.remainingGraphMemoryBytes if bytes > remaining { - c.setContainerMemoryExceeded(bytes, remaining) + c.setGraphMemoryExceeded(bytes, remaining) return false } - c.remainingContainerMemoryBytes = remaining - bytes + c.remainingGraphMemoryBytes = remaining - bytes return true } //go:noinline -func (c *ReadContext) setContainerMemoryError(format string, args ...any) { +func (c *ReadContext) setGraphMemoryError(format string, args ...any) { c.SetError(DeserializationErrorf(format, args...)) } //go:noinline -func (c *ReadContext) setContainerMemoryExceeded(bytes int64, remaining int64) { +func (c *ReadContext) setGraphMemoryExceeded(bytes int64, remaining int64) { c.SetError(DeserializationErrorf( - "estimated container memory request %d bytes exceeds maxContainerMemoryBytes remaining budget %d bytes out of effective limit %d bytes", - bytes, remaining, c.containerMemoryLimitBytes)) + "estimated graph memory request %d bytes exceeds maxGraphMemoryBytes remaining budget %d bytes out of effective limit %d bytes", + bytes, remaining, c.graphMemoryLimitBytes)) } // SetData sets new input data (for buffer reuse) @@ -656,7 +677,7 @@ func (c *ReadContext) readStringSliceData() []string { if c.HasError() { return nil } - if !c.reserveCountedContainerMemory(length, stringElementBytes, stringMaxLength) { + if !c.reserveCountedGraphMemory(length, stringElementBytes, stringMaxLength) { return nil } if length == 0 { @@ -940,9 +961,15 @@ func (c *ReadContext) ReadValue(value reflect.Value, refMode RefMode, readType b } else if isNamedStruct { // For named struct types, create a pointer to support circular references // Create *A instead of A + if !reserveStructGraph(c, actualType) { + return + } newValue = reflect.New(actualType) valueToSet = newValue } else { + if !reserveStructGraph(c, actualType) { + return + } newValue = reflect.New(actualType).Elem() valueToSet = newValue } @@ -1075,6 +1102,9 @@ func (c *ReadContext) ReadStruct(value reflect.Value) { var readTarget reflect.Value if isPtr { if value.IsNil() { + if !reserveStructGraph(c, structType) { + return + } value.Set(reflect.New(structType)) } readTarget = value.Elem() diff --git a/go/fory/set.go b/go/fory/set.go index b3899c3ef8..437e679938 100644 --- a/go/fory/set.go +++ b/go/fory/set.go @@ -322,10 +322,10 @@ func (s setSerializer) ReadData(ctx *ReadContext, value reflect.Value) { valueBytes := int64(type_.Elem().Size()) elemBytes := keyBytes + valueBytes if elemBytes < keyBytes { - ctx.setContainerMemoryError("map entry size overflows: key=%d value=%d", keyBytes, valueBytes) + ctx.setGraphMemoryError("map entry size overflows: key=%d value=%d", keyBytes, valueBytes) return } - if !ctx.ReserveCountedContainerMemory(length, elemBytes) { + if !ctx.ReserveCountedGraphMemory(length, elemBytes) { return } // Initialize empty set if length is 0 @@ -370,10 +370,10 @@ func (s setSerializer) ReadData(ctx *ReadContext, value reflect.Value) { valueBytes := int64(type_.Elem().Size()) elemBytes := keyBytes + valueBytes if elemBytes < keyBytes { - ctx.setContainerMemoryError("map entry size overflows: key=%d value=%d", keyBytes, valueBytes) + ctx.setGraphMemoryError("map entry size overflows: key=%d value=%d", keyBytes, valueBytes) return } - if !ctx.ReserveCountedContainerMemory(length, elemBytes) { + if !ctx.ReserveCountedGraphMemory(length, elemBytes) { return } diff --git a/go/fory/slice.go b/go/fory/slice.go index ad1bc38fc6..7975feb432 100644 --- a/go/fory/slice.go +++ b/go/fory/slice.go @@ -152,7 +152,7 @@ func newSliceSerializer(type_ reflect.Type, elemSerializer Serializer, xlang boo elemSerializer: elemSerializer, referencable: isRefType(elem, xlang), elemBytes: elemBytes, - maxLength: maxContainerCount(elemBytes), + maxLength: maxGraphCount(elemBytes), }, nil } @@ -319,7 +319,7 @@ func (s *sliceSerializer) ReadData(ctx *ReadContext, value reflect.Value) { } return } - if !isArrayType && !ctx.reserveCountedContainerMemory(length, s.elemBytes, s.maxLength) { + if !isArrayType && !ctx.reserveCountedGraphMemory(length, s.elemBytes, s.maxLength) { return } diff --git a/go/fory/slice_dyn.go b/go/fory/slice_dyn.go index 11e3eb05c4..b65721f9f8 100644 --- a/go/fory/slice_dyn.go +++ b/go/fory/slice_dyn.go @@ -41,11 +41,11 @@ type sliceDynSerializer struct { func newSliceDynSerializer(elemType reflect.Type) (*sliceDynSerializer, error) { // Nil element type is allowed for fully dynamic slices (e.g., []any) if elemType == nil { - elemBytes := containerSizeOf[any]() + elemBytes := graphSizeOf[any]() return &sliceDynSerializer{ isInterfaceElem: true, elemBytes: elemBytes, - maxLength: maxContainerCount(elemBytes), + maxLength: maxGraphCount(elemBytes), }, nil } // Validate element type is interface or pointer to interface @@ -61,7 +61,7 @@ func newSliceDynSerializer(elemType reflect.Type) (*sliceDynSerializer, error) { isInterfaceElem: isInterface, isPointerElem: isPointerToInterface, elemBytes: elemBytes, - maxLength: maxContainerCount(elemBytes), + maxLength: maxGraphCount(elemBytes), }, nil } @@ -287,7 +287,7 @@ func (s *sliceDynSerializer) readData(ctx *ReadContext, value reflect.Value, exp value.Set(reflect.MakeSlice(sliceType, 0, 0)) return } - if !allocatedByCaller && !ctx.reserveCountedContainerMemory(length, s.elemBytes, s.maxLength) { + if !allocatedByCaller && !ctx.reserveCountedGraphMemory(length, s.elemBytes, s.maxLength) { return } diff --git a/go/fory/slice_primitive.go b/go/fory/slice_primitive.go index b809467c01..3e2fc5e2d8 100644 --- a/go/fory/slice_primitive.go +++ b/go/fory/slice_primitive.go @@ -652,7 +652,7 @@ func (s stringSliceSerializer) ReadData(ctx *ReadContext, value reflect.Value) { return } ptr := (*[]string)(value.Addr().UnsafePointer()) - if !ctx.reserveCountedContainerMemory(length, stringElementBytes, stringMaxLength) { + if !ctx.reserveCountedGraphMemory(length, stringElementBytes, stringMaxLength) { return } if length == 0 { diff --git a/go/fory/slice_primitive_list.go b/go/fory/slice_primitive_list.go index f76f0b66b4..efda751dde 100644 --- a/go/fory/slice_primitive_list.go +++ b/go/fory/slice_primitive_list.go @@ -35,7 +35,7 @@ func newPrimitiveList(type_ reflect.Type, elemTypeID TypeId, elemType reflect.Ty type_: type_, elemTypeID: elemTypeID, elemBytes: elemBytes, - maxLength: maxContainerCount(elemBytes), + maxLength: maxGraphCount(elemBytes), } } @@ -179,7 +179,7 @@ func (s primitiveListSerializer) ReadData(ctx *ReadContext, value reflect.Value) if ctx.HasError() { return } - if !ctx.reserveCountedContainerMemory(length, s.elemBytes, s.maxLength) { + if !ctx.reserveCountedGraphMemory(length, s.elemBytes, s.maxLength) { return } if length == 0 { @@ -243,7 +243,7 @@ func (s compatiblePrimitiveListToArraySerializer) ReadData(ctx *ReadContext, val } if length == 0 { if value.Kind() == reflect.Slice { - if !ctx.reserveCountedContainerMemory(length, s.listReader.elemBytes, s.listReader.maxLength) { + if !ctx.reserveCountedGraphMemory(length, s.listReader.elemBytes, s.listReader.maxLength) { return } value.Set(reflect.MakeSlice(value.Type(), 0, 0)) @@ -284,7 +284,7 @@ func (s compatiblePrimitiveListToArraySerializer) ReadData(ctx *ReadContext, val return } if value.Kind() == reflect.Slice { - if !ctx.reserveCountedContainerMemory(length, s.listReader.elemBytes, s.listReader.maxLength) { + if !ctx.reserveCountedGraphMemory(length, s.listReader.elemBytes, s.listReader.maxLength) { return } temp := reflect.New(value.Type()).Elem() diff --git a/go/fory/stream.go b/go/fory/stream.go index 45111695e5..40cb642ae8 100644 --- a/go/fory/stream.go +++ b/go/fory/stream.go @@ -96,7 +96,7 @@ func (is *InputStream) Shrink() { func (f *Fory) DeserializeFromStream(is *InputStream, v any) error { origBuffer := f.readCtx.buffer f.readCtx.buffer = is.buffer - f.readCtx.initContainerMemoryBudget(0, true) + f.readCtx.initGraphMemoryBudget(0, true) if f.readCtx.HasError() { err := f.readCtx.TakeError() f.readCtx.buffer = origBuffer @@ -114,6 +114,9 @@ func (f *Fory) DeserializeFromStream(is *InputStream, v any) error { } target := reflect.ValueOf(v).Elem() + if err := f.reserveRootGraphOwner(target); err != nil { + return err + } f.readCtx.ReadValue(target, RefModeTracking, true) if f.readCtx.HasError() { return f.readCtx.TakeError() @@ -130,7 +133,7 @@ func (f *Fory) DeserializeFromReader(r io.Reader, v any) error { defer f.resetReadState() // Always reset to enforce stateless semantics. f.readCtx.buffer.ResetWithReader(r, 0) - f.readCtx.initContainerMemoryBudget(0, true) + f.readCtx.initGraphMemoryBudget(0, true) if f.readCtx.HasError() { return f.readCtx.TakeError() } @@ -141,6 +144,9 @@ func (f *Fory) DeserializeFromReader(r io.Reader, v any) error { } target := reflect.ValueOf(v).Elem() + if err := f.reserveRootGraphOwner(target); err != nil { + return err + } f.readCtx.ReadValue(target, RefModeTracking, true) if f.readCtx.HasError() { return f.readCtx.TakeError() diff --git a/go/fory/tests/structs_fory_gen.go b/go/fory/tests/structs_fory_gen.go index a3dae17b33..945f3f5f58 100644 --- a/go/fory/tests/structs_fory_gen.go +++ b/go/fory/tests/structs_fory_gen.go @@ -1,6 +1,6 @@ // Code generated by forygen. DO NOT EDIT. // source: structs.go -// generated at: 2026-06-26T15:00:42+08:00 +// generated at: 2026-06-30T22:09:02+08:00 package fory @@ -190,7 +190,7 @@ func (g *DynamicSliceDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v if ctx.HasError() { return ctx.TakeError() } - if !ctx.ReserveCountedContainerMemory(sliceLen, int64(unsafe.Sizeof(*new(any)))) { + if !ctx.ReserveCountedGraphMemory(sliceLen, int64(unsafe.Sizeof(*new(any)))) { return ctx.TakeError() } if sliceLen == 0 { @@ -221,7 +221,7 @@ func (g *DynamicSliceDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v if ctx.HasError() { return ctx.TakeError() } - if !ctx.ReserveCountedContainerMemory(sliceLen, int64(unsafe.Sizeof(*new(any)))) { + if !ctx.ReserveCountedGraphMemory(sliceLen, int64(unsafe.Sizeof(*new(any)))) { return ctx.TakeError() } if sliceLen == 0 { @@ -260,7 +260,13 @@ func (g *DynamicSliceDemo_ForyGenSerializer) ReadData(ctx *fory.ReadContext, val var v *DynamicSliceDemo if value.Kind() == reflect.Ptr { if value.IsNil() { - // For pointer types, allocate using value.Type().Elem() + graphBytes := int64(unsafe.Sizeof(DynamicSliceDemo{})) + if graphBytes == 0 { + graphBytes = 1 + } + if !ctx.ReserveGraphMemory(graphBytes) { + return + } value.Set(reflect.New(value.Type().Elem())) } v = value.Interface().(*DynamicSliceDemo) @@ -672,7 +678,7 @@ func (g *MapDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v *MapDemo) if ctx.HasError() { return ctx.TakeError() } - if !ctx.ReserveCountedContainerMemory(mapLen, int64(unsafe.Sizeof(*new(int)))+int64(unsafe.Sizeof(*new(int)))) { + if !ctx.ReserveCountedGraphMemory(mapLen, int64(unsafe.Sizeof(*new(int)))+int64(unsafe.Sizeof(*new(int)))) { return ctx.TakeError() } if mapLen == 0 { @@ -722,7 +728,7 @@ func (g *MapDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v *MapDemo) if ctx.HasError() { return ctx.TakeError() } - if !ctx.ReserveCountedContainerMemory(mapLen, int64(unsafe.Sizeof(*new(int)))+int64(unsafe.Sizeof(*new(int)))) { + if !ctx.ReserveCountedGraphMemory(mapLen, int64(unsafe.Sizeof(*new(int)))+int64(unsafe.Sizeof(*new(int)))) { return ctx.TakeError() } if mapLen == 0 { @@ -774,7 +780,7 @@ func (g *MapDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v *MapDemo) if ctx.HasError() { return ctx.TakeError() } - if !ctx.ReserveCountedContainerMemory(mapLen, int64(unsafe.Sizeof(*new(string)))+int64(unsafe.Sizeof(*new(int)))) { + if !ctx.ReserveCountedGraphMemory(mapLen, int64(unsafe.Sizeof(*new(string)))+int64(unsafe.Sizeof(*new(int)))) { return ctx.TakeError() } if mapLen == 0 { @@ -824,7 +830,7 @@ func (g *MapDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v *MapDemo) if ctx.HasError() { return ctx.TakeError() } - if !ctx.ReserveCountedContainerMemory(mapLen, int64(unsafe.Sizeof(*new(string)))+int64(unsafe.Sizeof(*new(int)))) { + if !ctx.ReserveCountedGraphMemory(mapLen, int64(unsafe.Sizeof(*new(string)))+int64(unsafe.Sizeof(*new(int)))) { return ctx.TakeError() } if mapLen == 0 { @@ -876,7 +882,7 @@ func (g *MapDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v *MapDemo) if ctx.HasError() { return ctx.TakeError() } - if !ctx.ReserveCountedContainerMemory(mapLen, int64(unsafe.Sizeof(*new(string)))+int64(unsafe.Sizeof(*new(string)))) { + if !ctx.ReserveCountedGraphMemory(mapLen, int64(unsafe.Sizeof(*new(string)))+int64(unsafe.Sizeof(*new(string)))) { return ctx.TakeError() } if mapLen == 0 { @@ -926,7 +932,7 @@ func (g *MapDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v *MapDemo) if ctx.HasError() { return ctx.TakeError() } - if !ctx.ReserveCountedContainerMemory(mapLen, int64(unsafe.Sizeof(*new(string)))+int64(unsafe.Sizeof(*new(string)))) { + if !ctx.ReserveCountedGraphMemory(mapLen, int64(unsafe.Sizeof(*new(string)))+int64(unsafe.Sizeof(*new(string)))) { return ctx.TakeError() } if mapLen == 0 { @@ -984,7 +990,13 @@ func (g *MapDemo_ForyGenSerializer) ReadData(ctx *fory.ReadContext, value reflec var v *MapDemo if value.Kind() == reflect.Ptr { if value.IsNil() { - // For pointer types, allocate using value.Type().Elem() + graphBytes := int64(unsafe.Sizeof(MapDemo{})) + if graphBytes == 0 { + graphBytes = 1 + } + if !ctx.ReserveGraphMemory(graphBytes) { + return + } value.Set(reflect.New(value.Type().Elem())) } v = value.Interface().(*MapDemo) @@ -1287,7 +1299,7 @@ func (g *SliceDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v *SliceD if ctx.HasError() { return ctx.TakeError() } - if !ctx.ReserveCountedContainerMemory(sliceLen, int64(unsafe.Sizeof(*new(bool)))) { + if !ctx.ReserveCountedGraphMemory(sliceLen, int64(unsafe.Sizeof(*new(bool)))) { return ctx.TakeError() } if sliceLen == 0 { @@ -1329,7 +1341,7 @@ func (g *SliceDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v *SliceD if ctx.HasError() { return ctx.TakeError() } - if !ctx.ReserveCountedContainerMemory(sliceLen, int64(unsafe.Sizeof(*new(bool)))) { + if !ctx.ReserveCountedGraphMemory(sliceLen, int64(unsafe.Sizeof(*new(bool)))) { return ctx.TakeError() } if sliceLen == 0 { @@ -1373,7 +1385,7 @@ func (g *SliceDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v *SliceD if ctx.HasError() { return ctx.TakeError() } - if !ctx.ReserveCountedContainerMemory(sliceLen, int64(unsafe.Sizeof(*new(float64)))) { + if !ctx.ReserveCountedGraphMemory(sliceLen, int64(unsafe.Sizeof(*new(float64)))) { return ctx.TakeError() } if sliceLen == 0 { @@ -1415,7 +1427,7 @@ func (g *SliceDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v *SliceD if ctx.HasError() { return ctx.TakeError() } - if !ctx.ReserveCountedContainerMemory(sliceLen, int64(unsafe.Sizeof(*new(float64)))) { + if !ctx.ReserveCountedGraphMemory(sliceLen, int64(unsafe.Sizeof(*new(float64)))) { return ctx.TakeError() } if sliceLen == 0 { @@ -1459,7 +1471,7 @@ func (g *SliceDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v *SliceD if ctx.HasError() { return ctx.TakeError() } - if !ctx.ReserveCountedContainerMemory(sliceLen, int64(unsafe.Sizeof(*new(int32)))) { + if !ctx.ReserveCountedGraphMemory(sliceLen, int64(unsafe.Sizeof(*new(int32)))) { return ctx.TakeError() } if sliceLen == 0 { @@ -1501,7 +1513,7 @@ func (g *SliceDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v *SliceD if ctx.HasError() { return ctx.TakeError() } - if !ctx.ReserveCountedContainerMemory(sliceLen, int64(unsafe.Sizeof(*new(int32)))) { + if !ctx.ReserveCountedGraphMemory(sliceLen, int64(unsafe.Sizeof(*new(int32)))) { return ctx.TakeError() } if sliceLen == 0 { @@ -1545,7 +1557,7 @@ func (g *SliceDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v *SliceD if ctx.HasError() { return ctx.TakeError() } - if !ctx.ReserveCountedContainerMemory(sliceLen, int64(unsafe.Sizeof(*new(string)))) { + if !ctx.ReserveCountedGraphMemory(sliceLen, int64(unsafe.Sizeof(*new(string)))) { return ctx.TakeError() } if sliceLen == 0 { @@ -1595,7 +1607,7 @@ func (g *SliceDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v *SliceD if ctx.HasError() { return ctx.TakeError() } - if !ctx.ReserveCountedContainerMemory(sliceLen, int64(unsafe.Sizeof(*new(string)))) { + if !ctx.ReserveCountedGraphMemory(sliceLen, int64(unsafe.Sizeof(*new(string)))) { return ctx.TakeError() } if sliceLen == 0 { @@ -1653,7 +1665,13 @@ func (g *SliceDemo_ForyGenSerializer) ReadData(ctx *fory.ReadContext, value refl var v *SliceDemo if value.Kind() == reflect.Ptr { if value.IsNil() { - // For pointer types, allocate using value.Type().Elem() + graphBytes := int64(unsafe.Sizeof(SliceDemo{})) + if graphBytes == 0 { + graphBytes = 1 + } + if !ctx.ReserveGraphMemory(graphBytes) { + return + } value.Set(reflect.New(value.Type().Elem())) } v = value.Interface().(*SliceDemo) @@ -1830,7 +1848,13 @@ func (g *ValidationDemo_ForyGenSerializer) ReadData(ctx *fory.ReadContext, value var v *ValidationDemo if value.Kind() == reflect.Ptr { if value.IsNil() { - // For pointer types, allocate using value.Type().Elem() + graphBytes := int64(unsafe.Sizeof(ValidationDemo{})) + if graphBytes == 0 { + graphBytes = 1 + } + if !ctx.ReserveGraphMemory(graphBytes) { + return + } value.Set(reflect.New(value.Type().Elem())) } v = value.Interface().(*ValidationDemo) diff --git a/java/fory-annotation-processor/src/main/java/org/apache/fory/annotation/processing/StaticSerializerSourceWriter.java b/java/fory-annotation-processor/src/main/java/org/apache/fory/annotation/processing/StaticSerializerSourceWriter.java index 4d29797710..54adf66696 100644 --- a/java/fory-annotation-processor/src/main/java/org/apache/fory/annotation/processing/StaticSerializerSourceWriter.java +++ b/java/fory-annotation-processor/src/main/java/org/apache/fory/annotation/processing/StaticSerializerSourceWriter.java @@ -247,6 +247,7 @@ private void writeSchemaConsistentRead() { .append(struct.typeName) .append(" readSchemaConsistent(ReadContext readContext) {\n"); builder.append(" MemoryBuffer buffer = readContext.getBuffer();\n"); + builder.append(" reserveObjectGraphMemory(readContext);\n"); builder.append(" if (typeResolver.checkClassVersion()) {\n"); builder.append(" checkClassVersion(buffer.readInt32(), classVersionHash);\n"); builder.append(" }\n"); @@ -803,6 +804,7 @@ private void writeCompatibleRead() { builder.append(" if (sameSchemaCompatible) {\n"); builder.append(" return readSchemaConsistent(readContext);\n"); builder.append(" }\n"); + builder.append(" reserveObjectGraphMemory(readContext);\n"); if (struct.record) { for (SourceField field : struct.fields) { builder diff --git a/java/fory-core/src/main/java/org/apache/fory/builder/ObjectCodecBuilder.java b/java/fory-core/src/main/java/org/apache/fory/builder/ObjectCodecBuilder.java index b467361ac0..2896d8b3c4 100644 --- a/java/fory-core/src/main/java/org/apache/fory/builder/ObjectCodecBuilder.java +++ b/java/fory-core/src/main/java/org/apache/fory/builder/ObjectCodecBuilder.java @@ -793,6 +793,7 @@ private Expression getWriterPos(Expression writerPos, long acc) { public Expression buildDecodeExpression() { Reference buffer = new Reference(BUFFER_NAME, bufferTypeRef, false); ListExpression expressions = new ListExpression(); + expressions.add(new Expression.Block("reserveObjectGraphMemory(" + READ_CONTEXT_NAME + ");")); if (typeResolver.checkClassVersion()) { expressions.add(checkClassVersion(buffer)); } diff --git a/java/fory-core/src/main/java/org/apache/fory/builder/StaticCompatibleCodecBuilder.java b/java/fory-core/src/main/java/org/apache/fory/builder/StaticCompatibleCodecBuilder.java index 4a43269771..2feeea6daf 100644 --- a/java/fory-core/src/main/java/org/apache/fory/builder/StaticCompatibleCodecBuilder.java +++ b/java/fory-core/src/main/java/org/apache/fory/builder/StaticCompatibleCodecBuilder.java @@ -160,6 +160,7 @@ private String genObjectCompatibleRead() { ? "((" + ctx.type(beanClass) + ") " + beanCode.value() + ")" : beanCode.value().toString(); StringBuilder code = new StringBuilder(); + code.append("reserveObjectGraphMemory(").append(READ_CONTEXT_NAME).append(");\n"); if (StringUtils.isNotBlank(beanCode.code())) { code.append(beanCode.code()).append('\n'); } @@ -189,6 +190,7 @@ private String genObjectCompatibleRead() { private String genRecordCompatibleRead() { RecordComponent[] components = RecordUtils.getRecordComponents(beanClass); StringBuilder code = new StringBuilder(); + code.append("reserveObjectGraphMemory(").append(READ_CONTEXT_NAME).append(");\n"); for (int i = 0; i < components.length; i++) { Class componentType = components[i].getType(); code.append(recordLocalType(componentType)) diff --git a/java/fory-core/src/main/java/org/apache/fory/config/Config.java b/java/fory-core/src/main/java/org/apache/fory/config/Config.java index b96e4aa83d..323f4f7bd4 100644 --- a/java/fory-core/src/main/java/org/apache/fory/config/Config.java +++ b/java/fory-core/src/main/java/org/apache/fory/config/Config.java @@ -68,7 +68,7 @@ public class Config implements Serializable { private final int maxTypeMetaBytes; private final int maxSchemaVersionsPerType; private final int maxAverageSchemaVersionsPerType; - private final long maxContainerMemoryBytes; + private final long maxGraphMemoryBytes; private final float mapRefLoadFactor; private final boolean forVirtualThread; @@ -115,7 +115,7 @@ public Config(ForyBuilder builder) { maxTypeMetaBytes = builder.maxTypeMetaBytes; maxSchemaVersionsPerType = builder.maxSchemaVersionsPerType; maxAverageSchemaVersionsPerType = builder.maxAverageSchemaVersionsPerType; - maxContainerMemoryBytes = builder.maxContainerMemoryBytes; + maxGraphMemoryBytes = builder.maxGraphMemoryBytes; mapRefLoadFactor = builder.mapRefLoadFactor; forVirtualThread = builder.forVirtualThread; } @@ -322,9 +322,9 @@ public int maxAverageSchemaVersionsPerType() { return maxAverageSchemaVersionsPerType; } - /** Returns the root-operation estimated container memory limit in bytes, or -1 for auto. */ - public long maxContainerMemoryBytes() { - return maxContainerMemoryBytes; + /** Returns the root-operation estimated graph memory limit in bytes, or -1 for auto. */ + public long maxGraphMemoryBytes() { + return maxGraphMemoryBytes; } /** Returns loadFactor of MacRef's writtenObjects. */ @@ -375,7 +375,7 @@ public boolean equals(Object o) { && maxTypeMetaBytes == config.maxTypeMetaBytes && maxSchemaVersionsPerType == config.maxSchemaVersionsPerType && maxAverageSchemaVersionsPerType == config.maxAverageSchemaVersionsPerType - && maxContainerMemoryBytes == config.maxContainerMemoryBytes + && maxGraphMemoryBytes == config.maxGraphMemoryBytes && Objects.equals(defaultJDKStreamSerializerType, config.defaultJDKStreamSerializerType) && longEncoding == config.longEncoding && forVirtualThread == config.forVirtualThread; @@ -411,7 +411,7 @@ public int hashCode() { maxTypeMetaBytes, maxSchemaVersionsPerType, maxAverageSchemaVersionsPerType, - maxContainerMemoryBytes, + maxGraphMemoryBytes, metaShareEnabled, scopedMetaShareEnabled, metaCompressor, diff --git a/java/fory-core/src/main/java/org/apache/fory/config/ForyBuilder.java b/java/fory-core/src/main/java/org/apache/fory/config/ForyBuilder.java index 93d4943940..b5975502ee 100644 --- a/java/fory-core/src/main/java/org/apache/fory/config/ForyBuilder.java +++ b/java/fory-core/src/main/java/org/apache/fory/config/ForyBuilder.java @@ -103,7 +103,7 @@ public final class ForyBuilder { int maxTypeMetaBytes = 4096; int maxSchemaVersionsPerType = 10; int maxAverageSchemaVersionsPerType = 3; - long maxContainerMemoryBytes = -1; + long maxGraphMemoryBytes = -1; float mapRefLoadFactor = 0.51f; boolean forVirtualThread = false; TypeChecker typeChecker; @@ -573,18 +573,18 @@ public ForyBuilder withMaxAverageSchemaVersionsPerType(int maxAverageSchemaVersi } /** - * Sets the maximum estimated container-owned memory accepted during one root deserialization. + * Sets the maximum estimated graph memory accepted during one root deserialization. * *

The default is {@code -1}, which derives an automatic per-root budget from the input shape. * Positive values are explicit byte limits. Other values are invalid. */ - public ForyBuilder withMaxContainerMemoryBytes(long maxContainerMemoryBytes) { + public ForyBuilder withMaxGraphMemoryBytes(long maxGraphMemoryBytes) { Preconditions.checkArgument( - maxContainerMemoryBytes == -1 || maxContainerMemoryBytes > 0, - "maxContainerMemoryBytes must be positive or -1 for auto but got %s", - maxContainerMemoryBytes); - this.maxContainerMemoryBytes = maxContainerMemoryBytes; - recordAction(b -> b.withMaxContainerMemoryBytes(maxContainerMemoryBytes)); + maxGraphMemoryBytes == -1 || maxGraphMemoryBytes > 0, + "maxGraphMemoryBytes must be positive or -1 for auto but got %s", + maxGraphMemoryBytes); + this.maxGraphMemoryBytes = maxGraphMemoryBytes; + recordAction(b -> b.withMaxGraphMemoryBytes(maxGraphMemoryBytes)); return this; } diff --git a/java/fory-core/src/main/java/org/apache/fory/context/ReadContext.java b/java/fory-core/src/main/java/org/apache/fory/context/ReadContext.java index 53db4131a4..dcd39f1347 100644 --- a/java/fory-core/src/main/java/org/apache/fory/context/ReadContext.java +++ b/java/fory-core/src/main/java/org/apache/fory/context/ReadContext.java @@ -66,7 +66,7 @@ public final class ReadContext { private final boolean compressInt; private final Int64Encoding longEncoding; private final int maxDepth; - private final long maxContainerMemoryBytes; + private final long maxGraphMemoryBytes; private final boolean scopedMetaShareEnabled; private final boolean forVirtualThread; private final IdentityHashMap contextObjects = new IdentityHashMap<>(); @@ -75,8 +75,8 @@ public final class ReadContext { private MetaReadContext metaReadContext; private boolean peerOutOfBandEnabled; private int depth; - private long containerMemoryLimitBytes; - private long remainingContainerMemoryBytes; + private long graphMemoryLimitBytes; + private long remainingGraphMemoryBytes; /** * Creates read-side runtime state for one {@code Fory} instance. @@ -102,7 +102,7 @@ public ReadContext( compressInt = config.compressInt(); longEncoding = config.longEncoding(); maxDepth = config.maxDepth(); - maxContainerMemoryBytes = config.maxContainerMemoryBytes(); + maxGraphMemoryBytes = config.maxGraphMemoryBytes(); forVirtualThread = config.forVirtualThread(); scopedMetaShareEnabled = config.isScopedMetaShareEnabled(); if (scopedMetaShareEnabled) { @@ -123,11 +123,11 @@ public void prepare( this.buffer = buffer; this.peerOutOfBandEnabled = peerOutOfBandEnabled; this.outOfBandBuffers = outOfBandBuffers == null ? null : outOfBandBuffers.iterator(); - initContainerMemoryBudget(rootInputBytes, unknownLengthInput); + initGraphMemoryBudget(rootInputBytes, unknownLengthInput); } - private void initContainerMemoryBudget(int rootInputBytes, boolean unknownLengthInput) { - long limit = maxContainerMemoryBytes; + private void initGraphMemoryBudget(int rootInputBytes, boolean unknownLengthInput) { + long limit = maxGraphMemoryBytes; if (limit <= 0) { if (unknownLengthInput) { limit = STREAM_ROOT_BUDGET_BYTES; @@ -139,8 +139,8 @@ private void initContainerMemoryBudget(int rootInputBytes, boolean unknownLength limit = rootInputBytes * KNOWN_ROOT_BUDGET_MULTIPLIER + KNOWN_ROOT_BUDGET_SLACK_BYTES; } } - containerMemoryLimitBytes = limit; - remainingContainerMemoryBytes = limit; + graphMemoryLimitBytes = limit; + remainingGraphMemoryBytes = limit; } /** @@ -336,8 +336,8 @@ public void reset() { outOfBandBuffers = null; peerOutOfBandEnabled = false; depth = 0; - containerMemoryLimitBytes = 0; - remainingContainerMemoryBytes = 0; + graphMemoryLimitBytes = 0; + remainingGraphMemoryBytes = 0; } /** Returns the immutable runtime configuration for this context. */ @@ -345,31 +345,31 @@ public Config getConfig() { return config; } - public void reserveContainerMemory(long bytes) { + public void reserveGraphMemory(long bytes) { if (bytes < 0) { - throwNegativeContainerMemory(bytes); + throwNegativeGraphMemory(bytes); } - long remaining = remainingContainerMemoryBytes; + long remaining = remainingGraphMemoryBytes; if (bytes > remaining) { - throwContainerMemoryExceeded(bytes, remaining); + throwGraphMemoryExceeded(bytes, remaining); } - remainingContainerMemoryBytes = remaining - bytes; + remainingGraphMemoryBytes = remaining - bytes; } - private void throwNegativeContainerMemory(long bytes) { + private void throwNegativeGraphMemory(long bytes) { throw new InsecureException( - "Estimated container memory must be non-negative, but got " + bytes + " bytes."); + "Estimated graph memory must be non-negative, but got " + bytes + " bytes."); } - private void throwContainerMemoryExceeded(long bytes, long remaining) { + private void throwGraphMemoryExceeded(long bytes, long remaining) { throw new InsecureException( - "Estimated container memory request " + "Estimated graph memory request " + bytes - + " bytes exceeds maxContainerMemoryBytes remaining budget " + + " bytes exceeds maxGraphMemoryBytes remaining budget " + remaining + " bytes out of effective limit " - + containerMemoryLimitBytes - + " bytes. If the data is trusted, increase ForyBuilder#withMaxContainerMemoryBytes."); + + graphMemoryLimitBytes + + " bytes. If the data is trusted, increase ForyBuilder#withMaxGraphMemoryBytes."); } /** Returns the generics stack shared by the owning runtime. */ diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/AbstractObjectSerializer.java b/java/fory-core/src/main/java/org/apache/fory/serializer/AbstractObjectSerializer.java index 5312540d41..49d91ab657 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/AbstractObjectSerializer.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/AbstractObjectSerializer.java @@ -69,6 +69,8 @@ public abstract class AbstractObjectSerializer extends Serializer { private static final Logger LOG = LoggerFactory.getLogger(AbstractObjectSerializer.class); private static final Object SELF_REFERENCE = new Object(); + private static final int OBJECT_SELF_BYTES = 1; + private static final int REFERENCE_BYTES = 4; // Constructor-bound objects reserve a ref id before constructor arguments are read, but the // object cannot be referenced semantically until the constructor returns. Generated constructor // serializers call the tracker before reading ref-tracking constructor-phase fields so nested @@ -79,6 +81,7 @@ public abstract class AbstractObjectSerializer extends Serializer { protected final TypeResolver typeResolver; protected final boolean isRecord; protected final ObjectInstantiator objectInstantiator; + private final long objectGraphMemoryBytes; private SerializationFieldInfo[] fieldInfos; private RecordInfo copyRecordInfo; @@ -88,6 +91,7 @@ protected AbstractObjectSerializer() { this.typeResolver = null; this.isRecord = false; this.objectInstantiator = null; + this.objectGraphMemoryBytes = 0; } public AbstractObjectSerializer(TypeResolver typeResolver, Class type) { @@ -101,6 +105,41 @@ public AbstractObjectSerializer( this.typeResolver = typeResolver; this.isRecord = RecordUtils.isRecord(type); this.objectInstantiator = objectInstantiator; + this.objectGraphMemoryBytes = computeObjectGraphMemoryBytes(type); + } + + protected final void reserveObjectGraphMemory(ReadContext readContext) { + readContext.reserveGraphMemory(objectGraphMemoryBytes); + } + + static long computeObjectGraphMemoryBytes(Class type) { + // One byte is a stable nonzero self cost, not an attempt to model JVM object headers. + long bytes = OBJECT_SELF_BYTES; + for (Field field : ReflectionUtils.getFields(type, true)) { + if (!Modifier.isStatic(field.getModifiers())) { + bytes += fieldGraphMemoryBytes(field.getType()); + } + } + return bytes; + } + + private static int fieldGraphMemoryBytes(Class fieldType) { + if (!fieldType.isPrimitive()) { + return REFERENCE_BYTES; + } + if (fieldType == boolean.class || fieldType == byte.class) { + return 1; + } + if (fieldType == char.class || fieldType == short.class) { + return 2; + } + if (fieldType == int.class || fieldType == float.class) { + return 4; + } + if (fieldType == long.class || fieldType == double.class) { + return 8; + } + return 0; } static void writeField( diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/ArraySerializers.java b/java/fory-core/src/main/java/org/apache/fory/serializer/ArraySerializers.java index 83c2758ea7..4be9ce969f 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/ArraySerializers.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/ArraySerializers.java @@ -46,6 +46,7 @@ */ public final class ArraySerializers { private static final int REFERENCE_BYTES = 4; + private static final int OBJECT_ARRAY_BYTES = 1; private ArraySerializers() {} @@ -61,7 +62,7 @@ private static int readObjectArraySize(ReadContext readContext) { if (numElements < 0) { throwInvalidObjectArraySize(numElements); } - readContext.reserveContainerMemory((long) numElements * REFERENCE_BYTES); + readContext.reserveGraphMemory(OBJECT_ARRAY_BYTES + (long) numElements * REFERENCE_BYTES); buffer.checkReadableBytes(numElements); return numElements; } diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/CompatibleCollectionArrayReader.java b/java/fory-core/src/main/java/org/apache/fory/serializer/CompatibleCollectionArrayReader.java index 49ae18fc71..4fe26cb010 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/CompatibleCollectionArrayReader.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/CompatibleCollectionArrayReader.java @@ -61,6 +61,7 @@ final class CompatibleCollectionArrayReader { // This compatible reader may be reached during native-image analysis. Use the settled // reference-slot fallback instead of touching MemoryBuffer from class initialization. + private static final int COLLECTION_BYTES = 1; private static final int REFERENCE_BYTES = 4; static final int READ_LIST_TO_ARRAY = 1; @@ -983,7 +984,7 @@ private static List readNullableListBoxedElements( ReadContext readContext, int numElements, int arrayTypeId, int elementTypeId) { MemoryBuffer buffer = readContext.getBuffer(); int bodyBytes = minReadablePrimitiveListBytes(numElements, elementTypeId, true); - readContext.reserveContainerMemory((long) numElements * REFERENCE_BYTES); + readContext.reserveGraphMemory(COLLECTION_BYTES + (long) numElements * REFERENCE_BYTES); buffer.checkReadableBytes(bodyBytes); ArrayList values = new ArrayList<>(numElements); for (int i = 0; i < numElements; i++) { @@ -1183,7 +1184,7 @@ private static boolean canMaterializePrimitiveListTarget(Class targetType, in private static List materializeBoxedList( ReadContext readContext, Object array, int arrayTypeId) { int size = java.lang.reflect.Array.getLength(array); - readContext.reserveContainerMemory((long) size * REFERENCE_BYTES); + readContext.reserveGraphMemory(COLLECTION_BYTES + (long) size * REFERENCE_BYTES); ArrayList list = new ArrayList<>(size); switch (arrayTypeId) { case Types.BOOL_ARRAY: diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/CompatibleLayerSerializerBase.java b/java/fory-core/src/main/java/org/apache/fory/serializer/CompatibleLayerSerializerBase.java index e7040f16f4..68aeeaf94b 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/CompatibleLayerSerializerBase.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/CompatibleLayerSerializerBase.java @@ -135,6 +135,7 @@ public Object[] readFieldValues(ReadContext readContext) { @Override public T read(ReadContext readContext) { checkLayerSerializerMeta(); + reserveObjectGraphMemory(readContext); T obj = newBean(); readContext.reference(obj); return readAndSetFields(readContext, obj); diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/CompatibleSerializer.java b/java/fory-core/src/main/java/org/apache/fory/serializer/CompatibleSerializer.java index b22843e05e..c7c024e451 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/CompatibleSerializer.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/CompatibleSerializer.java @@ -237,6 +237,7 @@ private T newInstance() { @Override public T read(ReadContext readContext) { + reserveObjectGraphMemory(readContext); if (isRecord) { Object[] fieldValues = new Object[allFields.length]; if (hasCompatibleCollectionArrayRead) { diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/ExceptionSerializers.java b/java/fory-core/src/main/java/org/apache/fory/serializer/ExceptionSerializers.java index b3b54e18be..d9d348966e 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/ExceptionSerializers.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/ExceptionSerializers.java @@ -62,6 +62,7 @@ public static final class ExceptionSerializer extends Seria private final TypeResolver typeResolver; private final ObjectInstantiator objectInstantiator; private final Constructor messageConstructor; + private final long graphMemoryBytes; private volatile Serializer[] slotsSerializers; private volatile boolean rebuildSlotsSerializersAtRuntime; @@ -74,6 +75,7 @@ public ExceptionSerializer(TypeResolver typeResolver, Class type) { messageConstructor == null && MemoryUtils.JDK_LANG_FIELD_ACCESS ? createThrowableObjectInstantiator(typeResolver, type) : null; + graphMemoryBytes = AbstractObjectSerializer.computeObjectGraphMemoryBytes(type); slotsSerializers = buildSlotsSerializers(typeResolver, type); if (!MemoryUtils.JDK_LANG_FIELD_ACCESS && isJdkThrowable(type) @@ -117,6 +119,7 @@ public T read(ReadContext readContext) { return readAndroidThrowableWithoutDetailMessageField( readContext, stackTrace, slotsSerializers); } + readContext.reserveGraphMemory(graphMemoryBytes); T obj = newThrowableForRead(); readContext.reference(obj); Throwable cause = (Throwable) readContext.readRef(); @@ -157,6 +160,7 @@ private T readAndroidThrowableWithoutDetailMessageField( + " requires JDK internal field access. " + jdkFieldAccessMessage()); } + readContext.reserveGraphMemory(graphMemoryBytes); T obj = newThrowableWithMessage(detailMessage); readContext.reference(obj); if (stackTrace != null) { diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/ObjectSerializer.java b/java/fory-core/src/main/java/org/apache/fory/serializer/ObjectSerializer.java index e656359d14..9e50544860 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/ObjectSerializer.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/ObjectSerializer.java @@ -208,6 +208,7 @@ private void writeFieldByCodecCategory( @Override public T read(ReadContext readContext) { + reserveObjectGraphMemory(readContext); MemoryBuffer buffer = readContext.getBuffer(); if (isRecord) { Object[] fields = readFields(readContext); diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/ObjectStreamSerializer.java b/java/fory-core/src/main/java/org/apache/fory/serializer/ObjectStreamSerializer.java index 149e872862..4ea7532e61 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/ObjectStreamSerializer.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/ObjectStreamSerializer.java @@ -269,6 +269,7 @@ public void write(WriteContext writeContext, Object value) { @Override public Object read(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); + reserveObjectGraphMemory(readContext); Object obj = objectInstantiator.newInstance(); readContext.reference(obj); int numClasses = buffer.readInt16(); diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/UnknownClassSerializers.java b/java/fory-core/src/main/java/org/apache/fory/serializer/UnknownClassSerializers.java index 1356d80a35..efd02dec7d 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/UnknownClassSerializers.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/UnknownClassSerializers.java @@ -58,6 +58,11 @@ private ClassFieldsInfo(FieldGroups fieldGroups, int classVersionHash) { } public static final class UnknownStructSerializer extends Serializer { + private static final int UNKNOWN_STRUCT_SELF_BYTES = 1; + private static final int UNKNOWN_STRUCT_REFERENCE_BYTES = 4; + private static final int UNKNOWN_STRUCT_ENTRY_BYTES = + UNKNOWN_STRUCT_SELF_BYTES + 2 * UNKNOWN_STRUCT_REFERENCE_BYTES; + private static final int NONEXISTENT_META_SHARED_ID_SIZE = computeVarUInt32Size(ClassResolver.NONEXISTENT_META_SHARED_ID); private final Config config; @@ -246,11 +251,15 @@ private ClassFieldsInfo getClassFieldsInfo(TypeDef typeDef) { @Override public Object read(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); + ClassFieldsInfo allFieldsInfo = getClassFieldsInfo(typeDef); + readContext.reserveGraphMemory( + UNKNOWN_STRUCT_SELF_BYTES + + 2L * UNKNOWN_STRUCT_REFERENCE_BYTES + + (long) allFieldsInfo.allFields.length * UNKNOWN_STRUCT_ENTRY_BYTES); UnknownClass.UnknownStruct obj = new UnknownClass.UnknownStruct(typeDef); readContext.reference(obj); List entries = new ArrayList<>(); // Protocol order: primitive, nullable primitive, then all non-primitives by field identifier. - ClassFieldsInfo allFieldsInfo = getClassFieldsInfo(typeDef); Generics generics = readContext.getGenerics(); for (SerializationFieldInfo fieldInfo : allFieldsInfo.allFields) { Object fieldValue = readFieldByCodecCategory(readContext, generics, fieldInfo, buffer); diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/CollectionLikeSerializer.java b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/CollectionLikeSerializer.java index 10c39dc49f..2796a9b2bf 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/CollectionLikeSerializer.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/CollectionLikeSerializer.java @@ -46,6 +46,7 @@ */ @SuppressWarnings({"unchecked", "rawtypes"}) public abstract class CollectionLikeSerializer extends Serializer { + private static final int COLLECTION_BYTES = 1; private static final int REFERENCE_BYTES = 4; private MethodHandle constructor; @@ -565,7 +566,7 @@ protected void setNumElements(int numElements) { protected final int readCollectionSize(ReadContext readContext, MemoryBuffer buffer) { int numElements = buffer.readVarUInt32Small7(); checkCollectionSize(numElements); - readContext.reserveContainerMemory((long) numElements * REFERENCE_BYTES); + readContext.reserveGraphMemory(COLLECTION_BYTES + (long) numElements * REFERENCE_BYTES); buffer.checkReadableBytes(numElements); return numElements; } diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/CollectionSerializers.java b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/CollectionSerializers.java index 5fcbffd1e5..69118249a0 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/CollectionSerializers.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/CollectionSerializers.java @@ -930,7 +930,7 @@ public ArrayBlockingQueue newCollection(ReadContext readContext) { setNumElements(numElements); int capacity = buffer.readVarUInt32Small7(); checkBoundedQueueCapacity(numElements, capacity); - readContext.reserveContainerMemory((long) (capacity - numElements) * REFERENCE_BYTES); + readContext.reserveGraphMemory((long) (capacity - numElements) * REFERENCE_BYTES); buffer.checkReadableBytes(capacity); ArrayBlockingQueue queue = new ArrayBlockingQueue<>(capacity); readContext.reference(queue); diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/MapLikeSerializer.java b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/MapLikeSerializer.java index 4c1fc65ff9..eab1479522 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/MapLikeSerializer.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/MapLikeSerializer.java @@ -58,6 +58,7 @@ @SuppressWarnings({"unchecked", "rawtypes"}) public abstract class MapLikeSerializer extends Serializer { public static final int MAX_CHUNK_SIZE = 255; + private static final int MAP_BYTES = 1; private static final int REFERENCE_BYTES = 4; static final class MapTypeCache { @@ -971,7 +972,7 @@ protected final int readMapSize(ReadContext readContext, MemoryBuffer buffer) { if (numElements > Integer.MAX_VALUE / 2) { throwInvalidMapBodySize(numElements); } - readContext.reserveContainerMemory((long) numElements * 2 * REFERENCE_BYTES); + readContext.reserveGraphMemory(MAP_BYTES + (long) numElements * 2 * REFERENCE_BYTES); buffer.checkReadableBytes(numElements << 1); return numElements; } diff --git a/java/fory-core/src/test/java/org/apache/fory/serializer/ContainerMemoryBudgetTest.java b/java/fory-core/src/test/java/org/apache/fory/serializer/GraphMemoryBudgetTest.java similarity index 69% rename from java/fory-core/src/test/java/org/apache/fory/serializer/ContainerMemoryBudgetTest.java rename to java/fory-core/src/test/java/org/apache/fory/serializer/GraphMemoryBudgetTest.java index 0c5ecf3e0d..f14559c511 100644 --- a/java/fory-core/src/test/java/org/apache/fory/serializer/ContainerMemoryBudgetTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/serializer/GraphMemoryBudgetTest.java @@ -36,18 +36,19 @@ import org.apache.fory.memory.MemoryBuffer; import org.testng.annotations.Test; -public class ContainerMemoryBudgetTest extends ForyTestBase { +public class GraphMemoryBudgetTest extends ForyTestBase { private static final long KNOWN_ROOT_MULTIPLIER = 8L; private static final long KNOWN_ROOT_SLACK_BYTES = 64L * 1024; private static final long STREAM_ROOT_BYTES = 128L * 1024 * 1024; private static final int REFERENCE_BYTES = 4; + private static final int OBJECT_SELF_BYTES = 1; @Test public void testConfigValidation() { - assertEquals(newFory(-1).getConfig().maxContainerMemoryBytes(), -1); - assertEquals(newFory(123).getConfig().maxContainerMemoryBytes(), 123); - assertThrows(IllegalArgumentException.class, () -> builder().withMaxContainerMemoryBytes(0)); - assertThrows(IllegalArgumentException.class, () -> builder().withMaxContainerMemoryBytes(-2)); + assertEquals(newFory(-1).getConfig().maxGraphMemoryBytes(), -1); + assertEquals(newFory(123).getConfig().maxGraphMemoryBytes(), 123); + assertThrows(IllegalArgumentException.class, () -> builder().withMaxGraphMemoryBytes(0)); + assertThrows(IllegalArgumentException.class, () -> builder().withMaxGraphMemoryBytes(-2)); } @Test @@ -56,8 +57,8 @@ public void testKnownAutoBudget() { ReadContext readContext = prepareContext(fory, 17, false); try { long budget = knownAutoBytes(17); - readContext.reserveContainerMemory(budget); - assertThrows(InsecureException.class, () -> readContext.reserveContainerMemory(1)); + readContext.reserveGraphMemory(budget); + assertThrows(InsecureException.class, () -> readContext.reserveGraphMemory(1)); } finally { readContext.reset(); } @@ -68,8 +69,8 @@ public void testStreamAutoBudget() { Fory fory = newFory(-1); ReadContext readContext = prepareContext(fory, 17, true); try { - readContext.reserveContainerMemory(STREAM_ROOT_BYTES); - assertThrows(InsecureException.class, () -> readContext.reserveContainerMemory(1)); + readContext.reserveGraphMemory(STREAM_ROOT_BYTES); + assertThrows(InsecureException.class, () -> readContext.reserveGraphMemory(1)); } finally { readContext.reset(); } @@ -80,8 +81,8 @@ public void testExplicitBudgetWins() { Fory fory = newFory(7); ReadContext readContext = prepareContext(fory, 1024 * 1024, false); try { - readContext.reserveContainerMemory(7); - assertThrows(InsecureException.class, () -> readContext.reserveContainerMemory(1)); + readContext.reserveGraphMemory(7); + assertThrows(InsecureException.class, () -> readContext.reserveGraphMemory(1)); } finally { readContext.reset(); } @@ -91,7 +92,7 @@ public void testExplicitBudgetWins() { public void testNestedEmptyContainersUseParentStorage() { List value = emptyLists(1); byte[] bytes = newFory(-1).serialize(value); - long required = collectionBytes(1); + long required = collectionBytes(1) + collectionBytes(0); assertThrows(InsecureException.class, () -> newFory(required - 1).deserialize(bytes)); assertEquals(newFory(required).deserialize(bytes), value); @@ -104,7 +105,7 @@ public void testSiblingBudgetIsCumulative() { long firstChildOnly = collectionBytes(2) + collectionBytes(64); assertThrows(InsecureException.class, () -> newFory(firstChildOnly).deserialize(bytes)); - assertEquals(newFory(firstChildOnly + collectionBytes(64)).deserialize(bytes), value); + assertEquals(newFory(collectionBytes(2) + 2L * collectionBytes(64)).deserialize(bytes), value); } @Test @@ -112,7 +113,7 @@ public void testMapBudgetAndOverflow() { Fory fory = newFory(mapBytes(1) - 1); ReadContext readContext = prepareContext(fory, 8, false); try { - assertThrows(InsecureException.class, () -> readContext.reserveContainerMemory(mapBytes(1))); + assertThrows(InsecureException.class, () -> readContext.reserveGraphMemory(mapBytes(1))); } finally { readContext.reset(); } @@ -120,8 +121,8 @@ public void testMapBudgetAndOverflow() { Fory exactFory = newFory(mapBytes(1)); ReadContext exactContext = prepareContext(exactFory, 8, false); try { - exactContext.reserveContainerMemory(mapBytes(1)); - assertThrows(InsecureException.class, () -> exactContext.reserveContainerMemory(1)); + exactContext.reserveGraphMemory(mapBytes(1)); + assertThrows(InsecureException.class, () -> exactContext.reserveGraphMemory(1)); } finally { exactContext.reset(); } @@ -166,10 +167,38 @@ public void testObjectArrayBudget() { } } + @Test + public void testPojoGraphBudget() { + Pojo value = new Pojo(7, 9L, "child string is skipped as a leaf"); + byte[] bytes = newFory(-1).serialize(value); + long required = pojoBytes(); + + assertThrows(InsecureException.class, () -> newFory(required - 1, false).deserialize(bytes)); + assertEquals(newFory(required, false).deserialize(bytes), value); + + assertThrows(InsecureException.class, () -> newFory(required - 1, true).deserialize(bytes)); + assertEquals(newFory(required, true).deserialize(bytes), value); + } + + @Test + public void testNestedEmptyPojoGraphBudget() { + ArrayList value = new ArrayList<>(); + value.add(new EmptyPojo()); + value.add(new EmptyPojo()); + byte[] bytes = newFory(-1).serialize(value); + long required = collectionBytes(2) + 2L * emptyPojoBytes(); + + assertThrows(InsecureException.class, () -> newFory(required - 1).deserialize(bytes)); + List decoded = (List) newFory(required).deserialize(bytes); + assertEquals(decoded.size(), 2); + assertTrue(decoded.get(0) instanceof EmptyPojo); + assertTrue(decoded.get(1) instanceof EmptyPojo); + } + @Test public void testScalarOwnersSkipBudget() { Fory fory = newFory(1); - assertEquals(fory.deserialize(fory.serialize("container budget")), "container budget"); + assertEquals(fory.deserialize(fory.serialize("graph budget")), "graph budget"); byte[] bytes = new byte[] {1, 2, 3}; assertTrue(Arrays.equals((byte[]) fory.deserialize(fory.serialize(bytes)), bytes)); @@ -200,8 +229,12 @@ public void testTruncatedCollectionStillFails() { } } - private static Fory newFory(long maxContainerMemoryBytes) { - return builder().withMaxContainerMemoryBytes(maxContainerMemoryBytes).build(); + private static Fory newFory(long maxGraphMemoryBytes) { + return newFory(maxGraphMemoryBytes, true); + } + + private static Fory newFory(long maxGraphMemoryBytes, boolean codegen) { + return builder().withMaxGraphMemoryBytes(maxGraphMemoryBytes).withCodegen(codegen).build(); } private static ReadContext prepareContext( @@ -213,15 +246,23 @@ private static ReadContext prepareContext( } private static long collectionBytes(int numElements) { - return (long) numElements * REFERENCE_BYTES; + return OBJECT_SELF_BYTES + (long) numElements * REFERENCE_BYTES; } private static long mapBytes(int numElements) { - return (long) numElements * 2 * REFERENCE_BYTES; + return OBJECT_SELF_BYTES + (long) numElements * 2 * REFERENCE_BYTES; } private static long objectArrayBytes(int numElements) { - return (long) numElements * REFERENCE_BYTES; + return OBJECT_SELF_BYTES + (long) numElements * REFERENCE_BYTES; + } + + private static long emptyPojoBytes() { + return OBJECT_SELF_BYTES; + } + + private static long pojoBytes() { + return OBJECT_SELF_BYTES + 4 + 8 + REFERENCE_BYTES; } private static long knownAutoBytes(int inputBytes) { @@ -257,4 +298,36 @@ private static MemoryBuffer objectArraySizeBuffer(int numElements) { private static MemoryBuffer trimBuffer(MemoryBuffer buffer) { return MemoryBuffer.fromByteArray(buffer.getBytes(0, buffer.writerIndex())); } + + public static final class EmptyPojo {} + + public static final class Pojo { + public int intValue; + public long longValue; + public String name; + + public Pojo() {} + + Pojo(int intValue, long longValue, String name) { + this.intValue = intValue; + this.longValue = longValue; + this.name = name; + } + + @Override + public boolean equals(Object obj) { + if (!(obj instanceof Pojo)) { + return false; + } + Pojo other = (Pojo) obj; + return intValue == other.intValue + && longValue == other.longValue + && java.util.Objects.equals(name, other.name); + } + + @Override + public int hashCode() { + return java.util.Objects.hash(intValue, longValue, name); + } + } } diff --git a/javascript/packages/core/lib/context.ts b/javascript/packages/core/lib/context.ts index dab5a3766e..3d2e4b9f2b 100644 --- a/javascript/packages/core/lib/context.ts +++ b/javascript/packages/core/lib/context.ts @@ -550,11 +550,11 @@ export class ReadContext { private _depth = 0; private _maxDepth: number; - private readonly maxContainerMemoryBytes: number; - private effectiveContainerMemoryBytes = 0; - private remainingContainerMemoryBytes = 0; - private remoteSchemaVersionsByType: Map | undefined - = undefined; + private readonly maxGraphMemoryBytes: number; + private effectiveGraphMemoryBytes = 0; + private remainingGraphMemoryBytes = 0; + private remoteSchemaVersionsByType: Map | undefined = + undefined; constructor( readonly typeResolver: TypeResolverLike, @@ -564,7 +564,7 @@ export class ReadContext { this.refReader = new RefReader(this.reader); this.metaStringReader = new MetaStringReader(); this._maxDepth = config.maxDepth ?? 50; - this.maxContainerMemoryBytes = config.maxContainerMemoryBytes; + this.maxGraphMemoryBytes = config.maxGraphMemoryBytes; } reset(bytes: Uint8Array) { @@ -573,35 +573,36 @@ export class ReadContext { this.metaStringReader.reset(); this.typeMeta = []; this._depth = 0; - this.effectiveContainerMemoryBytes = this.maxContainerMemoryBytes > 0 - ? this.maxContainerMemoryBytes - : bytes.byteLength * ReadContext.KNOWN_ROOT_BUDGET_MULTIPLIER - + ReadContext.KNOWN_ROOT_BUDGET_SLACK_BYTES; - this.remainingContainerMemoryBytes = this.effectiveContainerMemoryBytes; + this.effectiveGraphMemoryBytes = + this.maxGraphMemoryBytes > 0 + ? this.maxGraphMemoryBytes + : bytes.byteLength * ReadContext.KNOWN_ROOT_BUDGET_MULTIPLIER + + ReadContext.KNOWN_ROOT_BUDGET_SLACK_BYTES; + this.remainingGraphMemoryBytes = this.effectiveGraphMemoryBytes; } - reserveContainerMemory(bytes: number) { + reserveGraphMemory(bytes: number) { if (!Number.isSafeInteger(bytes) || bytes < 0) { - this.throwContainerMemoryOverflow(bytes); + this.throwGraphMemoryOverflow(bytes); } - const remaining = this.remainingContainerMemoryBytes - bytes; + const remaining = this.remainingGraphMemoryBytes - bytes; if (remaining < 0) { - this.throwContainerBudgetExceeded(bytes); + this.throwGraphBudgetExceeded(bytes); } - this.remainingContainerMemoryBytes = remaining; + this.remainingGraphMemoryBytes = remaining; } - private throwContainerMemoryOverflow(bytes: number): never { + private throwGraphMemoryOverflow(bytes: number): never { throw new Error( - `maxContainerMemoryBytes overflow: requested ${bytes} estimated container bytes`, + `maxGraphMemoryBytes overflow: requested ${bytes} estimated graph bytes`, ); } - private throwContainerBudgetExceeded(bytes: number): never { + private throwGraphBudgetExceeded(bytes: number): never { throw new Error( - `maxContainerMemoryBytes exceeded: requested ${bytes} estimated container bytes, ` - + `${this.remainingContainerMemoryBytes} remaining, effective limit ` - + `${this.effectiveContainerMemoryBytes}`, + `maxGraphMemoryBytes exceeded: requested ${bytes} estimated graph bytes, ` + + `${this.remainingGraphMemoryBytes} remaining, effective limit ` + + `${this.effectiveGraphMemoryBytes}`, ); } @@ -613,8 +614,8 @@ export class ReadContext { this._depth++; if (this._depth > this._maxDepth) { throw new Error( - `Deserialization depth limit exceeded: ${this._depth} > ${this._maxDepth}. ` - + "The data may be malicious, or increase maxDepth if needed.", + `Deserialization depth limit exceeded: ${this._depth} > ${this._maxDepth}. ` + + "The data may be malicious, or increase maxDepth if needed.", ); } } @@ -801,14 +802,14 @@ export class ReadContext { expectedTypeName: string, ) { if ( - typeMeta.getTypeId() !== expectedTypeId - || typeMeta.getNs() !== expectedNamespace - || typeMeta.getTypeName() !== expectedTypeName + typeMeta.getTypeId() !== expectedTypeId || + typeMeta.getNs() !== expectedNamespace || + typeMeta.getTypeName() !== expectedTypeName ) { throw new Error( - `TypeMeta mismatch: expected ${expectedNamespace}$${expectedTypeName} ` - + `type ${expectedTypeId}, got ${typeMeta.getNs()}$${typeMeta.getTypeName()} ` - + `type ${typeMeta.getTypeId()}`, + `TypeMeta mismatch: expected ${expectedNamespace}$${expectedTypeName} ` + + `type ${expectedTypeId}, got ${typeMeta.getNs()}$${typeMeta.getTypeName()} ` + + `type ${typeMeta.getTypeId()}`, ); } } @@ -856,8 +857,8 @@ export class ReadContext { } else { const localSerializer = original ?? this.serializerByTypeMeta(typeMeta); if ( - localSerializer === undefined - && !TypeId.structType(typeMeta.getTypeId()) + localSerializer === undefined && + !TypeId.structType(typeMeta.getTypeId()) ) { throw new Error( `can't find serializer for TypeMeta ${typeMeta.getNs()}$${typeMeta.getTypeName()}`, @@ -865,8 +866,8 @@ export class ReadContext { } const typeKey = this.checkRemoteTypeMetaLimit(typeMeta); if ( - localSerializer !== undefined - && TypeId.structType(typeMeta.getTypeId()) + localSerializer !== undefined && + TypeId.structType(typeMeta.getTypeId()) ) { const expectedHash = localHash ?? localSerializer.getHash(); if (expectedHash !== typeMeta.getHash()) { @@ -878,8 +879,8 @@ export class ReadContext { ); } } else if ( - localHash !== undefined - && localHash !== typeMeta.getHash() + localHash !== undefined && + localHash !== typeMeta.getHash() ) { this.ensureCompatibleReadSerializer( typeMeta, @@ -919,33 +920,33 @@ export class ReadContext { : typeMeta.getUserTypeId(); const versionsByType = this.remoteSchemaVersionsByType; const versionsForType = versionsByType?.get(typeKey) ?? 0; - const maxSchemaVersionsPerType - = this.typeResolver.config.maxSchemaVersionsPerType; + const maxSchemaVersionsPerType = + this.typeResolver.config.maxSchemaVersionsPerType; if (versionsForType >= maxSchemaVersionsPerType) { throw new Error( - `Remote schema version limit exceeded for type ${String(typeKey)}: ` - + `${versionsForType} >= ${maxSchemaVersionsPerType}. The data may ` - + "be malicious. If the data is not malicious, please increase " - + "maxSchemaVersionsPerType.", + `Remote schema version limit exceeded for type ${String(typeKey)}: ` + + `${versionsForType} >= ${maxSchemaVersionsPerType}. The data may ` + + "be malicious. If the data is not malicious, please increase " + + "maxSchemaVersionsPerType.", ); } - const acceptedTypeCount - = versionsForType === 0 + const acceptedTypeCount = + versionsForType === 0 ? (versionsByType?.size ?? 0) + 1 : versionsByType!.size; - const maxAverageSchemaVersionsPerType - = this.typeResolver.config.maxAverageSchemaVersionsPerType; + const maxAverageSchemaVersionsPerType = + this.typeResolver.config.maxAverageSchemaVersionsPerType; const globalLimit = Math.max( ReadContext.MIN_REMOTE_TYPE_META_LIMIT, acceptedTypeCount * maxAverageSchemaVersionsPerType, ); if (this.totalAcceptedSchemaVersions >= globalLimit) { throw new Error( - `Remote schema version limit exceeded: ${this.totalAcceptedSchemaVersions} ` - + `metadata versions for ${acceptedTypeCount} accepted remote types ` - + `exceeds the average limit ${maxAverageSchemaVersionsPerType}. ` - + "The data may be malicious. If the data is not malicious, please " - + "increase maxAverageSchemaVersionsPerType.", + `Remote schema version limit exceeded: ${this.totalAcceptedSchemaVersions} ` + + `metadata versions for ${acceptedTypeCount} accepted remote types ` + + `exceeds the average limit ${maxAverageSchemaVersionsPerType}. ` + + "The data may be malicious. If the data is not malicious, please " + + "increase maxAverageSchemaVersionsPerType.", ); } return typeKey; @@ -1044,16 +1045,16 @@ export class ReadContext { return false; } if ( - (remote.trackingRef === true) !== (local.trackingRef === true) - || (remote.nullable === true) !== (local.nullable === true) + (remote.trackingRef === true) !== (local.trackingRef === true) || + (remote.nullable === true) !== (local.nullable === true) ) { return false; } switch (remote.typeId) { case TypeId.MAP: return ( - this.fieldSchemasEqual(remote.options?.key, local.options?.key) - && this.fieldSchemasEqual(remote.options?.value, local.options?.value) + this.fieldSchemasEqual(remote.options?.key, local.options?.key) && + this.fieldSchemasEqual(remote.options?.value, local.options?.value) ); case TypeId.LIST: return this.fieldSchemasEqual( @@ -1084,24 +1085,24 @@ export class ReadContext { return compatible; } if ( - isCompatibleScalarType(fieldInfo.typeId) - && isCompatibleScalarType(fallbackTypeInfo.typeId) - && ((fieldInfo.trackingRef === true) - !== (fallbackTypeInfo.trackingRef === true) - || ((fieldInfo.trackingRef === true - || fallbackTypeInfo.trackingRef === true) - && (fieldInfo.typeId !== fallbackTypeInfo.typeId - || fieldInfo.nullable !== fallbackTypeInfo.nullable))) + isCompatibleScalarType(fieldInfo.typeId) && + isCompatibleScalarType(fallbackTypeInfo.typeId) && + ((fieldInfo.trackingRef === true) !== + (fallbackTypeInfo.trackingRef === true) || + ((fieldInfo.trackingRef === true || + fallbackTypeInfo.trackingRef === true) && + (fieldInfo.typeId !== fallbackTypeInfo.typeId || + fieldInfo.nullable !== fallbackTypeInfo.nullable))) ) { throw new Error( "unsupported compatible scalar tracking-ref schema mismatch", ); } if ( - isCompatibleScalarPair(fieldInfo.typeId, fallbackTypeInfo.typeId) - && fieldInfo.typeId !== fallbackTypeInfo.typeId - && (fieldInfo.trackingRef === true - || fallbackTypeInfo.trackingRef === true) + isCompatibleScalarPair(fieldInfo.typeId, fallbackTypeInfo.typeId) && + fieldInfo.typeId !== fallbackTypeInfo.typeId && + (fieldInfo.trackingRef === true || + fallbackTypeInfo.trackingRef === true) ) { throw new Error( "unsupported compatible scalar tracking-ref schema mismatch", @@ -1117,10 +1118,10 @@ export class ReadContext { throw new Error("unsupported compatible list/array schema mismatch"); } if ( - fieldInfo.typeId !== TypeId.UNKNOWN - && this.canonicalFieldTypeId(fallbackTypeInfo) !== TypeId.UNKNOWN - && this.canonicalTypeId(fieldInfo.typeId) - !== this.canonicalFieldTypeId(fallbackTypeInfo) + fieldInfo.typeId !== TypeId.UNKNOWN && + this.canonicalFieldTypeId(fallbackTypeInfo) !== TypeId.UNKNOWN && + this.canonicalTypeId(fieldInfo.typeId) !== + this.canonicalFieldTypeId(fallbackTypeInfo) ) { throw new Error("unsupported compatible field schema mismatch"); } @@ -1210,31 +1211,31 @@ export class ReadContext { return false; } if ( - this.schemaMatchTypeId(remote.typeId) - !== this.schemaMatchTypeId(this.typeResolver.computeTypeId(local)) + this.schemaMatchTypeId(remote.typeId) !== + this.schemaMatchTypeId(this.typeResolver.computeTypeId(local)) ) { return true; } const remoteTracksRef = remote.trackingRef === true; const localTracksRef = local.trackingRef === true; if ( - remoteTracksRef !== localTracksRef - || ((remoteTracksRef || localTracksRef) - && (remote.nullable === true) !== (local.nullable === true)) + remoteTracksRef !== localTracksRef || + ((remoteTracksRef || localTracksRef) && + (remote.nullable === true) !== (local.nullable === true)) ) { return true; } switch (remote.typeId) { case TypeId.MAP: return ( - local.options?.key === undefined - || local.options?.value === undefined - || this.hasNestedSchemaMismatch( + local.options?.key === undefined || + local.options?.value === undefined || + this.hasNestedSchemaMismatch( remote.options!.key!, local.options.key, false, - ) - || this.hasNestedSchemaMismatch( + ) || + this.hasNestedSchemaMismatch( remote.options!.value!, local.options.value, false, @@ -1242,8 +1243,8 @@ export class ReadContext { ); case TypeId.LIST: return ( - local.options?.inner === undefined - || this.hasNestedSchemaMismatch( + local.options?.inner === undefined || + this.hasNestedSchemaMismatch( remote.options!.inner!, local.options.inner, false, @@ -1251,8 +1252,8 @@ export class ReadContext { ); case TypeId.SET: return ( - local.options?.key === undefined - || this.hasNestedSchemaMismatch( + local.options?.key === undefined || + this.hasNestedSchemaMismatch( remote.options!.key!, local.options.key, false, @@ -1273,19 +1274,19 @@ export class ReadContext { ): TypeInfo | undefined { if (this.isByteSequenceRootPair(remote, local)) { if ( - (remote.nullable === true) !== (local.nullable === true) - || (remote.trackingRef === true) !== (local.trackingRef === true) + (remote.nullable === true) !== (local.nullable === true) || + (remote.trackingRef === true) !== (local.trackingRef === true) ) { return undefined; } return local.clone(); } if ( - this.isListArrayRootPair(remote, local) - && (remote.nullable === true - || local.nullable === true - || remote.trackingRef === true - || local.trackingRef === true) + this.isListArrayRootPair(remote, local) && + (remote.nullable === true || + local.nullable === true || + remote.trackingRef === true || + local.trackingRef === true) ) { return undefined; } @@ -1305,22 +1306,22 @@ export class ReadContext { } const remoteArrayElement = denseArrayElementTypeId(remote.typeId); if ( - remoteArrayElement !== undefined - && local.typeId === TypeId.LIST - && local.options?.inner - && compatibleArrayElementTypeId(local.options.inner.typeId) - === remoteArrayElement + remoteArrayElement !== undefined && + local.typeId === TypeId.LIST && + local.options?.inner && + compatibleArrayElementTypeId(local.options.inner.typeId) === + remoteArrayElement ) { return compatibleArrayToListTypeInfo(remoteArrayElement); } if ( - remote.trackingRef !== true - && local.trackingRef !== true - && !( - remote.typeId === local.typeId - && (remote.nullable === true) === (local.nullable === true) - ) - && isCompatibleScalarPair(remote.typeId, local.typeId) + remote.trackingRef !== true && + local.trackingRef !== true && + !( + remote.typeId === local.typeId && + (remote.nullable === true) === (local.nullable === true) + ) && + isCompatibleScalarPair(remote.typeId, local.typeId) ) { return markCompatibleScalarRead(local.clone(), { remoteTypeId: remote.typeId, @@ -1352,8 +1353,8 @@ export class ReadContext { remote.options!.key!, local.options?.key, false, - ) - || this.hasUnsupportedListArrayMismatch( + ) || + this.hasUnsupportedListArrayMismatch( remote.options!.value!, local.options?.value, false, @@ -1381,10 +1382,10 @@ export class ReadContext { local: TypeInfo, ): boolean { return ( - (remote.typeId === TypeId.LIST - && denseArrayElementTypeId(local.typeId) !== undefined) - || (denseArrayElementTypeId(remote.typeId) !== undefined - && local.typeId === TypeId.LIST) + (remote.typeId === TypeId.LIST && + denseArrayElementTypeId(local.typeId) !== undefined) || + (denseArrayElementTypeId(remote.typeId) !== undefined && + local.typeId === TypeId.LIST) ); } @@ -1393,9 +1394,9 @@ export class ReadContext { local: TypeInfo, ): boolean { return ( - (remote.typeId === TypeId.BINARY - && local.typeId === TypeId.UINT8_ARRAY) - || (remote.typeId === TypeId.UINT8_ARRAY && local.typeId === TypeId.BINARY) + (remote.typeId === TypeId.BINARY && + local.typeId === TypeId.UINT8_ARRAY) || + (remote.typeId === TypeId.UINT8_ARRAY && local.typeId === TypeId.BINARY) ); } diff --git a/javascript/packages/core/lib/fory.ts b/javascript/packages/core/lib/fory.ts index 979216a5b9..0e11fd6c2a 100644 --- a/javascript/packages/core/lib/fory.ts +++ b/javascript/packages/core/lib/fory.ts @@ -38,22 +38,30 @@ const DEFAULT_MAX_TYPE_FIELDS = 512 as const; const DEFAULT_MAX_TYPE_META_BYTES = 4096 as const; const DEFAULT_MAX_SCHEMA_VERSIONS_PER_TYPE = 10 as const; const DEFAULT_MAX_AVERAGE_SCHEMA_VERSIONS_PER_TYPE = 3 as const; -const DEFAULT_MAX_CONTAINER_MEMORY_BYTES = -1 as const; +const DEFAULT_MAX_GRAPH_MEMORY_BYTES = -1 as const; export default class Fory { readonly typeResolver: TypeResolver; readonly anySerializer: Serializer; readonly config: Config; readonly writeContext: WriteContext; readonly readContext: ReadContext; - private readonly rootSerializers = new WeakMap PlatformBuffer>(); + private readonly rootSerializers = new WeakMap< + Serializer, + (data: any) => PlatformBuffer + >(); - private readonly rootDeserializers = new WeakMap any>(); + private readonly rootDeserializers = new WeakMap< + Serializer, + (bytes: Uint8Array) => any + >(); constructor(config?: Partial) { this.config = this.initConfig(config); const maxDepth = this.config.maxDepth ?? DEFAULT_DEPTH_LIMIT; if (!Number.isInteger(maxDepth) || maxDepth < MIN_DEPTH_LIMIT) { - throw new Error(`maxDepth must be an integer >= ${MIN_DEPTH_LIMIT} but got ${maxDepth}`); + throw new Error( + `maxDepth must be an integer >= ${MIN_DEPTH_LIMIT} but got ${maxDepth}`, + ); } this.typeResolver = new TypeResolver(this.config); this.writeContext = new WriteContext(this.typeResolver, this.config); @@ -66,21 +74,30 @@ export default class Fory { private initConfig(config: Partial | undefined) { const maxTypeFields = config?.maxTypeFields ?? DEFAULT_MAX_TYPE_FIELDS; if (!Number.isInteger(maxTypeFields) || maxTypeFields <= 0) { - throw new Error(`maxTypeFields must be a positive integer but got ${maxTypeFields}`); + throw new Error( + `maxTypeFields must be a positive integer but got ${maxTypeFields}`, + ); } - const maxTypeMetaBytes = config?.maxTypeMetaBytes ?? DEFAULT_MAX_TYPE_META_BYTES; + const maxTypeMetaBytes = + config?.maxTypeMetaBytes ?? DEFAULT_MAX_TYPE_META_BYTES; if (!Number.isInteger(maxTypeMetaBytes) || maxTypeMetaBytes <= 0) { - throw new Error(`maxTypeMetaBytes must be a positive integer but got ${maxTypeMetaBytes}`); + throw new Error( + `maxTypeMetaBytes must be a positive integer but got ${maxTypeMetaBytes}`, + ); } const maxSchemaVersionsPerType = config?.maxSchemaVersionsPerType ?? DEFAULT_MAX_SCHEMA_VERSIONS_PER_TYPE; - if (!Number.isInteger(maxSchemaVersionsPerType) || maxSchemaVersionsPerType <= 0) { + if ( + !Number.isInteger(maxSchemaVersionsPerType) || + maxSchemaVersionsPerType <= 0 + ) { throw new Error( `maxSchemaVersionsPerType must be a positive integer but got ${maxSchemaVersionsPerType}`, ); } const maxAverageSchemaVersionsPerType = - config?.maxAverageSchemaVersionsPerType ?? DEFAULT_MAX_AVERAGE_SCHEMA_VERSIONS_PER_TYPE; + config?.maxAverageSchemaVersionsPerType ?? + DEFAULT_MAX_AVERAGE_SCHEMA_VERSIONS_PER_TYPE; if ( !Number.isInteger(maxAverageSchemaVersionsPerType) || maxAverageSchemaVersionsPerType <= 0 @@ -89,21 +106,21 @@ export default class Fory { `maxAverageSchemaVersionsPerType must be a positive integer but got ${maxAverageSchemaVersionsPerType}`, ); } - const maxContainerMemoryBytes - = config?.maxContainerMemoryBytes ?? DEFAULT_MAX_CONTAINER_MEMORY_BYTES; + const maxGraphMemoryBytes = + config?.maxGraphMemoryBytes ?? DEFAULT_MAX_GRAPH_MEMORY_BYTES; if ( - !Number.isSafeInteger(maxContainerMemoryBytes) - || (maxContainerMemoryBytes !== -1 && maxContainerMemoryBytes <= 0) + !Number.isSafeInteger(maxGraphMemoryBytes) || + (maxGraphMemoryBytes !== -1 && maxGraphMemoryBytes <= 0) ) { throw new Error( - `maxContainerMemoryBytes must be -1 or a positive safe integer but got ${maxContainerMemoryBytes}`, + `maxGraphMemoryBytes must be -1 or a positive safe integer but got ${maxGraphMemoryBytes}`, ); } return { ref: Boolean(config?.ref), useSliceString: Boolean(config?.useSliceString), maxDepth: config?.maxDepth, - maxContainerMemoryBytes, + maxGraphMemoryBytes, maxTypeFields, maxTypeMetaBytes, maxSchemaVersionsPerType, @@ -139,8 +156,9 @@ export default class Fory { register(constructor: any, customSerializer?: CustomSerializer) { let serializer: Serializer; if (constructor.prototype?.[ForyTypeInfoSymbol]) { - const typeInfo: TypeInfo = (constructor.prototype[ForyTypeInfoSymbol] as WithForyClsInfo) - .structTypeInfo; + const typeInfo: TypeInfo = ( + constructor.prototype[ForyTypeInfoSymbol] as WithForyClsInfo + ).structTypeInfo; typeInfo.freeze(); serializer = new Gen(this.typeResolver, { creator: constructor, @@ -162,7 +180,10 @@ export default class Fory { }; } - deserialize(bytes: Uint8Array, serializer: Serializer = this.anySerializer): T | null { + deserialize( + bytes: Uint8Array, + serializer: Serializer = this.anySerializer, + ): T | null { this.readContext.reset(bytes); const reader = this.readContext.reader; const bitmap = reader.readUint8(); @@ -173,9 +194,12 @@ export default class Fory { } private throwInvalidRootHeader(bitmap: number): never { - const knownFlags = ConfigFlags.isCrossLanguageFlag | ConfigFlags.isOutOfBandFlag; + const knownFlags = + ConfigFlags.isCrossLanguageFlag | ConfigFlags.isOutOfBandFlag; if ((bitmap & ~knownFlags) !== 0) { - throw new Error(`unsupported root header bitmap 0x${bitmap.toString(16)}`); + throw new Error( + `unsupported root header bitmap 0x${bitmap.toString(16)}`, + ); } if ((bitmap & ConfigFlags.isCrossLanguageFlag) === 0) { throw new Error("support crosslanguage mode only"); diff --git a/javascript/packages/core/lib/gen/collection.ts b/javascript/packages/core/lib/gen/collection.ts index f2c26be573..bca49665b7 100644 --- a/javascript/packages/core/lib/gen/collection.ts +++ b/javascript/packages/core/lib/gen/collection.ts @@ -27,6 +27,7 @@ import { AnyHelper } from "./any"; import type { ReadContext, WriteContext } from "../context"; const REFERENCE_BYTES = 4; +const COLLECTION_BYTES = 1; export type CompatibleCollectionArrayReadAction = { target: "array" | "list"; @@ -66,7 +67,10 @@ export const CollectionFlags = { SAME_TYPE: 0b1000, }; -function compatibleArrayCollectionExpr(elementTypeId: number, len: string): string { +function compatibleArrayCollectionExpr( + elementTypeId: number, + len: string, +): string { switch (elementTypeId) { case TypeId.BOOL: return `new external.BoolArray(${len})`; @@ -173,7 +177,11 @@ class CollectionAnySerializer { trackingRef = current.needToWriteRef(); } if (isSame) { - if (serializer !== null && serializer !== undefined && current !== serializer) { + if ( + serializer !== null && + serializer !== undefined && + current !== serializer + ) { isSame = false; } else { serializer = current; @@ -205,7 +213,8 @@ class CollectionAnySerializer { if (size === 0) { return; } - const { serializer, isSame, includeNone, trackingRef } = this.writeElementsHeader(value); + const { serializer, isSame, includeNone, trackingRef } = + this.writeElementsHeader(value); if (isSame) { serializer!.writeTypeInfo(value); if (trackingRef) { @@ -231,7 +240,8 @@ class CollectionAnySerializer { } else { if (trackingRef) { for (const item of value) { - const serializer = this.writeContext.typeResolver.getSerializerByData(item); + const serializer = + this.writeContext.typeResolver.getSerializerByData(item); serializer?.writeRef(item); } } else if (includeNone) { @@ -239,14 +249,16 @@ class CollectionAnySerializer { if (item === null || item === undefined) { this.writeContext.writer.writeInt8(RefFlags.NullFlag); } else { - const serializer = this.writeContext.typeResolver.getSerializerByData(item); + const serializer = + this.writeContext.typeResolver.getSerializerByData(item); this.writeContext.writer.writeInt8(RefFlags.NotNullValueFlag); serializer!.writeNoRef(item); } } } else { for (const item of value) { - const serializer = this.writeContext.typeResolver.getSerializerByData(item); + const serializer = + this.writeContext.typeResolver.getSerializerByData(item); serializer!.writeNoRef(item); } } @@ -260,7 +272,9 @@ class CollectionAnySerializer { ): any { void fromRef; const len = this.readContext.reader.readVarUint32Small7(); - this.readContext.reserveContainerMemory(len * REFERENCE_BYTES); + this.readContext.reserveGraphMemory( + COLLECTION_BYTES + len * REFERENCE_BYTES, + ); if (len === 0) { return createCollection(len); } @@ -285,7 +299,11 @@ class CollectionAnySerializer { const refId = this.readContext.reader.readVarUInt32(); accessor(result, i, this.readContext.getReadRef(refId)); } else if (refFlag === RefFlags.RefValueFlag) { - accessor(result, i, this.readSerializerWithDepth(serializer!, true)); + accessor( + result, + i, + this.readSerializerWithDepth(serializer!, true), + ); } else { accessor(result, i, null); } @@ -296,7 +314,11 @@ class CollectionAnySerializer { if (flag === RefFlags.NullFlag) { accessor(result, i, null); } else { - accessor(result, i, this.readSerializerWithDepth(serializer!, false)); + accessor( + result, + i, + this.readSerializerWithDepth(serializer!, false), + ); } } } else { @@ -317,13 +339,21 @@ class CollectionAnySerializer { accessor(result, i, null); } else { const itemSerializer = AnyHelper.detectSerializer(this.readContext); - accessor(result, i, this.readSerializerWithDepth(itemSerializer!, false)); + accessor( + result, + i, + this.readSerializerWithDepth(itemSerializer!, false), + ); } } } else { for (let i = 0; i < len; i++) { const itemSerializer = AnyHelper.detectSerializer(this.readContext); - accessor(result, i, this.readSerializerWithDepth(itemSerializer!, false)); + accessor( + result, + i, + this.readSerializerWithDepth(itemSerializer!, false), + ); } } } @@ -339,7 +369,11 @@ export abstract class CollectionSerializerGenerator extends BaseSerializerGenera super(typeInfo, builder, scope); this.typeInfo = typeInfo; const inner = this.genericTypeDescriptin()!; - this.innerGenerator = CodegenRegistry.newGeneratorByTypeInfo(inner, this.builder, this.scope); + this.innerGenerator = CodegenRegistry.newGeneratorByTypeInfo( + inner, + this.builder, + this.scope, + ); } abstract genericTypeDescriptin(): TypeInfo | undefined; @@ -431,7 +465,10 @@ export abstract class CollectionSerializerGenerator extends BaseSerializerGenera `; } - readSpecificType(accessor: (expr: string) => string, refState: string): string { + readSpecificType( + accessor: (expr: string) => string, + refState: string, + ): string { const result = this.scope.uniqueName("result"); const len = this.scope.uniqueName("len"); const flags = this.scope.uniqueName("flags"); @@ -440,18 +477,27 @@ export abstract class CollectionSerializerGenerator extends BaseSerializerGenera const elemSerializer = this.scope.uniqueName("elemSerializer"); const anyHelper = this.builder.getExternal(AnyHelper.name); const readContextName = this.builder.getReadContextName(); - const useDeclaredStructElementReader = TypeId.structType(this.innerGenerator.getTypeId()!); - const compatibleReadAction = getCompatibleCollectionArrayReadAction(this.typeInfo); + const useDeclaredStructElementReader = TypeId.structType( + this.innerGenerator.getTypeId()!, + ); + const compatibleReadAction = getCompatibleCollectionArrayReadAction( + this.typeInfo, + ); const compatibleListToArray = compatibleReadAction?.target === "array"; const newCollection = compatibleListToArray ? compatibleArrayCollectionExpr(compatibleReadAction!.elementTypeId, len) : this.newCollection(len); const reserveMemory = compatibleListToArray - ? `${readContextName}.reserveContainerMemory(${len} * ${compatibleArrayElementBytes(compatibleReadAction!.elementTypeId)});` - : `${readContextName}.reserveContainerMemory(${len} * ${REFERENCE_BYTES});`; + ? `${readContextName}.reserveGraphMemory(${COLLECTION_BYTES} + ${len} * ${compatibleArrayElementBytes(compatibleReadAction!.elementTypeId)});` + : `${readContextName}.reserveGraphMemory(${COLLECTION_BYTES} + ${len} * ${REFERENCE_BYTES});`; const putAccessor = (item: string, index: string) => compatibleListToArray - ? compatibleArrayPutAccessor(compatibleReadAction!.elementTypeId, result, item, index) + ? compatibleArrayPutAccessor( + compatibleReadAction!.elementTypeId, + result, + item, + index, + ) : this.putAccessor(result, item, index); const rejectCompatiblePayload = compatibleListToArray ? ` @@ -468,7 +514,10 @@ export abstract class CollectionSerializerGenerator extends BaseSerializerGenera const innerReader = useDeclaredStructElementReader ? this.innerGenerator.readEmbed() : this.innerGenerator; - const readInnerElement = (assignStmt: (x: any) => string, refState: string) => { + const readInnerElement = ( + assignStmt: (x: any) => string, + refState: string, + ) => { return innerIsLeaf ? this.innerGenerator.read(assignStmt, refState) : innerReader.readWithDepth(assignStmt, refState); diff --git a/javascript/packages/core/lib/gen/ext.ts b/javascript/packages/core/lib/gen/ext.ts index 5d8af562a4..928e848275 100644 --- a/javascript/packages/core/lib/gen/ext.ts +++ b/javascript/packages/core/lib/gen/ext.ts @@ -25,6 +25,9 @@ import { CodegenRegistry } from "./router"; import { BaseSerializerGenerator } from "./serializer"; import { TypeMeta } from "../meta/TypeMeta"; +const OBJECT_BYTES = 1; +const REFERENCE_BYTES = 4; + class ExtSerializerGenerator extends BaseSerializerGenerator { typeInfo: TypeInfo; typeMeta: TypeMeta; @@ -41,6 +44,13 @@ class ExtSerializerGenerator extends BaseSerializerGenerator { this.ownTypeInfoExpr = `${this.serializerExpr}.getTypeInfo()`; } + private objectGraphBytes(): number { + return ( + OBJECT_BYTES + + Object.keys(this.typeInfo.options?.props ?? {}).length * REFERENCE_BYTES + ); + } + write(accessor: string): string { return ` ${this.builder.getOptions("customSerializer")}.write(${this.builder.getWriteContextName()}, ${accessor}) @@ -50,6 +60,7 @@ class ExtSerializerGenerator extends BaseSerializerGenerator { read(accessor: (expr: string) => string, refState: string): string { const result = this.scope.uniqueName("result"); return ` + ${this.builder.getReadContextName()}.reserveGraphMemory(${this.objectGraphBytes()}); ${ this.typeInfo.options!.withConstructor ? ` @@ -165,7 +176,9 @@ class ExtSerializerGenerator extends BaseSerializerGenerator { let writeUserTypeIdStmt = ""; switch (internalTypeId) { case TypeId.EXT: - writeUserTypeIdStmt = this.builder.writer.writeVarUint32Small7(this.typeInfo.userTypeId); + writeUserTypeIdStmt = this.builder.writer.writeVarUint32Small7( + this.typeInfo.userTypeId, + ); break; case TypeId.NAMED_EXT: if (!this.builder.resolver.isCompatible()) { @@ -191,7 +204,10 @@ class ExtSerializerGenerator extends BaseSerializerGenerator { "typeInfoBytes", `new Uint8Array([${TypeMeta.fromTypeInfo(this.typeInfo, this.builder.resolver).toBytes().join(",")}])`, ); - typeMeta = this.builder.typeMetaResolver.writeTypeMeta(this.builder.getTypeInfo(), bytes); + typeMeta = this.builder.typeMetaResolver.writeTypeMeta( + this.builder.getTypeInfo(), + bytes, + ); } break; default: diff --git a/javascript/packages/core/lib/gen/map.ts b/javascript/packages/core/lib/gen/map.ts index 020b3d51c1..ee402c3882 100644 --- a/javascript/packages/core/lib/gen/map.ts +++ b/javascript/packages/core/lib/gen/map.ts @@ -27,6 +27,7 @@ import { AnyHelper } from "./any"; import { ReadContext, WriteContext } from "../context"; const REFERENCE_BYTES = 4; +const MAP_BYTES = 1; const MapFlags = { /** Whether track elements ref. */ @@ -96,7 +97,11 @@ class MapChunkWriter { return flag; } - private writeHead(keyInfo: ElementInfo, valueInfo: ElementInfo, withOutSize = false) { + private writeHead( + keyInfo: ElementInfo, + valueInfo: ElementInfo, + withOutSize = false, + ) { // KV header const header = this.getHead(keyInfo, valueInfo); // chunkSize default 0 | KV header @@ -144,7 +149,10 @@ class MapChunkWriter { endChunk() { if (this.chunkOffset > 0) { - this.writeContext.writer.setUint8Position(this.chunkOffset, this.chunkSize); + this.writeContext.writer.setUint8Position( + this.chunkOffset, + this.chunkSize, + ); this.chunkSize = 0; } } @@ -201,7 +209,11 @@ class MapAnySerializer { : this.writeContext.typeResolver.getSerializerByData(v); const header = mapChunkWriter.next( - new ElementInfo(keySerializer || null, k == null, keySerializer?.needToWriteRef() || false), + new ElementInfo( + keySerializer || null, + k == null, + keySerializer?.needToWriteRef() || false, + ), new ElementInfo( valueSerializer || null, v == null, @@ -211,7 +223,10 @@ class MapAnySerializer { const keyHeader = header & 0b111; const valueHeader = header >> 3; if (mapChunkWriter.isFirst()) { - if (!(keyHeader & MapFlags.HAS_NULL) && !(valueHeader & MapFlags.HAS_NULL)) { + if ( + !(keyHeader & MapFlags.HAS_NULL) && + !(valueHeader & MapFlags.HAS_NULL) + ) { if (!(keyHeader & MapFlags.DECL_ELEMENT_TYPE)) { keySerializer?.writeTypeInfo(null); } @@ -221,7 +236,8 @@ class MapAnySerializer { } } - const includeNone = keyHeader & MapFlags.HAS_NULL || valueHeader & MapFlags.HAS_NULL; + const includeNone = + keyHeader & MapFlags.HAS_NULL || valueHeader & MapFlags.HAS_NULL; if (!this.writeFlag(keyHeader, k)) { if (!includeNone) { keySerializer!.write(k); @@ -253,28 +269,41 @@ class MapAnySerializer { return null; } if (!trackingRef) { - serializer = serializer == null ? AnyHelper.detectSerializer(this.readContext) : serializer; + serializer = + serializer == null + ? AnyHelper.detectSerializer(this.readContext) + : serializer; return this.readSerializerWithDepth(serializer!, false); } const flag = this.readContext.reader.readInt8(); switch (flag) { case RefFlags.RefValueFlag: - serializer = serializer == null ? AnyHelper.detectSerializer(this.readContext) : serializer; + serializer = + serializer == null + ? AnyHelper.detectSerializer(this.readContext) + : serializer; return this.readSerializerWithDepth(serializer!, true); case RefFlags.RefFlag: - return this.readContext.getReadRef(this.readContext.reader.readVarUInt32()); + return this.readContext.getReadRef( + this.readContext.reader.readVarUInt32(), + ); case RefFlags.NullFlag: return null; case RefFlags.NotNullValueFlag: - serializer = serializer == null ? AnyHelper.detectSerializer(this.readContext) : serializer; + serializer = + serializer == null + ? AnyHelper.detectSerializer(this.readContext) + : serializer; return this.readSerializerWithDepth(serializer!, false); } } read(fromRef: boolean): any { let count = this.readContext.reader.readVarUint32Small7(); - this.readContext.reserveContainerMemory(count * 2 * REFERENCE_BYTES); + this.readContext.reserveGraphMemory( + MAP_BYTES + count * 2 * REFERENCE_BYTES, + ); const result = new Map(); if (fromRef) { this.readContext.reference(result); @@ -292,7 +321,10 @@ class MapAnySerializer { let keySerializer = this.keySerializer; let valueSerializer = this.valueSerializer; - if (!(keyHeader & MapFlags.HAS_NULL) && !(valueHeader & MapFlags.HAS_NULL)) { + if ( + !(keyHeader & MapFlags.HAS_NULL) && + !(valueHeader & MapFlags.HAS_NULL) + ) { if (!(keyHeader & MapFlags.DECL_ELEMENT_TYPE)) { keySerializer = AnyHelper.detectSerializer(this.readContext); } @@ -347,9 +379,13 @@ export class MapSerializerGenerator extends BaseSerializerGenerator { private writeSpecificType(accessor: string) { const k = this.scope.uniqueName("k"); const v = this.scope.uniqueName("v"); - let keyHeader = this.keyGenerator.needToWriteRef() ? MapFlags.TRACKING_REF : 0; + let keyHeader = this.keyGenerator.needToWriteRef() + ? MapFlags.TRACKING_REF + : 0; keyHeader |= MapFlags.DECL_ELEMENT_TYPE; - let valueHeader = this.valueGenerator.needToWriteRef() ? MapFlags.TRACKING_REF : 0; + let valueHeader = this.valueGenerator.needToWriteRef() + ? MapFlags.TRACKING_REF + : 0; valueHeader |= MapFlags.DECL_ELEMENT_TYPE; const lastKeyIsNull = this.scope.uniqueName("lastKeyIsNull"); const lastValueIsNull = this.scope.uniqueName("lastValueIsNull"); @@ -460,7 +496,10 @@ export class MapSerializerGenerator extends BaseSerializerGenerator { }).write(${accessor})`; } - private readSpecificType(accessor: (expr: string) => string, refState: string) { + private readSpecificType( + accessor: (expr: string) => string, + refState: string, + ) { const count = this.scope.uniqueName("count"); const result = this.scope.uniqueName("result"); // Skip depth tracking for leaf key/value types. @@ -494,7 +533,7 @@ export class MapSerializerGenerator extends BaseSerializerGenerator { return ` let ${count} = ${this.builder.reader.readVarUint32Small7()}; - ${readContextName}.reserveContainerMemory(${count} * 2 * ${REFERENCE_BYTES}); + ${readContextName}.reserveGraphMemory(${MAP_BYTES} + ${count} * 2 * ${REFERENCE_BYTES}); const ${result} = new Map(); if (${refState}) { ${this.builder.referenceResolver.reference(result)} diff --git a/javascript/packages/core/lib/gen/struct.ts b/javascript/packages/core/lib/gen/struct.ts index 69ecb5dc1b..bec05a6446 100644 --- a/javascript/packages/core/lib/gen/struct.ts +++ b/javascript/packages/core/lib/gen/struct.ts @@ -25,9 +25,15 @@ import { CodegenRegistry } from "./router"; import { BaseSerializerGenerator, SerializerGenerator } from "./serializer"; import { TypeMeta } from "../meta/TypeMeta"; import { getCompatibleCollectionArrayReadAction } from "./collection"; -import { CompatibleScalarConverter, getCompatibleScalarReadAction } from "../compatible/scalar"; +import { + CompatibleScalarConverter, + getCompatibleScalarReadAction, +} from "../compatible/scalar"; import { shouldSkipCompatibleRead } from "../compatible/field"; +const OBJECT_BYTES = 1; +const REFERENCE_BYTES = 4; + /** * Returns true when a field's read cannot recurse and needs no depth tracking. * Covers leaf scalars, typed arrays, and collections/maps whose elements are all leaf types. @@ -44,7 +50,12 @@ function isDepthFreeField(typeInfo: TypeInfo): boolean { if (id === TypeId.MAP) { const key = typeInfo.options?.key; const value = typeInfo.options?.value; - return !!key && !!value && TypeId.isLeafTypeId(key.typeId) && TypeId.isLeafTypeId(value.typeId); + return ( + !!key && + !!value && + TypeId.isLeafTypeId(key.typeId) && + TypeId.isLeafTypeId(value.typeId) + ); } return false; } @@ -61,7 +72,10 @@ function compatibleReadTargetExpr(typeInfo: TypeInfo, expr: string): string { } } -const sortProps = (typeInfo: TypeInfo, typeResolver: CodecBuilder["resolver"]) => { +const sortProps = ( + typeInfo: TypeInfo, + typeResolver: CodecBuilder["resolver"], +) => { const props = typeInfo.options!.props; if (typeInfo.options!.preserveFieldOrder) { return ( @@ -102,7 +116,10 @@ function toRefMode(trackingRef?: boolean, nullable?: boolean) { } } -function isDirectVarInt32Field(typeInfo: TypeInfo, typeResolver: CodecBuilder["resolver"]) { +function isDirectVarInt32Field( + typeInfo: TypeInfo, + typeResolver: CodecBuilder["resolver"], +) { return varInt32ObjectReadKind(typeInfo, typeResolver) === "number"; } @@ -129,7 +146,10 @@ function varInt32ObjectReadKind( return typeInfo.typeId === TypeId.VARINT32 ? "number" : null; } -function directNumericFieldReadExpr(typeInfo: TypeInfo, builder: CodecBuilder): string | null { +function directNumericFieldReadExpr( + typeInfo: TypeInfo, + builder: CodecBuilder, +): string | null { if ( toRefMode(typeInfo.trackingRef, typeInfo.nullable) !== RefMode.NONE || !builder.resolver.isMonomorphic(typeInfo, typeInfo.dynamic) || @@ -185,7 +205,11 @@ function compatibleScalarFieldReadExpr( builder: CodecBuilder, ): string | null { const converter = builder.getExternal(CompatibleScalarConverter.name); - const remoteRead = compatibleScalarRemoteReadExpr(remoteTypeId, builder, converter); + const remoteRead = compatibleScalarRemoteReadExpr( + remoteTypeId, + builder, + converter, + ); if (remoteRead === null) { return null; } @@ -209,12 +233,22 @@ function compatibleScalarFieldReadExpr( case TypeId.UINT16: case TypeId.UINT32: case TypeId.UINT64: - return scalarToIntegerExpr(remoteCanonical, localCanonical, remoteRead, converter); + return scalarToIntegerExpr( + remoteCanonical, + localCanonical, + remoteRead, + converter, + ); case TypeId.FLOAT16: case TypeId.BFLOAT16: case TypeId.FLOAT32: case TypeId.FLOAT64: - return scalarToFloatExpr(remoteCanonical, localCanonical, remoteRead, converter); + return scalarToFloatExpr( + remoteCanonical, + localCanonical, + remoteRead, + converter, + ); default: return null; } @@ -291,7 +325,11 @@ function compatibleScalarRemoteReadExpr( } } -function scalarToBoolExpr(remoteTypeId: number, value: string, converter: string): string | null { +function scalarToBoolExpr( + remoteTypeId: number, + value: string, + converter: string, +): string | null { switch (remoteTypeId) { case TypeId.BOOL: return value; @@ -309,7 +347,11 @@ function scalarToBoolExpr(remoteTypeId: number, value: string, converter: string } } -function scalarToStringExpr(remoteTypeId: number, value: string, converter: string): string | null { +function scalarToStringExpr( + remoteTypeId: number, + value: string, + converter: string, +): string | null { switch (remoteTypeId) { case TypeId.BOOL: return `(${value} ? "true" : "false")`; @@ -503,7 +545,10 @@ function floatMethod(prefix: string, localTypeId: number): string | null { } } -function integerRangeFitsFloat(remoteTypeId: number, localTypeId: number): boolean { +function integerRangeFitsFloat( + remoteTypeId: number, + localTypeId: number, +): boolean { switch (localTypeId) { case TypeId.FLOAT16: case TypeId.BFLOAT16: @@ -561,6 +606,10 @@ class StructSerializerGenerator extends BaseSerializerGenerator { ); } + private objectGraphBytes(): number { + return OBJECT_BYTES + this.sortedProps.length * REFERENCE_BYTES; + } + readField( fieldName: string, fieldTypeInfo: TypeInfo, @@ -704,24 +753,32 @@ class StructSerializerGenerator extends BaseSerializerGenerator { } write(accessor: string): string { - if (!this.typeInfo.options?.props || Object.keys(this.typeInfo.options.props).length === 0) { + if ( + !this.typeInfo.options?.props || + Object.keys(this.typeInfo.options.props).length === 0 + ) { const hash = this.typeMeta.computeStructHash(); return `${!this.builder.resolver.isCompatible() ? this.builder.writer.writeInt32(hash) : ""}`; } const hash = this.typeMeta.computeStructHash(); const fieldWrites: string[] = []; - for (let i = 0; i < this.sortedProps.length;) { + for (let i = 0; i < this.sortedProps.length; ) { const current = this.sortedProps[i]; if (isDirectVarInt32Field(current.typeInfo, this.builder.resolver)) { let end = i + 1; while ( end < this.sortedProps.length && - isDirectVarInt32Field(this.sortedProps[end].typeInfo, this.builder.resolver) + isDirectVarInt32Field( + this.sortedProps[end].typeInfo, + this.builder.resolver, + ) ) { end++; } if (end - i > 1) { - fieldWrites.push(this.writeVarInt32Run(accessor, this.sortedProps.slice(i, end))); + fieldWrites.push( + this.writeVarInt32Run(accessor, this.sortedProps.slice(i, end)), + ); i = end; continue; } @@ -730,10 +787,19 @@ class StructSerializerGenerator extends BaseSerializerGenerator { if (!InnerGeneratorClass) { throw new Error(`${current.typeInfo.typeId} generator not exists`); } - const innerGenerator = new InnerGeneratorClass(current.typeInfo, this.builder, this.scope); + const innerGenerator = new InnerGeneratorClass( + current.typeInfo, + this.builder, + this.scope, + ); const fieldAccessor = `${accessor}${CodecBuilder.safePropAccessor(current.key)}`; fieldWrites.push( - this.writeField(current.key, current.typeInfo, fieldAccessor, innerGenerator.writeEmbed()), + this.writeField( + current.key, + current.typeInfo, + fieldAccessor, + innerGenerator.writeEmbed(), + ), ); i++; } @@ -743,7 +809,10 @@ class StructSerializerGenerator extends BaseSerializerGenerator { `; } - private writeVarInt32Run(accessor: string, fields: { key: string; typeInfo: TypeInfo }[]) { + private writeVarInt32Run( + accessor: string, + fields: { key: string; typeInfo: TypeInfo }[], + ) { const cursor = this.scope.uniqueName("cursor"); const buffer = this.scope.uniqueName("buffer"); const dataView = this.scope.uniqueName("dataView"); @@ -809,14 +878,35 @@ class StructSerializerGenerator extends BaseSerializerGenerator { read(accessor: (expr: string) => string, refState: string): string { const result = this.scope.uniqueName("result"); - if (!this.typeInfo.options?.props || Object.keys(this.typeInfo.options.props).length === 0) { + const hash = this.typeMeta.computeStructHash(); + if ( + !this.typeInfo.options?.props || + Object.keys(this.typeInfo.options.props).length === 0 + ) { return ` - let ${result} = ${this.serializerExpr}.read(${refState}); + ${ + !this.builder.resolver.isCompatible() + ? ` + if(${this.builder.reader.readInt32()} !== ${hash}) { + throw new Error("Read class version is not consistent with ${hash} ") + } + ` + : "" + } + ${this.builder.getReadContextName()}.reserveGraphMemory(${OBJECT_BYTES}); + ${ + this.typeInfo.options?.withConstructor + ? `const ${result} = new ${this.builder.getOptions("creator")}();` + : `const ${result} = {};` + } + ${this.maybeReference(result, refState)} ${accessor(result)}; `; } - const hash = this.typeMeta.computeStructHash(); - const directNumericObjectRead = this.readDirectNumericObject(accessor, refState); + const directNumericObjectRead = this.readDirectNumericObject( + accessor, + refState, + ); if (directNumericObjectRead !== null) { return ` ${ @@ -841,6 +931,7 @@ class StructSerializerGenerator extends BaseSerializerGenerator { ` : "" } + ${this.builder.getReadContextName()}.reserveGraphMemory(${this.objectGraphBytes()}); ${ this.typeInfo.options!.withConstructor ? ` @@ -850,7 +941,9 @@ class StructSerializerGenerator extends BaseSerializerGenerator { const ${result} = { ${this.sortedProps .map(({ key }) => { - if (shouldSkipCompatibleRead(this.typeInfo.options!.props![key])) { + if ( + shouldSkipCompatibleRead(this.typeInfo.options!.props![key]) + ) { return ""; } return `${CodecBuilder.safePropName(key)}: null`; @@ -867,7 +960,11 @@ class StructSerializerGenerator extends BaseSerializerGenerator { if (!InnerGeneratorClass) { throw new Error(`${typeInfo.typeId} generator not exists`); } - const innerGenerator = new InnerGeneratorClass(typeInfo, this.builder, this.scope); + const innerGenerator = new InnerGeneratorClass( + typeInfo, + this.builder, + this.scope, + ); return ` ${this.readField(key, typeInfo, (expr) => `${result}${CodecBuilder.safePropAccessor(key)} = ${expr}`, innerGenerator.readEmbed())} `; @@ -881,11 +978,17 @@ class StructSerializerGenerator extends BaseSerializerGenerator { accessor: (expr: string) => string, refState: string, ): string | null { - const varInt32ObjectRead = this.readDirectVarInt32Object(accessor, refState); + const varInt32ObjectRead = this.readDirectVarInt32Object( + accessor, + refState, + ); if (varInt32ObjectRead !== null) { return varInt32ObjectRead; } - if (this.typeInfo.options!.withConstructor || this.sortedProps.length === 0) { + if ( + this.typeInfo.options!.withConstructor || + this.sortedProps.length === 0 + ) { return null; } const fields: Array<{ key: string; expr: string }> = []; @@ -911,6 +1014,7 @@ class StructSerializerGenerator extends BaseSerializerGenerator { } const result = this.scope.uniqueName("result"); return ` + ${this.builder.getReadContextName()}.reserveGraphMemory(${this.objectGraphBytes()}); const ${result} = { ${fields.map(({ key, expr }) => `${CodecBuilder.safePropName(key)}: ${expr}`).join(",\n")} }; @@ -923,7 +1027,10 @@ class StructSerializerGenerator extends BaseSerializerGenerator { accessor: (expr: string) => string, refState: string, ): string | null { - if (this.typeInfo.options!.withConstructor || this.sortedProps.length === 0) { + if ( + this.typeInfo.options!.withConstructor || + this.sortedProps.length === 0 + ) { return null; } const fields = []; @@ -1003,6 +1110,7 @@ class StructSerializerGenerator extends BaseSerializerGenerator { let ${value}; ${reads} ${this.builder.reader.readSetCursor(cursor)} + ${this.builder.getReadContextName()}.reserveGraphMemory(${this.objectGraphBytes()}); const ${result} = { ${fields.map(({ key, local }) => `${CodecBuilder.safePropName(key)}: ${local}`).join(",\n")} }; @@ -1012,7 +1120,10 @@ class StructSerializerGenerator extends BaseSerializerGenerator { } readWithDepth(assignStmt: (v: string) => string, refState: string): string { - if (!this.typeInfo.options?.props || Object.keys(this.typeInfo.options.props).length === 0) { + if ( + !this.typeInfo.options?.props || + Object.keys(this.typeInfo.options.props).length === 0 + ) { const result = this.scope.uniqueName("result"); return ` ${this.builder.getReadContextName()}.incReadDepth(); @@ -1026,7 +1137,10 @@ class StructSerializerGenerator extends BaseSerializerGenerator { readNoRef(assignStmt: (v: string) => string, refState: string): string { const result = this.scope.uniqueName("result"); - if (!this.typeInfo.options?.props || Object.keys(this.typeInfo.options.props).length === 0) { + if ( + !this.typeInfo.options?.props || + Object.keys(this.typeInfo.options.props).length === 0 + ) { return this.readTypeInfoThen( (changedSerializer) => ` ${assignStmt(`${changedSerializer}.read(${refState})`)}; @@ -1143,8 +1257,10 @@ class StructSerializerGenerator extends BaseSerializerGenerator { const canInlineCompatibleTypeInfo = internalTypeId === TypeId.COMPATIBLE_STRUCT || internalTypeId === TypeId.NAMED_COMPATIBLE_STRUCT || - (internalTypeId === TypeId.NAMED_STRUCT && builder.resolver.isCompatible()); - const canUseHeaderCacheFastPath = canInlineCompatibleTypeInfo && serializer?._initialized; + (internalTypeId === TypeId.NAMED_STRUCT && + builder.resolver.isCompatible()); + const canUseHeaderCacheFastPath = + canInlineCompatibleTypeInfo && serializer?._initialized; const inlineCompatibleTypeInfo = ( onMetaChanged: (changedSerializer: string) => string, onMetaUnchanged: () => string, @@ -1181,7 +1297,8 @@ class StructSerializerGenerator extends BaseSerializerGenerator { const result = scope.uniqueName("result"); return ` ${inlineCompatibleTypeInfo( - (changedSerializer) => `${accessor(`${changedSerializer}.read(${refState})`)};`, + (changedSerializer) => + `${accessor(`${changedSerializer}.read(${refState})`)};`, () => ` ${builder.getReadContextName()}.incReadDepth(); let ${result} = ${hoisted}.read(${refState}); @@ -1283,13 +1400,18 @@ class StructSerializerGenerator extends BaseSerializerGenerator { let writeUserTypeIdStmt = ""; switch (internalTypeId) { case TypeId.STRUCT: - writeUserTypeIdStmt = this.builder.writer.writeVarUint32Small7(this.typeInfo.userTypeId); + writeUserTypeIdStmt = this.builder.writer.writeVarUint32Small7( + this.typeInfo.userTypeId, + ); break; case TypeId.NAMED_COMPATIBLE_STRUCT: case TypeId.COMPATIBLE_STRUCT: { const bytes = this.typeMetaBytesExpr(); - typeMeta = this.builder.typeMetaResolver.writeTypeMeta(this.builder.getTypeInfo(), bytes); + typeMeta = this.builder.typeMetaResolver.writeTypeMeta( + this.builder.getTypeInfo(), + bytes, + ); } break; case TypeId.NAMED_STRUCT: @@ -1337,17 +1459,25 @@ class StructSerializerGenerator extends BaseSerializerGenerator { let fixedSize = 8; if (options!.props) { Object.values(options!.props).forEach((x) => { - const propGenerator = new (CodegenRegistry.get(x.typeId)!)(x, this.builder, this.scope); + const propGenerator = new (CodegenRegistry.get(x.typeId)!)( + x, + this.builder, + this.scope, + ); fixedSize += propGenerator.getFixedSize(); }); } else { - fixedSize += this.builder.resolver.getSerializerByName(typeInfo.named!)!.fixedSize; + fixedSize += this.builder.resolver.getSerializerByName( + typeInfo.named!, + )!.fixedSize; } return fixedSize; } getHash(): string { - return TypeMeta.fromTypeInfo(this.typeInfo, this.builder.resolver).getHash().toString(); + return TypeMeta.fromTypeInfo(this.typeInfo, this.builder.resolver) + .getHash() + .toString(); } getTypeMetaBytes(): string { @@ -1365,4 +1495,7 @@ class StructSerializerGenerator extends BaseSerializerGenerator { CodegenRegistry.register(TypeId.STRUCT, StructSerializerGenerator); CodegenRegistry.register(TypeId.NAMED_STRUCT, StructSerializerGenerator); CodegenRegistry.register(TypeId.COMPATIBLE_STRUCT, StructSerializerGenerator); -CodegenRegistry.register(TypeId.NAMED_COMPATIBLE_STRUCT, StructSerializerGenerator); +CodegenRegistry.register( + TypeId.NAMED_COMPATIBLE_STRUCT, + StructSerializerGenerator, +); diff --git a/javascript/packages/core/lib/type.ts b/javascript/packages/core/lib/type.ts index 6acf24bc96..eb65815987 100644 --- a/javascript/packages/core/lib/type.ts +++ b/javascript/packages/core/lib/type.ts @@ -291,7 +291,7 @@ export interface Config { ref: boolean; useSliceString: boolean; maxDepth?: number; - maxContainerMemoryBytes: number; + maxGraphMemoryBytes: number; maxTypeFields: number; maxTypeMetaBytes: number; maxSchemaVersionsPerType: number; diff --git a/javascript/test/containerMemoryBudget.test.ts b/javascript/test/containerMemoryBudget.test.ts deleted file mode 100644 index 499d4750b6..0000000000 --- a/javascript/test/containerMemoryBudget.test.ts +++ /dev/null @@ -1,226 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -import Fory, { Type } from '../packages/core/index'; -import { describe, expect, test } from '@jest/globals'; - -const KNOWN_SLACK_BYTES = 64 * 1024; - -function serializeAny(value: unknown) { - return new Fory({ compatible: false, ref: true }).serialize(value); -} - -function deserializeAny(bytes: Uint8Array, maxContainerMemoryBytes: number) { - return new Fory({ - compatible: false, - ref: true, - maxContainerMemoryBytes, - }).deserialize(bytes); -} - -describe('container memory budget', () => { - test('uses known length auto budget', () => { - const inputBytes = 17; - const fory = new Fory({ compatible: false }); - const budget = inputBytes * 8 + KNOWN_SLACK_BYTES; - - fory.readContext.reset(new Uint8Array(inputBytes)); - expect(() => fory.readContext.reserveContainerMemory(budget)).not.toThrow(); - expect(() => fory.readContext.reserveContainerMemory(1)).toThrow( - /maxContainerMemoryBytes/, - ); - }); - - test('validates explicit config', () => { - expect(() => new Fory({ maxContainerMemoryBytes: 0 })).toThrow( - /maxContainerMemoryBytes/, - ); - expect(() => new Fory({ maxContainerMemoryBytes: -2 })).toThrow( - /maxContainerMemoryBytes/, - ); - - const fory = new Fory({ maxContainerMemoryBytes: 24 }); - fory.readContext.reset(new Uint8Array(1)); - expect(() => fory.readContext.reserveContainerMemory(0)).not.toThrow(); - expect(() => fory.readContext.reserveContainerMemory(24)).not.toThrow(); - expect(() => fory.readContext.reserveContainerMemory(1)).toThrow( - /maxContainerMemoryBytes/, - ); - }); - - test('uses parent storage for nested empty containers', () => { - const typeInfo = Type.struct('budget.nested.empty', { - values: Type.list(Type.list(Type.int32({ encoding: 'fixed' }))).setId(1), - }); - const writer = new Fory({ compatible: false, ref: true }); - const bytes = writer.register(typeInfo).serialize({ values: [[]] }); - const passingReader = new Fory({ - compatible: false, - ref: true, - maxContainerMemoryBytes: 4, - }).register(typeInfo); - const failingReader = new Fory({ - compatible: false, - ref: true, - maxContainerMemoryBytes: 3, - }).register(typeInfo); - - expect(() => failingReader.deserialize(bytes)).toThrow( - /maxContainerMemoryBytes/, - ); - expect(passingReader.deserialize(bytes)).toEqual({ values: [[]] }); - }); - - test('reserves sibling containers cumulatively', () => { - const typeInfo = Type.struct('budget.sibling.empty', { - values: Type.list(Type.list(Type.int32({ encoding: 'fixed' }))).setId(1), - }); - const writer = new Fory({ compatible: false, ref: true }); - const bytes = writer.register(typeInfo).serialize({ - values: [[], [], []], - }); - const passingReader = new Fory({ - compatible: false, - ref: true, - maxContainerMemoryBytes: 12, - }).register(typeInfo); - const failingReader = new Fory({ - compatible: false, - ref: true, - maxContainerMemoryBytes: 11, - }).register(typeInfo); - - expect(() => failingReader.deserialize(bytes)).toThrow( - /maxContainerMemoryBytes/, - ); - expect(passingReader.deserialize(bytes)).toEqual({ - values: [[], [], []], - }); - }); - - test('reserves map entries', () => { - const bytes = serializeAny(new Map([[1, 2]])); - - expect(() => deserializeAny(bytes, 7)).toThrow(/maxContainerMemoryBytes/); - expect(deserializeAny(bytes, 8)).toEqual(new Map([[1, 2]])); - }); - - test('reserves generated containers', () => { - const typeInfo = Type.struct('budget.generated', { - list: Type.list(Type.int32({ encoding: 'fixed' })).setId(1), - set: Type.set(Type.string()).setId(2), - map: Type.map(Type.string(), Type.int32({ encoding: 'fixed' })).setId(3), - }); - const writer = new Fory({ compatible: false, ref: true }); - const bytes = writer.register(typeInfo).serialize({ - list: [1], - set: new Set(['a']), - map: new Map([['k', 1]]), - }); - const passingReader = new Fory({ - compatible: false, - ref: true, - maxContainerMemoryBytes: 16, - }).register(typeInfo); - const failingReader = new Fory({ - compatible: false, - ref: true, - maxContainerMemoryBytes: 15, - }).register(typeInfo); - - expect(() => failingReader.deserialize(bytes)).toThrow( - /maxContainerMemoryBytes/, - ); - expect(passingReader.deserialize(bytes)).toEqual({ - list: [1], - set: new Set(['a']), - map: new Map([['k', 1]]), - }); - }); - - test('reserves compatible typed arrays', () => { - const writerType = Type.struct(9010, { - values: Type.list(Type.int32({ encoding: 'fixed' })).setId(1), - }); - const readerType = Type.struct(9010, { - values: Type.int32Array().setId(1), - }); - const writer = new Fory({ compatible: true }); - const bytes = writer.register(writerType).serialize({ values: [1, 2, 3] }); - const passingReader = new Fory({ - compatible: true, - maxContainerMemoryBytes: 12, - }).register(readerType); - const failingReader = new Fory({ - compatible: true, - maxContainerMemoryBytes: 11, - }).register(readerType); - - expect(() => failingReader.deserialize(bytes)).toThrow( - /maxContainerMemoryBytes/, - ); - expect(Array.from(passingReader.deserialize(bytes).values)).toEqual([ - 1, - 2, - 3, - ]); - }); - - test('skips scalar dense owners', () => { - const typeInfo = Type.struct('budget.skipped', { - text: Type.string().setId(1), - binary: Type.binary().setId(2), - values: Type.int32Array().setId(3), - }); - const writer = new Fory({ compatible: false, ref: true }); - const bytes = writer.register(typeInfo).serialize({ - text: 'hello', - binary: new Uint8Array([1, 2, 3]), - values: new Int32Array([1, 2, 3]), - }); - const reader = new Fory({ - compatible: false, - ref: true, - maxContainerMemoryBytes: 1, - }).register(typeInfo); - - expect(reader.deserialize(bytes)).toEqual({ - text: 'hello', - binary: new Uint8Array([1, 2, 3]), - values: new Int32Array([1, 2, 3]), - }); - }); - - test('keeps byte checks', () => { - const typeInfo = Type.struct('budget.bytecheck', { - values: Type.int32Array().setId(1), - }); - const writer = new Fory({ compatible: false, ref: true }); - const bytes = writer.register(typeInfo).serialize({ - values: new Int32Array([1, 2, 3]), - }); - const reader = new Fory({ - compatible: false, - ref: true, - maxContainerMemoryBytes: 1024 * 1024, - }).register(typeInfo); - - expect(() => reader.deserialize(bytes.slice(0, bytes.length - 1))).toThrow(); - }); -}); diff --git a/javascript/test/graphMemoryBudget.test.ts b/javascript/test/graphMemoryBudget.test.ts new file mode 100644 index 0000000000..d3bb3ce912 --- /dev/null +++ b/javascript/test/graphMemoryBudget.test.ts @@ -0,0 +1,310 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +import Fory, { Type } from "../packages/core/index"; +import { describe, expect, test } from "@jest/globals"; + +const KNOWN_SLACK_BYTES = 64 * 1024; +const OBJECT_BYTES = 1; +const REFERENCE_BYTES = 4; + +const objectBytes = (fields: number) => OBJECT_BYTES + fields * REFERENCE_BYTES; +const listBytes = (count: number) => OBJECT_BYTES + count * REFERENCE_BYTES; +const mapBytes = (count: number) => OBJECT_BYTES + count * 2 * REFERENCE_BYTES; + +function serializeAny(value: unknown) { + return new Fory({ compatible: false, ref: true }).serialize(value); +} + +function deserializeAny(bytes: Uint8Array, maxGraphMemoryBytes: number) { + return new Fory({ + compatible: false, + ref: true, + maxGraphMemoryBytes, + }).deserialize(bytes); +} + +describe("graph memory budget", () => { + test("uses known length auto budget", () => { + const inputBytes = 17; + const fory = new Fory({ compatible: false }); + const budget = inputBytes * 8 + KNOWN_SLACK_BYTES; + + fory.readContext.reset(new Uint8Array(inputBytes)); + expect(() => fory.readContext.reserveGraphMemory(budget)).not.toThrow(); + expect(() => fory.readContext.reserveGraphMemory(1)).toThrow( + /maxGraphMemoryBytes/, + ); + }); + + test("validates explicit config", () => { + expect(() => new Fory({ maxGraphMemoryBytes: 0 })).toThrow( + /maxGraphMemoryBytes/, + ); + expect(() => new Fory({ maxGraphMemoryBytes: -2 })).toThrow( + /maxGraphMemoryBytes/, + ); + + const fory = new Fory({ maxGraphMemoryBytes: 24 }); + fory.readContext.reset(new Uint8Array(1)); + expect(() => fory.readContext.reserveGraphMemory(0)).not.toThrow(); + expect(() => fory.readContext.reserveGraphMemory(24)).not.toThrow(); + expect(() => fory.readContext.reserveGraphMemory(1)).toThrow( + /maxGraphMemoryBytes/, + ); + }); + + test("uses parent storage for nested empty containers", () => { + const typeInfo = Type.struct("budget.nested.empty", { + values: Type.list(Type.list(Type.int32({ encoding: "fixed" }))).setId(1), + }); + const writer = new Fory({ compatible: false, ref: true }); + const bytes = writer.register(typeInfo).serialize({ values: [[]] }); + const passingReader = new Fory({ + compatible: false, + ref: true, + maxGraphMemoryBytes: objectBytes(1) + listBytes(1) + listBytes(0), + }).register(typeInfo); + const failingReader = new Fory({ + compatible: false, + ref: true, + maxGraphMemoryBytes: objectBytes(1) + listBytes(1) + listBytes(0) - 1, + }).register(typeInfo); + + expect(() => failingReader.deserialize(bytes)).toThrow( + /maxGraphMemoryBytes/, + ); + expect(passingReader.deserialize(bytes)).toEqual({ values: [[]] }); + }); + + test("reserves sibling containers cumulatively", () => { + const typeInfo = Type.struct("budget.sibling.empty", { + values: Type.list(Type.list(Type.int32({ encoding: "fixed" }))).setId(1), + }); + const writer = new Fory({ compatible: false, ref: true }); + const bytes = writer.register(typeInfo).serialize({ + values: [[], [], []], + }); + const passingReader = new Fory({ + compatible: false, + ref: true, + maxGraphMemoryBytes: objectBytes(1) + listBytes(3) + 3 * listBytes(0), + }).register(typeInfo); + const failingReader = new Fory({ + compatible: false, + ref: true, + maxGraphMemoryBytes: objectBytes(1) + listBytes(3) + 3 * listBytes(0) - 1, + }).register(typeInfo); + + expect(() => failingReader.deserialize(bytes)).toThrow( + /maxGraphMemoryBytes/, + ); + expect(passingReader.deserialize(bytes)).toEqual({ + values: [[], [], []], + }); + }); + + test("reserves empty generated object owner", () => { + const childType = Type.struct("budget.empty.object", {}); + const typeInfo = Type.struct("budget.empty.parent", { + first: Type.struct("budget.empty.object").setId(1), + second: Type.struct("budget.empty.object").setId(2), + }); + const writer = new Fory({ compatible: false, ref: true }); + writer.register(childType); + const bytes = writer.register(typeInfo).serialize({ + first: {}, + second: {}, + }); + const required = objectBytes(2) + 2 * OBJECT_BYTES; + const passingReader = new Fory({ + compatible: false, + ref: true, + maxGraphMemoryBytes: required, + }); + passingReader.register(childType); + passingReader.register(typeInfo); + const failingReader = new Fory({ + compatible: false, + ref: true, + maxGraphMemoryBytes: required - 1, + }); + failingReader.register(childType); + failingReader.register(typeInfo); + + expect(() => failingReader.deserialize(bytes)).toThrow( + /maxGraphMemoryBytes/, + ); + expect(passingReader.deserialize(bytes)).toEqual({ + first: {}, + second: {}, + }); + }); + + test("preserves constructor for empty generated object owner", () => { + class EmptyChild {} + Type.struct("budget.empty.ctor.child", {})(EmptyChild); + class EmptyParent { + child = new EmptyChild(); + } + Type.struct("budget.empty.ctor.parent", { + child: Type.struct("budget.empty.ctor.child").setId(1), + })(EmptyParent); + + const writer = new Fory({ compatible: false, ref: true }); + writer.register(EmptyChild); + const bytes = writer.register(EmptyParent).serialize(new EmptyParent()); + const required = objectBytes(1) + OBJECT_BYTES; + const passingReader = new Fory({ + compatible: false, + ref: true, + maxGraphMemoryBytes: required, + }); + passingReader.register(EmptyChild); + passingReader.register(EmptyParent); + const failingReader = new Fory({ + compatible: false, + ref: true, + maxGraphMemoryBytes: required - 1, + }); + failingReader.register(EmptyChild); + failingReader.register(EmptyParent); + + expect(() => failingReader.deserialize(bytes)).toThrow( + /maxGraphMemoryBytes/, + ); + const decoded = passingReader.deserialize(bytes); + expect(decoded).toBeInstanceOf(EmptyParent); + expect(decoded.child).toBeInstanceOf(EmptyChild); + }); + + test("reserves map entries", () => { + const bytes = serializeAny(new Map([[1, 2]])); + + expect(() => deserializeAny(bytes, mapBytes(1) - 1)).toThrow( + /maxGraphMemoryBytes/, + ); + expect(deserializeAny(bytes, mapBytes(1))).toEqual(new Map([[1, 2]])); + }); + + test("reserves generated containers", () => { + const typeInfo = Type.struct("budget.generated", { + list: Type.list(Type.int32({ encoding: "fixed" })).setId(1), + set: Type.set(Type.string()).setId(2), + map: Type.map(Type.string(), Type.int32({ encoding: "fixed" })).setId(3), + }); + const writer = new Fory({ compatible: false, ref: true }); + const bytes = writer.register(typeInfo).serialize({ + list: [1], + set: new Set(["a"]), + map: new Map([["k", 1]]), + }); + const passingReader = new Fory({ + compatible: false, + ref: true, + maxGraphMemoryBytes: + objectBytes(3) + listBytes(1) + listBytes(1) + mapBytes(1), + }).register(typeInfo); + const failingReader = new Fory({ + compatible: false, + ref: true, + maxGraphMemoryBytes: + objectBytes(3) + listBytes(1) + listBytes(1) + mapBytes(1) - 1, + }).register(typeInfo); + + expect(() => failingReader.deserialize(bytes)).toThrow( + /maxGraphMemoryBytes/, + ); + expect(passingReader.deserialize(bytes)).toEqual({ + list: [1], + set: new Set(["a"]), + map: new Map([["k", 1]]), + }); + }); + + test("reserves compatible typed arrays", () => { + const writerType = Type.struct(9010, { + values: Type.list(Type.int32({ encoding: "fixed" })).setId(1), + }); + const readerType = Type.struct(9010, { + values: Type.int32Array().setId(1), + }); + const writer = new Fory({ compatible: true }); + const bytes = writer.register(writerType).serialize({ values: [1, 2, 3] }); + const passingReader = new Fory({ + compatible: true, + maxGraphMemoryBytes: objectBytes(1) + OBJECT_BYTES + 12, + }).register(readerType); + const failingReader = new Fory({ + compatible: true, + maxGraphMemoryBytes: objectBytes(1) + OBJECT_BYTES + 12 - 1, + }).register(readerType); + + expect(() => failingReader.deserialize(bytes)).toThrow( + /maxGraphMemoryBytes/, + ); + expect(Array.from(passingReader.deserialize(bytes).values)).toEqual([ + 1, 2, 3, + ]); + }); + + test("skips scalar dense owners", () => { + const typeInfo = Type.struct("budget.skipped", { + text: Type.string().setId(1), + binary: Type.binary().setId(2), + values: Type.int32Array().setId(3), + }); + const writer = new Fory({ compatible: false, ref: true }); + const bytes = writer.register(typeInfo).serialize({ + text: "hello", + binary: new Uint8Array([1, 2, 3]), + values: new Int32Array([1, 2, 3]), + }); + const reader = new Fory({ + compatible: false, + ref: true, + maxGraphMemoryBytes: objectBytes(3), + }).register(typeInfo); + + expect(reader.deserialize(bytes)).toEqual({ + text: "hello", + binary: new Uint8Array([1, 2, 3]), + values: new Int32Array([1, 2, 3]), + }); + }); + + test("keeps byte checks", () => { + const typeInfo = Type.struct("budget.bytecheck", { + values: Type.int32Array().setId(1), + }); + const writer = new Fory({ compatible: false, ref: true }); + const bytes = writer.register(typeInfo).serialize({ + values: new Int32Array([1, 2, 3]), + }); + const reader = new Fory({ + compatible: false, + ref: true, + maxGraphMemoryBytes: 1024 * 1024, + }).register(typeInfo); + + expect(() => + reader.deserialize(bytes.slice(0, bytes.length - 1)), + ).toThrow(); + }); +}); diff --git a/kotlin/fory-kotlin-ksp/src/main/kotlin/org/apache/fory/kotlin/ksp/KotlinSerializerSourceWriter.kt b/kotlin/fory-kotlin-ksp/src/main/kotlin/org/apache/fory/kotlin/ksp/KotlinSerializerSourceWriter.kt index 7517ca54ac..f93c17b860 100644 --- a/kotlin/fory-kotlin-ksp/src/main/kotlin/org/apache/fory/kotlin/ksp/KotlinSerializerSourceWriter.kt +++ b/kotlin/fory-kotlin-ksp/src/main/kotlin/org/apache/fory/kotlin/ksp/KotlinSerializerSourceWriter.kt @@ -391,6 +391,7 @@ internal class KotlinSerializerSourceWriter(private val struct: KotlinSourceStru .append(" private fun readSchemaConstructor(readContext: ReadContext): ") .append(struct.typeName) .append(" {\n") + builder.append(" reserveObjectGraphMemory(readContext)\n") builder.append(" val fieldValues = arrayOfNulls(DESCRIPTORS.size)\n") builder.append(" val bufferedFields = newFieldBits(DESCRIPTORS.size)\n") builder.append(" beginConstructorRef(readContext)\n") @@ -654,6 +655,7 @@ internal class KotlinSerializerSourceWriter(private val struct: KotlinSourceStru } private fun writeMutableReadBody() { + builder.append(" reserveObjectGraphMemory(readContext)\n") builder.append(" val value = ").append(struct.typeName).append("()\n") builder.append(" if (readContext.hasPreservedRefId()) {\n") builder.append(" readContext.reference(value)\n") @@ -700,6 +702,7 @@ internal class KotlinSerializerSourceWriter(private val struct: KotlinSourceStru builder.append(" }\n\n") return } + builder.append(" reserveObjectGraphMemory(readContext)\n") writeCompatibleValueReadBody(" ", constructorRefs = false) builder.append(" }\n\n") } @@ -709,6 +712,7 @@ internal class KotlinSerializerSourceWriter(private val struct: KotlinSourceStru .append(" private fun readCompatibleConstructor(readContext: ReadContext): ") .append(struct.typeName) .append(" {\n") + builder.append(" reserveObjectGraphMemory(readContext)\n") builder.append(" beginConstructorRef(readContext)\n") builder.append(" try {\n") writeCompatibleValueReadBody(" ", constructorRefs = true) @@ -829,6 +833,7 @@ internal class KotlinSerializerSourceWriter(private val struct: KotlinSourceStru private fun writeMutableCompatibleReadBody() { writePresenceVars() + builder.append(" reserveObjectGraphMemory(readContext)\n") builder.append(" val value = ").append(struct.typeName).append("()\n") builder.append(" if (readContext.hasPreservedRefId()) {\n") builder.append(" readContext.reference(value)\n") diff --git a/kotlin/fory-kotlin/src/main/kotlin/org/apache/fory/serializer/kotlin/CollectionSerializer.kt b/kotlin/fory-kotlin/src/main/kotlin/org/apache/fory/serializer/kotlin/CollectionSerializer.kt index be2980313f..9b6498e21a 100644 --- a/kotlin/fory-kotlin/src/main/kotlin/org/apache/fory/serializer/kotlin/CollectionSerializer.kt +++ b/kotlin/fory-kotlin/src/main/kotlin/org/apache/fory/serializer/kotlin/CollectionSerializer.kt @@ -56,7 +56,7 @@ public class KotlinArrayDequeSerializer( } override fun newCollection(readContext: ReadContext): Collection { - val numElements = readCollectionSize(readContext) + val numElements = readCollectionSize(readContext, readContext.buffer) setNumElements(numElements) return ArrayDequeBuilder(ArrayDeque(numElements)) } diff --git a/kotlin/fory-kotlin/src/test/kotlin/org/apache/fory/serializer/kotlin/CollectionSerializerTest.kt b/kotlin/fory-kotlin/src/test/kotlin/org/apache/fory/serializer/kotlin/CollectionSerializerTest.kt index 92749a2506..c66dc6582a 100644 --- a/kotlin/fory-kotlin/src/test/kotlin/org/apache/fory/serializer/kotlin/CollectionSerializerTest.kt +++ b/kotlin/fory-kotlin/src/test/kotlin/org/apache/fory/serializer/kotlin/CollectionSerializerTest.kt @@ -36,18 +36,18 @@ class CollectionSerializerTest { } @Test - fun testArrayDequeContainerMemoryBudget() { + fun testArrayDequeGraphMemoryBudget() { val writer: Fory = ForyKotlin.builder().withXlang(false).requireClassRegistration(true).build() val reader: Fory = ForyKotlin.builder() .withXlang(false) .requireClassRegistration(true) - .withMaxContainerMemoryBytes(23) + .withMaxGraphMemoryBytes(23) .build() try { reader.deserialize(writer.serialize(ArrayDeque(listOf(1, 2, 3, 4, 5, 6)))) - fail("Expected container memory budget failure") + fail("Expected graph memory budget failure") } catch (ignored: InsecureException) {} } diff --git a/python/pyfory/_fory.py b/python/pyfory/_fory.py index 0d6f4d0591..9a3e65f6de 100644 --- a/python/pyfory/_fory.py +++ b/python/pyfory/_fory.py @@ -124,7 +124,7 @@ class Fory: "strict", "buffer", "max_depth", - "max_container_memory_bytes", + "max_graph_memory_bytes", "field_nullable", "policy", ) @@ -140,7 +140,7 @@ def __init__( max_type_meta_bytes: int = 4096, max_schema_versions_per_type: int = 10, max_average_schema_versions_per_type: int = 3, - max_container_memory_bytes: int = -1, + max_graph_memory_bytes: int = -1, policy: DeserializationPolicy = None, field_nullable: bool = False, meta_compressor=None, @@ -185,7 +185,7 @@ def __init__( max_average_schema_versions_per_type: Average remote metadata versions allowed across accepted remote types. - max_container_memory_bytes: Maximum estimated container-owned memory per root + max_graph_memory_bytes: Maximum estimated graph memory per root deserialization. `-1` means auto; positive values are explicit byte limits. policy: Custom deserialization policy for security checks. When provided, @@ -219,12 +219,12 @@ def __init__( if not isinstance(max_average_schema_versions_per_type, int) or max_average_schema_versions_per_type <= 0: raise ValueError("max_average_schema_versions_per_type must be a positive integer") if ( - not isinstance(max_container_memory_bytes, int) - or (max_container_memory_bytes != -1 and max_container_memory_bytes <= 0) - or max_container_memory_bytes > (1 << 63) - 1 + not isinstance(max_graph_memory_bytes, int) + or (max_graph_memory_bytes != -1 and max_graph_memory_bytes <= 0) + or max_graph_memory_bytes > (1 << 63) - 1 ): - raise ValueError("max_container_memory_bytes must be -1 or a positive 63-bit integer") - self.max_container_memory_bytes = max_container_memory_bytes + raise ValueError("max_graph_memory_bytes must be -1 or a positive 63-bit integer") + self.max_graph_memory_bytes = max_graph_memory_bytes self.config = Config( xlang=xlang, track_ref=ref, @@ -237,7 +237,7 @@ def __init__( max_type_meta_bytes=max_type_meta_bytes, max_schema_versions_per_type=max_schema_versions_per_type, max_average_schema_versions_per_type=max_average_schema_versions_per_type, - max_container_memory_bytes=max_container_memory_bytes, + max_graph_memory_bytes=max_graph_memory_bytes, field_nullable=field_nullable, policy=self.policy, meta_compressor=meta_compressor, diff --git a/python/pyfory/collection.pxi b/python/pyfory/collection.pxi index c4dd89a0b4..807805dcbb 100644 --- a/python/pyfory/collection.pxi +++ b/python/pyfory/collection.pxi @@ -42,6 +42,7 @@ cdef int8_t NULL_KEY_VALUE_DECL_TYPE_TRACKING_REF = KEY_HAS_NULL | VALUE_DECL_TY cdef int8_t NULL_VALUE_KEY_DECL_TYPE = VALUE_HAS_NULL | KEY_DECL_TYPE cdef int8_t NULL_VALUE_KEY_DECL_TYPE_TRACKING_REF = VALUE_HAS_NULL | KEY_DECL_TYPE | TRACKING_KEY_REF cdef int64_t _REFERENCE_BYTES = sizeof(PyObject*) +cdef int64_t _OWNER_BYTES = 1 ctypedef PyObject *PyObjectPtr cdef class ListSerializer @@ -467,22 +468,21 @@ cdef class ListSerializer(CollectionSerializer): cdef int8_t head_flag cdef int32_t ref_id cdef int64_t i - cdef int64_t container_bytes - cdef int64_t remaining_container_memory_bytes - if len_ == 0: - list_ = PyList_New(0) - return list_ + cdef int64_t graph_bytes + cdef int64_t remaining_graph_memory_bytes if len_ < 0: raise ValueError("Container element count is negative") + graph_bytes = _OWNER_BYTES + len_ * _REFERENCE_BYTES + remaining_graph_memory_bytes = read_context.remaining_graph_memory_bytes + if graph_bytes > remaining_graph_memory_bytes: + read_context.reserve_graph_memory_fast(graph_bytes) else: - container_bytes = len_ * sizeof(PyObject*) - remaining_container_memory_bytes = read_context.remaining_container_memory_bytes - if container_bytes > remaining_container_memory_bytes: - read_context.reserve_container_memory_fast(container_bytes) - else: - read_context.remaining_container_memory_bytes = ( - remaining_container_memory_bytes - container_bytes - ) + read_context.remaining_graph_memory_bytes = ( + remaining_graph_memory_bytes - graph_bytes + ) + if len_ == 0: + list_ = PyList_New(0) + return list_ read_context.check_readable_bytes(len_) collect_flag = buffer.read_int8() @@ -596,22 +596,21 @@ cdef class TupleSerializer(CollectionSerializer): cdef bint has_null cdef int8_t head_flag cdef int64_t i - cdef int64_t container_bytes - cdef int64_t remaining_container_memory_bytes - if len_ == 0: - tuple_ = PyTuple_New(0) - return tuple_ + cdef int64_t graph_bytes + cdef int64_t remaining_graph_memory_bytes if len_ < 0: raise ValueError("Container element count is negative") + graph_bytes = _OWNER_BYTES + len_ * _REFERENCE_BYTES + remaining_graph_memory_bytes = read_context.remaining_graph_memory_bytes + if graph_bytes > remaining_graph_memory_bytes: + read_context.reserve_graph_memory_fast(graph_bytes) else: - container_bytes = len_ * sizeof(PyObject*) - remaining_container_memory_bytes = read_context.remaining_container_memory_bytes - if container_bytes > remaining_container_memory_bytes: - read_context.reserve_container_memory_fast(container_bytes) - else: - read_context.remaining_container_memory_bytes = ( - remaining_container_memory_bytes - container_bytes - ) + read_context.remaining_graph_memory_bytes = ( + remaining_graph_memory_bytes - graph_bytes + ) + if len_ == 0: + tuple_ = PyTuple_New(0) + return tuple_ read_context.check_readable_bytes(len_) collect_flag = buffer.read_int8() @@ -726,25 +725,24 @@ cdef class SetSerializer(CollectionSerializer): cdef int8_t head_flag cdef int32_t ref_id cdef int64_t i - cdef int64_t container_bytes - cdef int64_t remaining_container_memory_bytes + cdef int64_t graph_bytes + cdef int64_t remaining_graph_memory_bytes len_ = buffer.read_var_uint32() + if len_ < 0: + raise ValueError("Container element count is negative") + graph_bytes = _OWNER_BYTES + len_ * _REFERENCE_BYTES + remaining_graph_memory_bytes = read_context.remaining_graph_memory_bytes + if graph_bytes > remaining_graph_memory_bytes: + read_context.reserve_graph_memory_fast(graph_bytes) + else: + read_context.remaining_graph_memory_bytes = ( + remaining_graph_memory_bytes - graph_bytes + ) if len_ == 0: instance = set() read_context.reference(instance) return instance - if len_ < 0: - raise ValueError("Container element count is negative") - else: - container_bytes = len_ * sizeof(PyObject*) - remaining_container_memory_bytes = read_context.remaining_container_memory_bytes - if container_bytes > remaining_container_memory_bytes: - read_context.reserve_container_memory_fast(container_bytes) - else: - read_context.remaining_container_memory_bytes = ( - remaining_container_memory_bytes - container_bytes - ) read_context.check_readable_bytes(len_) instance = set() read_context.reference(instance) @@ -1090,21 +1088,21 @@ cdef class MapSerializer(Serializer): cdef int32_t ref_id cdef dict map_ cdef int8_t chunk_header = 0 - cdef int64_t container_bytes - cdef int64_t remaining_container_memory_bytes + cdef int64_t graph_bytes + cdef int64_t remaining_graph_memory_bytes + if size < 0: + raise ValueError("Map entry count is negative") + graph_bytes = _OWNER_BYTES + size * (2 * _REFERENCE_BYTES) + remaining_graph_memory_bytes = read_context.remaining_graph_memory_bytes + if graph_bytes > remaining_graph_memory_bytes: + read_context.reserve_graph_memory_fast(graph_bytes) + else: + read_context.remaining_graph_memory_bytes = ( + remaining_graph_memory_bytes - graph_bytes + ) if size == 0: map_ = {} - elif size < 0: - raise ValueError("Map entry count is negative") else: - container_bytes = size * (2 * sizeof(PyObject*)) - remaining_container_memory_bytes = read_context.remaining_container_memory_bytes - if container_bytes > remaining_container_memory_bytes: - read_context.reserve_container_memory_fast(container_bytes) - else: - read_context.remaining_container_memory_bytes = ( - remaining_container_memory_bytes - container_bytes - ) read_context.check_readable_bytes(size) chunk_header = read_context.read_uint8() map_ = _PyDict_NewPresized(size) diff --git a/python/pyfory/collection.py b/python/pyfory/collection.py index 938c3663e8..b0021d74d8 100644 --- a/python/pyfory/collection.py +++ b/python/pyfory/collection.py @@ -36,6 +36,7 @@ COLL_IS_DECL_ELEMENT_TYPE = 0b100 COLL_IS_SAME_TYPE = 0b1000 _REFERENCE_BYTES = struct.calcsize("P") +_OWNER_BYTES = 1 def _needs_element_type_info(type_id): @@ -179,7 +180,7 @@ def _write_different_types(self, write_context, value, collect_flag=0): def read(self, read_context): length = read_context.read_var_uint32() - read_context.reserve_container_memory(length * _REFERENCE_BYTES) + read_context.reserve_graph_memory(_OWNER_BYTES + length * _REFERENCE_BYTES) if length != 0: read_context.check_readable_bytes(length) collection_ = self.new_instance(read_context, self.type_) @@ -461,7 +462,7 @@ def write(self, write_context, obj): def read(self, read_context): size = read_context.read_var_uint32() - read_context.reserve_container_memory(size * 2 * _REFERENCE_BYTES) + read_context.reserve_graph_memory(_OWNER_BYTES + size * 2 * _REFERENCE_BYTES) if size != 0: read_context.check_readable_bytes(size) map_ = {} diff --git a/python/pyfory/context.pxi b/python/pyfory/context.pxi index ed45820ec1..10ac94bf45 100644 --- a/python/pyfory/context.pxi +++ b/python/pyfory/context.pxi @@ -33,7 +33,7 @@ cdef int32_t MAX_CACHED_META_STRING_LENGTH = 2048 cdef int64_t _KNOWN_ROOT_BUDGET_MULTIPLIER = 8 cdef int64_t _KNOWN_ROOT_BUDGET_SLACK_BYTES = 64 * 1024 cdef int64_t _STREAM_ROOT_BUDGET_BYTES = 128 * 1024 * 1024 -cdef int64_t _MAX_CONTAINER_MEMORY_BYTES = 9223372036854775807 +cdef int64_t _MAX_GRAPH_MEMORY_BYTES = 9223372036854775807 cdef inline uint64_t _mix64(uint64_t x): @@ -750,9 +750,9 @@ cdef class ReadContext: cdef readonly bint field_nullable cdef readonly object policy cdef readonly int32_t max_depth - cdef public int64_t max_container_memory_bytes - cdef public int64_t container_memory_limit_bytes - cdef public int64_t remaining_container_memory_bytes + cdef public int64_t max_graph_memory_bytes + cdef public int64_t graph_memory_limit_bytes + cdef public int64_t remaining_graph_memory_bytes cdef readonly RefReader ref_reader cdef readonly MetaStringReader meta_string_reader cdef readonly MetaShareReadContext meta_share_context @@ -773,9 +773,9 @@ cdef class ReadContext: self.field_nullable = config.field_nullable self.policy = config.policy self.max_depth = config.max_depth - self.max_container_memory_bytes = config.max_container_memory_bytes - self.container_memory_limit_bytes = 0 - self.remaining_container_memory_bytes = 0 + self.max_graph_memory_bytes = config.max_graph_memory_bytes + self.graph_memory_limit_bytes = 0 + self.remaining_graph_memory_bytes = 0 self.ref_reader = RefReader(self.track_ref) self.meta_string_reader = MetaStringReader(self.type_resolver.shared_registry) self.meta_share_context = MetaShareReadContext() if config.scoped_meta_share_enabled else None @@ -796,23 +796,23 @@ cdef class ReadContext: int64_t root_input_bytes=-1, ): cdef int64_t limit - if self.max_container_memory_bytes > 0: - limit = self.max_container_memory_bytes + if self.max_graph_memory_bytes > 0: + limit = self.max_graph_memory_bytes elif buffer.has_input_stream(): limit = _STREAM_ROOT_BUDGET_BYTES else: if root_input_bytes < 0: root_input_bytes = buffer.size() - buffer.get_reader_index() - if root_input_bytes > (_MAX_CONTAINER_MEMORY_BYTES - _KNOWN_ROOT_BUDGET_SLACK_BYTES) // _KNOWN_ROOT_BUDGET_MULTIPLIER: - raise ValueError("max_container_memory_bytes auto budget overflow") + if root_input_bytes > (_MAX_GRAPH_MEMORY_BYTES - _KNOWN_ROOT_BUDGET_SLACK_BYTES) // _KNOWN_ROOT_BUDGET_MULTIPLIER: + raise ValueError("max_graph_memory_bytes auto budget overflow") limit = root_input_bytes * _KNOWN_ROOT_BUDGET_MULTIPLIER + _KNOWN_ROOT_BUDGET_SLACK_BYTES self.buffer = buffer self.c_buffer = buffer.c_buffer self.buffers = iter(buffers) if buffers is not None else None self.unsupported_objects = iter(unsupported_objects) if unsupported_objects is not None else None self.peer_out_of_band_enabled = peer_out_of_band_enabled - self.container_memory_limit_bytes = limit - self.remaining_container_memory_bytes = limit + self.graph_memory_limit_bytes = limit + self.remaining_graph_memory_bytes = limit self.depth = 0 cpdef inline reset(self): @@ -827,52 +827,52 @@ cdef class ReadContext: self.buffers = None self.unsupported_objects = None self.peer_out_of_band_enabled = False - self.container_memory_limit_bytes = 0 - self.remaining_container_memory_bytes = 0 + self.graph_memory_limit_bytes = 0 + self.remaining_graph_memory_bytes = 0 self.depth = 0 - cdef inline void reserve_container_memory_c(self, int64_t num_bytes): + cdef inline void reserve_graph_memory_c(self, int64_t num_bytes): cdef int64_t used if num_bytes < 0: - raise ValueError("Estimated container memory is negative") - if num_bytes > _MAX_CONTAINER_MEMORY_BYTES: - raise ValueError("Estimated container memory overflow") - if num_bytes > self.remaining_container_memory_bytes: - used = self.container_memory_limit_bytes - self.remaining_container_memory_bytes + raise ValueError("Estimated graph memory is negative") + if num_bytes > _MAX_GRAPH_MEMORY_BYTES: + raise ValueError("Estimated graph memory overflow") + if num_bytes > self.remaining_graph_memory_bytes: + used = self.graph_memory_limit_bytes - self.remaining_graph_memory_bytes raise ValueError( - f"Estimated container memory budget exceeded: requested {num_bytes} bytes, " - f"used {used} bytes, limit {self.container_memory_limit_bytes} bytes. " - "Increase Fory(..., max_container_memory_bytes=...) for trusted larger payloads." + f"Estimated graph memory budget exceeded: requested {num_bytes} bytes, " + f"used {used} bytes, limit {self.graph_memory_limit_bytes} bytes. " + "Increase Fory(..., max_graph_memory_bytes=...) for trusted larger payloads." ) - self.remaining_container_memory_bytes -= num_bytes + self.remaining_graph_memory_bytes -= num_bytes - cdef inline void reserve_container_memory_fast(self, int64_t num_bytes): + cdef inline void reserve_graph_memory_fast(self, int64_t num_bytes): cdef int64_t used - if num_bytes > self.remaining_container_memory_bytes: - used = self.container_memory_limit_bytes - self.remaining_container_memory_bytes + if num_bytes > self.remaining_graph_memory_bytes: + used = self.graph_memory_limit_bytes - self.remaining_graph_memory_bytes raise ValueError( - f"Estimated container memory budget exceeded: requested {num_bytes} bytes, " - f"used {used} bytes, limit {self.container_memory_limit_bytes} bytes. " - "Increase Fory(..., max_container_memory_bytes=...) for trusted larger payloads." + f"Estimated graph memory budget exceeded: requested {num_bytes} bytes, " + f"used {used} bytes, limit {self.graph_memory_limit_bytes} bytes. " + "Increase Fory(..., max_graph_memory_bytes=...) for trusted larger payloads." ) - self.remaining_container_memory_bytes -= num_bytes + self.remaining_graph_memory_bytes -= num_bytes - cpdef inline reserve_container_memory(self, int64_t num_bytes): - self.reserve_container_memory_c(num_bytes) + cpdef inline reserve_graph_memory(self, int64_t num_bytes): + self.reserve_graph_memory_c(num_bytes) - cdef inline void reserve_counted_container_memory_c( + cdef inline void reserve_counted_graph_memory_c( self, int64_t count, int64_t element_bytes, ): if count < 0 or element_bytes < 0: - raise ValueError("Estimated container memory is negative") - if element_bytes != 0 and count > _MAX_CONTAINER_MEMORY_BYTES // element_bytes: - raise ValueError("Estimated container memory overflow") - self.reserve_container_memory_c(count * element_bytes) + raise ValueError("Estimated graph memory is negative") + if element_bytes != 0 and count > _MAX_GRAPH_MEMORY_BYTES // element_bytes: + raise ValueError("Estimated graph memory overflow") + self.reserve_graph_memory_c(count * element_bytes) - cpdef inline reserve_counted_container_memory(self, int64_t count, int64_t element_bytes): - self.reserve_counted_container_memory_c(count, element_bytes) + cpdef inline reserve_counted_graph_memory(self, int64_t count, int64_t element_bytes): + self.reserve_counted_graph_memory_c(count, element_bytes) cpdef inline add_context_object(self, key, obj): self.context_objects[id(key)] = obj diff --git a/python/pyfory/context.py b/python/pyfory/context.py index f53384e31c..62ac7810a9 100644 --- a/python/pyfory/context.py +++ b/python/pyfory/context.py @@ -40,7 +40,7 @@ _KNOWN_ROOT_BUDGET_MULTIPLIER = 8 _KNOWN_ROOT_BUDGET_SLACK_BYTES = 64 * 1024 _STREAM_ROOT_BUDGET_BYTES = 128 * 1024 * 1024 -_MAX_CONTAINER_MEMORY_BYTES = (1 << 63) - 1 +_MAX_GRAPH_MEMORY_BYTES = (1 << 63) - 1 def _mix64(x: int) -> int: @@ -474,9 +474,9 @@ class ReadContext: "field_nullable", "policy", "max_depth", - "max_container_memory_bytes", - "container_memory_limit_bytes", - "remaining_container_memory_bytes", + "max_graph_memory_bytes", + "graph_memory_limit_bytes", + "remaining_graph_memory_bytes", "ref_reader", "meta_string_reader", "meta_share_context", @@ -497,9 +497,9 @@ def __init__(self, config: Config, type_resolver): self.field_nullable = config.field_nullable self.policy = config.policy self.max_depth = config.max_depth - self.max_container_memory_bytes = config.max_container_memory_bytes - self.container_memory_limit_bytes = 0 - self.remaining_container_memory_bytes = 0 + self.max_graph_memory_bytes = config.max_graph_memory_bytes + self.graph_memory_limit_bytes = 0 + self.remaining_graph_memory_bytes = 0 self.ref_reader = MapRefReader() if self.track_ref else NoRefReader() self.meta_string_reader = MetaStringReader(type_resolver.shared_registry) self.meta_share_context = MetaShareReadContext() if config.scoped_meta_share_enabled else None @@ -531,8 +531,8 @@ def prepare( peer_out_of_band_enabled=False, root_input_bytes=None, ): - if self.max_container_memory_bytes > 0: - limit = self.max_container_memory_bytes + if self.max_graph_memory_bytes > 0: + limit = self.max_graph_memory_bytes elif buffer.has_input_stream(): limit = _STREAM_ROOT_BUDGET_BYTES else: @@ -540,15 +540,15 @@ def prepare( root_input_bytes = buffer.size() - buffer.get_reader_index() if root_input_bytes < 0: raise ValueError("root input byte count is negative") - if root_input_bytes > (_MAX_CONTAINER_MEMORY_BYTES - _KNOWN_ROOT_BUDGET_SLACK_BYTES) // _KNOWN_ROOT_BUDGET_MULTIPLIER: - raise ValueError("max_container_memory_bytes auto budget overflow") + if root_input_bytes > (_MAX_GRAPH_MEMORY_BYTES - _KNOWN_ROOT_BUDGET_SLACK_BYTES) // _KNOWN_ROOT_BUDGET_MULTIPLIER: + raise ValueError("max_graph_memory_bytes auto budget overflow") limit = root_input_bytes * _KNOWN_ROOT_BUDGET_MULTIPLIER + _KNOWN_ROOT_BUDGET_SLACK_BYTES self.buffer = buffer self.buffers = iter(buffers) if buffers is not None else None self.unsupported_objects = iter(unsupported_objects) if unsupported_objects is not None else None self.peer_out_of_band_enabled = peer_out_of_band_enabled - self.container_memory_limit_bytes = limit - self.remaining_container_memory_bytes = limit + self.graph_memory_limit_bytes = limit + self.remaining_graph_memory_bytes = limit self.depth = 0 def reset(self): @@ -562,31 +562,31 @@ def reset(self): self.buffers = None self.unsupported_objects = None self.peer_out_of_band_enabled = False - self.container_memory_limit_bytes = 0 - self.remaining_container_memory_bytes = 0 + self.graph_memory_limit_bytes = 0 + self.remaining_graph_memory_bytes = 0 self.depth = 0 - def reserve_container_memory(self, num_bytes): + def reserve_graph_memory(self, num_bytes): if num_bytes < 0: - raise ValueError("Estimated container memory is negative") - if num_bytes > _MAX_CONTAINER_MEMORY_BYTES: - raise ValueError("Estimated container memory overflow") - remaining = self.remaining_container_memory_bytes + raise ValueError("Estimated graph memory is negative") + if num_bytes > _MAX_GRAPH_MEMORY_BYTES: + raise ValueError("Estimated graph memory overflow") + remaining = self.remaining_graph_memory_bytes if num_bytes > remaining: - used = self.container_memory_limit_bytes - remaining + used = self.graph_memory_limit_bytes - remaining raise ValueError( - f"Estimated container memory budget exceeded: requested {num_bytes} bytes, " - f"used {used} bytes, limit {self.container_memory_limit_bytes} bytes. " - "Increase Fory(..., max_container_memory_bytes=...) for trusted larger payloads." + f"Estimated graph memory budget exceeded: requested {num_bytes} bytes, " + f"used {used} bytes, limit {self.graph_memory_limit_bytes} bytes. " + "Increase Fory(..., max_graph_memory_bytes=...) for trusted larger payloads." ) - self.remaining_container_memory_bytes = remaining - num_bytes + self.remaining_graph_memory_bytes = remaining - num_bytes - def reserve_counted_container_memory(self, count, element_bytes): + def reserve_counted_graph_memory(self, count, element_bytes): if count < 0 or element_bytes < 0: - raise ValueError("Estimated container memory is negative") - if element_bytes and count > _MAX_CONTAINER_MEMORY_BYTES // element_bytes: - raise ValueError("Estimated container memory overflow") - self.reserve_container_memory(count * element_bytes) + raise ValueError("Estimated graph memory is negative") + if element_bytes and count > _MAX_GRAPH_MEMORY_BYTES // element_bytes: + raise ValueError("Estimated graph memory overflow") + self.reserve_graph_memory(count * element_bytes) def add_context_object(self, key, obj): self.context_objects[id(key)] = obj diff --git a/python/pyfory/serialization.pyx b/python/pyfory/serialization.pyx index 2e4fede422..54d72cc5cb 100644 --- a/python/pyfory/serialization.pyx +++ b/python/pyfory/serialization.pyx @@ -113,7 +113,7 @@ cdef class Config: max_type_meta_bytes: Maximum accepted body size in one received TypeDef. max_schema_versions_per_type: Maximum accepted remote metadata versions for one logical type. max_average_schema_versions_per_type: Average remote schema versions allowed across accepted remote types. - max_container_memory_bytes: Maximum estimated container-owned memory per root + max_graph_memory_bytes: Maximum estimated graph memory per root deserialization. -1 means auto; positive values are explicit byte limits. field_nullable: Treats struct/dataclass fields as nullable by default. policy: Deserialization policy used for security-sensitive checks. @@ -131,7 +131,7 @@ cdef class Config: cdef public int32_t max_type_meta_bytes cdef public int32_t max_schema_versions_per_type cdef public int32_t max_average_schema_versions_per_type - cdef public int64_t max_container_memory_bytes + cdef public int64_t max_graph_memory_bytes cdef public bint field_nullable cdef public object policy cdef public object meta_compressor @@ -150,7 +150,7 @@ cdef class Config: max_type_meta_bytes, max_schema_versions_per_type, max_average_schema_versions_per_type, - max_container_memory_bytes, + max_graph_memory_bytes, field_nullable, policy, meta_compressor, @@ -170,7 +170,7 @@ cdef class Config: max_type_meta_bytes: Maximum accepted body size in one received TypeDef. max_schema_versions_per_type: Maximum accepted remote metadata versions for one logical type. max_average_schema_versions_per_type: Average remote schema versions allowed across accepted remote types. - max_container_memory_bytes: Maximum estimated container-owned memory per root + max_graph_memory_bytes: Maximum estimated graph memory per root deserialization. -1 means auto; positive values are explicit byte limits. field_nullable: Treat all struct fields as nullable by default. policy: Deserialization policy implementation. @@ -192,16 +192,16 @@ cdef class Config: if max_average_schema_versions_per_type <= 0: raise ValueError("max_average_schema_versions_per_type must be a positive integer") if ( - not isinstance(max_container_memory_bytes, int) - or (max_container_memory_bytes != -1 and max_container_memory_bytes <= 0) - or max_container_memory_bytes > 9223372036854775807 + not isinstance(max_graph_memory_bytes, int) + or (max_graph_memory_bytes != -1 and max_graph_memory_bytes <= 0) + or max_graph_memory_bytes > 9223372036854775807 ): - raise ValueError("max_container_memory_bytes must be -1 or a positive 63-bit integer") + raise ValueError("max_graph_memory_bytes must be -1 or a positive 63-bit integer") self.max_type_fields = max_type_fields self.max_type_meta_bytes = max_type_meta_bytes self.max_schema_versions_per_type = max_schema_versions_per_type self.max_average_schema_versions_per_type = max_average_schema_versions_per_type - self.max_container_memory_bytes = max_container_memory_bytes + self.max_graph_memory_bytes = max_graph_memory_bytes self.field_nullable = field_nullable self.policy = policy self.meta_compressor = meta_compressor @@ -842,7 +842,7 @@ cdef class Fory: cdef public bint compatible cdef public bint field_nullable cdef public int32_t max_depth - cdef public int64_t max_container_memory_bytes + cdef public int64_t max_graph_memory_bytes cdef public object policy cdef public Config config cdef public TypeResolver type_resolver @@ -861,7 +861,7 @@ cdef class Fory: max_type_meta_bytes=4096, max_schema_versions_per_type=10, max_average_schema_versions_per_type=3, - max_container_memory_bytes=-1, + max_graph_memory_bytes=-1, policy=None, field_nullable=False, meta_compressor=None, @@ -880,7 +880,7 @@ cdef class Fory: max_type_meta_bytes: Maximum accepted body size in one received TypeDef. max_schema_versions_per_type: Maximum accepted remote metadata versions for one logical type. max_average_schema_versions_per_type: Average remote schema versions allowed across accepted remote types. - max_container_memory_bytes: Maximum estimated container-owned memory per root + max_graph_memory_bytes: Maximum estimated graph memory per root deserialization. -1 means auto; positive values are explicit byte limits. policy: Optional deserialization policy implementation. field_nullable: Treat struct fields as nullable by default. @@ -900,12 +900,12 @@ cdef class Fory: self.field_nullable = field_nullable self.max_depth = max_depth if ( - not isinstance(max_container_memory_bytes, int) - or (max_container_memory_bytes != -1 and max_container_memory_bytes <= 0) - or max_container_memory_bytes > 9223372036854775807 + not isinstance(max_graph_memory_bytes, int) + or (max_graph_memory_bytes != -1 and max_graph_memory_bytes <= 0) + or max_graph_memory_bytes > 9223372036854775807 ): - raise ValueError("max_container_memory_bytes must be -1 or a positive 63-bit integer") - self.max_container_memory_bytes = max_container_memory_bytes + raise ValueError("max_graph_memory_bytes must be -1 or a positive 63-bit integer") + self.max_graph_memory_bytes = max_graph_memory_bytes self.config = Config( xlang=xlang, track_ref=ref, @@ -918,7 +918,7 @@ cdef class Fory: max_type_meta_bytes=max_type_meta_bytes, max_schema_versions_per_type=max_schema_versions_per_type, max_average_schema_versions_per_type=max_average_schema_versions_per_type, - max_container_memory_bytes=max_container_memory_bytes, + max_graph_memory_bytes=max_graph_memory_bytes, field_nullable=field_nullable, policy=self.policy, meta_compressor=meta_compressor, @@ -1077,7 +1077,7 @@ cdef class Fory: cdef uint8_t bitmap cdef bint peer_out_of_band_enabled cdef int64_t root_input_bytes - cdef int64_t container_memory_limit + cdef int64_t graph_memory_limit if isinstance(buffer, bytes): buffer = Buffer(buffer) read_buffer = buffer @@ -1093,13 +1093,13 @@ cdef class Fory: raise ValueError("Out-of-band buffers are required by the root header") if not peer_out_of_band_enabled and buffers is not None: raise ValueError("Out-of-band buffers were provided for an in-band root payload") - if self.max_container_memory_bytes > 0: - container_memory_limit = self.max_container_memory_bytes + if self.max_graph_memory_bytes > 0: + graph_memory_limit = self.max_graph_memory_bytes elif read_buffer.has_input_stream(): - container_memory_limit = _STREAM_ROOT_BUDGET_BYTES + graph_memory_limit = _STREAM_ROOT_BUDGET_BYTES else: root_input_bytes = read_buffer.size() - reader_index - container_memory_limit = root_input_bytes * _KNOWN_ROOT_BUDGET_MULTIPLIER + _KNOWN_ROOT_BUDGET_SLACK_BYTES + graph_memory_limit = root_input_bytes * _KNOWN_ROOT_BUDGET_MULTIPLIER + _KNOWN_ROOT_BUDGET_SLACK_BYTES # Keep the root context setup inline. Top-level deserialize is a hot path, # so it should not pay an extra method call just to bind the active buffer. read_context.buffer = read_buffer @@ -1109,8 +1109,8 @@ cdef class Fory: iter(unsupported_objects) if unsupported_objects is not None else None ) read_context.peer_out_of_band_enabled = peer_out_of_band_enabled - read_context.container_memory_limit_bytes = container_memory_limit - read_context.remaining_container_memory_bytes = container_memory_limit + read_context.graph_memory_limit_bytes = graph_memory_limit + read_context.remaining_graph_memory_bytes = graph_memory_limit read_context.depth = 0 return read_context.read_ref() diff --git a/python/pyfory/serializer.py b/python/pyfory/serializer.py index 17ce24063f..5145b1e335 100644 --- a/python/pyfory/serializer.py +++ b/python/pyfory/serializer.py @@ -44,6 +44,7 @@ _WINDOWS = os.name == "nt" _REFERENCE_BYTES = struct.calcsize("P") +_OWNER_BYTES = 1 from pyfory.serialization import ENABLE_FORY_CYTHON_SERIALIZATION from pyfory.types import TypeId @@ -935,7 +936,7 @@ def read(self, read_context): if dtype.kind == "O": length = read_context.read_varint32() _check_non_negative_size(length, "ndarray object") - read_context.reserve_container_memory(length * _REFERENCE_BYTES) + read_context.reserve_graph_memory(_OWNER_BYTES + length * _REFERENCE_BYTES) read_context.check_readable_bytes(length) items = [read_context.read_ref() for _ in range(length)] return np.array(items, dtype=object) @@ -1716,10 +1717,11 @@ def write(self, write_context, value): def read(self, read_context): policy = read_context.policy policy.authorize_instantiation(self.type_) - obj = self.type_.__new__(self.type_) - read_context.reference(obj) num_fields = read_context.read_var_uint32() _check_non_negative_size(num_fields, "object field") + read_context.reserve_graph_memory(_OWNER_BYTES + num_fields * _REFERENCE_BYTES) + obj = self.type_.__new__(self.type_) + read_context.reference(obj) state = {} for _ in range(num_fields): field_name = read_context.read_string() @@ -1733,10 +1735,11 @@ def read(self, read_context): class _DefaultPolicyObjectSerializer(ObjectSerializer): def read(self, read_context): - obj = self.type_.__new__(self.type_) - read_context.reference(obj) num_fields = read_context.read_var_uint32() _check_non_negative_size(num_fields, "object field") + read_context.reserve_graph_memory(_OWNER_BYTES + num_fields * _REFERENCE_BYTES) + obj = self.type_.__new__(self.type_) + read_context.reference(obj) for _ in range(num_fields): field_name = read_context.read_string() field_value = read_context.read_ref() diff --git a/python/pyfory/struct.pxi b/python/pyfory/struct.pxi index 96b53d0fa9..bb813b0a57 100644 --- a/python/pyfory/struct.pxi +++ b/python/pyfory/struct.pxi @@ -22,6 +22,8 @@ from cpython.unicode cimport PyUnicode_InternFromString cdef uint8_t _BASIC_FIELD_NOT_INLINE = 0xFF +cdef int64_t _STRUCT_OWNER_BYTES = 1 +cdef int64_t _STRUCT_REFERENCE_BYTES = sizeof(PyObject*) cdef struct FieldRuntimeInfo: @@ -422,6 +424,9 @@ cdef class DataClassSerializer(Serializer): f"Hash {read_hash} is not consistent with {self._hash} for type {self.type_}" ) + read_context.reserve_graph_memory_fast( + _STRUCT_OWNER_BYTES + self._field_runtime_infos.size() * _STRUCT_REFERENCE_BYTES + ) obj = self.type_.__new__(self.type_) read_context.reference(obj) if self._has_slots: diff --git a/python/pyfory/struct.py b/python/pyfory/struct.py index a8daf46687..478cda5b43 100644 --- a/python/pyfory/struct.py +++ b/python/pyfory/struct.py @@ -24,6 +24,7 @@ import inspect import logging import os +import struct import sys import typing from typing import List, Dict @@ -80,6 +81,9 @@ logger = logging.getLogger(__name__) +_REFERENCE_BYTES = struct.calcsize("P") +_OWNER_BYTES = 1 + _MISSING_DEFAULT_INT_TYPES = { int, TypeId.INT8, @@ -651,6 +655,7 @@ def read(self, read_context): raise TypeNotCompatibleError( f"Hash {hash_} is not consistent with {self._hash} for type {self.type_}", ) + read_context.reserve_graph_memory(_OWNER_BYTES + len(self._field_names) * _REFERENCE_BYTES) obj = self.type_.__new__(self.type_) read_context.reference(obj) obj_dict = obj.__dict__ if not self._has_slots else None diff --git a/python/pyfory/tests/test_container_memory_budget.py b/python/pyfory/tests/test_graph_memory_budget.py similarity index 73% rename from python/pyfory/tests/test_container_memory_budget.py rename to python/pyfory/tests/test_graph_memory_budget.py index ae2bc15288..d777b719bf 100644 --- a/python/pyfory/tests/test_container_memory_budget.py +++ b/python/pyfory/tests/test_graph_memory_budget.py @@ -16,6 +16,7 @@ # under the License. import array +import dataclasses import struct import pytest @@ -34,7 +35,8 @@ KNOWN_ROOT_BUDGET_SLACK_BYTES = 64 * 1024 STREAM_ROOT_BUDGET_BYTES = 128 * 1024 * 1024 REFERENCE_BYTES = struct.calcsize("P") -MAX_CONTAINER_MEMORY_BYTES = (1 << 63) - 1 +OWNER_BYTES = 1 +MAX_GRAPH_MEMORY_BYTES = (1 << 63) - 1 class OneByteStream: @@ -81,12 +83,25 @@ def recv_into(self, buffer, size=-1): return read_size +@dataclasses.dataclass +class BudgetItem: + value: int + + +class BudgetObject: + pass + + def collection_memory(num_elements): - return num_elements * REFERENCE_BYTES + return OWNER_BYTES + num_elements * REFERENCE_BYTES def map_memory(num_entries): - return num_entries * 2 * REFERENCE_BYTES + return OWNER_BYTES + num_entries * 2 * REFERENCE_BYTES + + +def object_memory(num_fields): + return OWNER_BYTES + num_fields * REFERENCE_BYTES def new_fory(limit=-1, *, xlang=True): @@ -95,14 +110,14 @@ def new_fory(limit=-1, *, xlang=True): ref=True, strict=False, compatible=xlang, - max_container_memory_bytes=limit, + max_graph_memory_bytes=limit, ) def expect_budget(value, budget, *, xlang=True): writer = new_fory(xlang=xlang) data = writer.serialize(value) - with pytest.raises(ValueError, match="Estimated container memory budget exceeded"): + with pytest.raises(ValueError, match="Estimated graph memory budget exceeded"): new_fory(budget - 1, xlang=xlang).deserialize(data) return new_fory(budget, xlang=xlang).deserialize(data) @@ -119,10 +134,10 @@ def test_known_length_auto_budget(): try: fory.read_context.prepare(Buffer(b"x" * root_input_bytes), root_input_bytes=root_input_bytes) expected = root_input_bytes * KNOWN_ROOT_BUDGET_MULTIPLIER + KNOWN_ROOT_BUDGET_SLACK_BYTES - assert fory.read_context.container_memory_limit_bytes == expected - fory.read_context.reserve_container_memory(expected) - with pytest.raises(ValueError, match="Estimated container memory budget exceeded"): - fory.read_context.reserve_container_memory(1) + assert fory.read_context.graph_memory_limit_bytes == expected + fory.read_context.reserve_graph_memory(expected) + with pytest.raises(ValueError, match="Estimated graph memory budget exceeded"): + fory.read_context.reserve_graph_memory(1) finally: fory.reset_read() @@ -132,7 +147,7 @@ def test_stream_auto_budget(): try: buffer = Buffer.from_stream(OneByteStream(b"streamed")) fory.read_context.prepare(buffer, root_input_bytes=1) - assert fory.read_context.container_memory_limit_bytes == STREAM_ROOT_BUDGET_BYTES + assert fory.read_context.graph_memory_limit_bytes == STREAM_ROOT_BUDGET_BYTES finally: fory.reset_read() @@ -145,16 +160,41 @@ def test_explicit_config_overrides_auto(): def test_nested_empty_containers_use_parent_storage(): value = [[]] - budget = collection_memory(1) + budget = collection_memory(1) + collection_memory(0) assert expect_budget(value, budget) == value def test_sibling_nested_containers_are_cumulative(): value = [[], [], []] - budget = collection_memory(3) + budget = collection_memory(3) + 3 * collection_memory(0) assert expect_budget(value, budget) == value +def test_empty_object_owner_is_charged(): + fory = new_fory(xlang=False) + fory.register_type(BudgetItem) + value = BudgetItem(1) + budget = object_memory(1) + data = fory.serialize(value) + with pytest.raises(ValueError, match="Estimated graph memory budget exceeded"): + reader = new_fory(budget - 1, xlang=False) + reader.register_type(BudgetItem) + reader.deserialize(data) + reader = new_fory(budget, xlang=False) + reader.register_type(BudgetItem) + assert reader.deserialize(data) == value + + +def test_dynamic_object_owner_is_charged(): + value = BudgetObject() + value.left = 1 + value.right = "x" + budget = object_memory(2) + restored = expect_budget(value, budget, xlang=False) + assert restored.left == value.left + assert restored.right == value.right + + def test_map_entry_budget_and_overflow(): value = {"a": 1} assert expect_budget(value, map_memory(1)) == value @@ -162,9 +202,9 @@ def test_map_entry_budget_and_overflow(): fory = new_fory(xlang=False) try: fory.read_context.prepare(Buffer(b""), root_input_bytes=0) - max_map_entries = MAX_CONTAINER_MEMORY_BYTES // (2 * REFERENCE_BYTES) - with pytest.raises(ValueError, match="Estimated container memory overflow"): - fory.read_context.reserve_counted_container_memory(max_map_entries + 1, 2 * REFERENCE_BYTES) + max_map_entries = MAX_GRAPH_MEMORY_BYTES // (2 * REFERENCE_BYTES) + with pytest.raises(ValueError, match="Estimated graph memory overflow"): + fory.read_context.reserve_counted_graph_memory(max_map_entries + 1, 2 * REFERENCE_BYTES) finally: fory.reset_read() @@ -206,12 +246,12 @@ def test_declared_large_list_still_needs_bytes(): fory.read_context.prepare(Buffer(varuint_payload(1000)), root_input_bytes=1) with pytest.raises(Exception) as exc_info: serializer.read(fory.read_context) - assert "Estimated container memory" not in str(exc_info.value) + assert "Estimated graph memory" not in str(exc_info.value) finally: fory.reset_read() @pytest.mark.parametrize("limit", [0, -2, 1 << 63]) def test_invalid_config(limit): - with pytest.raises(ValueError, match="max_container_memory_bytes"): + with pytest.raises(ValueError, match="max_graph_memory_bytes"): new_fory(limit) diff --git a/rust/fory-core/src/config.rs b/rust/fory-core/src/config.rs index f5d003cd0f..64c76d98f0 100644 --- a/rust/fory-core/src/config.rs +++ b/rust/fory-core/src/config.rs @@ -40,9 +40,9 @@ pub struct Config { /// When enabled, shared references and circular references are tracked /// and preserved during serialization/deserialization. pub track_ref: bool, - /// Maximum estimated container-owned memory accepted during one root deserialization. + /// Maximum estimated graph memory accepted during one root deserialization. /// `-1` selects the automatic input-shaped limit. - pub max_container_memory_bytes: i64, + pub max_graph_memory_bytes: i64, /// Maximum accepted field count in one received struct TypeMeta. pub max_type_fields: u32, /// Maximum accepted body size in one received TypeMeta. @@ -64,7 +64,7 @@ impl Default for Config { max_dyn_depth: 5, check_struct_version: false, track_ref: false, - max_container_memory_bytes: -1, + max_graph_memory_bytes: -1, max_type_fields: 512, max_type_meta_bytes: 4096, max_schema_versions_per_type: 10, @@ -127,10 +127,10 @@ impl Config { self.track_ref } - /// Get maximum estimated container-owned memory per root deserialization. + /// Get maximum estimated graph memory per root deserialization. #[inline(always)] - pub fn max_container_memory_bytes(&self) -> i64 { - self.max_container_memory_bytes + pub fn max_graph_memory_bytes(&self) -> i64 { + self.max_graph_memory_bytes } /// Get maximum accepted field count in one received struct TypeMeta. diff --git a/rust/fory-core/src/context.rs b/rust/fory-core/src/context.rs index aabfc6c8a3..549370b5bf 100644 --- a/rust/fory-core/src/context.rs +++ b/rust/fory-core/src/context.rs @@ -33,7 +33,7 @@ use std::rc::Rc; const KNOWN_ROOT_BUDGET_MULTIPLIER: usize = 8; const KNOWN_ROOT_BUDGET_SLACK_BYTES: usize = 64 * 1024; -const MAX_CONTAINER_LEN: usize = u32::MAX as usize; +const MAX_GRAPH_COUNT: usize = u32::MAX as usize; /// Thread-local context cache with fast path for single Fory instance. /// Uses (cached_id, context) for O(1) access when using same Fory instance repeatedly. @@ -363,9 +363,9 @@ pub struct ReadContext<'a> { max_dyn_depth: u32, check_struct_version: bool, check_string_read: bool, - max_container_memory_bytes: i64, - container_memory_limit_bytes: usize, - remaining_container_memory_bytes: usize, + max_graph_memory_bytes: i64, + graph_memory_limit_bytes: usize, + remaining_graph_memory_bytes: usize, // Context-specific fields pub reader: Reader<'a>, @@ -395,9 +395,9 @@ impl<'a> ReadContext<'a> { max_dyn_depth: config.max_dyn_depth, check_struct_version: config.check_struct_version, check_string_read: config.check_string_read, - max_container_memory_bytes: config.max_container_memory_bytes, - container_memory_limit_bytes: 0, - remaining_container_memory_bytes: 0, + max_graph_memory_bytes: config.max_graph_memory_bytes, + graph_memory_limit_bytes: 0, + remaining_graph_memory_bytes: 0, reader: Reader::default(), meta_resolver: MetaReaderResolver::default(), meta_string_resolver: MetaStringReaderResolver::default(), @@ -454,77 +454,77 @@ impl<'a> ReadContext<'a> { } #[inline(always)] - pub(crate) fn init_container_memory_budget( + pub(crate) fn init_graph_memory_budget( &mut self, root_input_bytes: usize, ) -> Result<(), Error> { - let limit = if self.max_container_memory_bytes > 0 { - usize::try_from(self.max_container_memory_bytes).map_err(|_| { - container_memory_error("max_container_memory_bytes does not fit usize") - })? + let limit = if self.max_graph_memory_bytes > 0 { + usize::try_from(self.max_graph_memory_bytes) + .map_err(|_| graph_memory_error("max_graph_memory_bytes does not fit usize"))? } else { if root_input_bytes > (usize::MAX - KNOWN_ROOT_BUDGET_SLACK_BYTES) / KNOWN_ROOT_BUDGET_MULTIPLIER { - return Err(container_memory_error( - "root input size overflows automatic container memory budget", + return Err(graph_memory_error( + "root input size overflows automatic graph memory budget", )); } root_input_bytes * KNOWN_ROOT_BUDGET_MULTIPLIER + KNOWN_ROOT_BUDGET_SLACK_BYTES }; - self.container_memory_limit_bytes = limit; - self.remaining_container_memory_bytes = limit; + self.graph_memory_limit_bytes = limit; + self.remaining_graph_memory_bytes = limit; Ok(()) } #[inline(always)] - pub(crate) fn reserve_counted_container_memory( + pub(crate) fn reserve_counted_graph_memory( &mut self, len: u32, elem_bytes: usize, ) -> Result { let len = len as usize; - self.reserve_counted_memory(len, elem_bytes)?; + self.reserve_counted_graph_bytes(len, elem_bytes)?; Ok(len) } #[inline(always)] - pub(crate) fn reserve_container_bytes(&mut self, bytes: usize) -> Result<(), Error> { - let remaining = self.remaining_container_memory_bytes; + #[doc(hidden)] + pub fn reserve_graph_memory(&mut self, bytes: usize) -> Result<(), Error> { + let remaining = self.remaining_graph_memory_bytes; if bytes > remaining { - return Err(container_memory_exceeded( + return Err(graph_memory_exceeded( bytes, remaining, - self.container_memory_limit_bytes, + self.graph_memory_limit_bytes, )); } - self.remaining_container_memory_bytes = remaining - bytes; + self.remaining_graph_memory_bytes = remaining - bytes; Ok(()) } #[inline(always)] - fn reserve_counted_memory(&mut self, len: usize, elem_bytes: usize) -> Result<(), Error> { + fn reserve_counted_graph_bytes(&mut self, len: usize, elem_bytes: usize) -> Result<(), Error> { if len == 0 { return Ok(()); } - if elem_bytes <= usize::MAX / MAX_CONTAINER_LEN { - return self.reserve_container_bytes(len * elem_bytes); + if elem_bytes <= usize::MAX / MAX_GRAPH_COUNT { + return self.reserve_graph_memory(len * elem_bytes); } - self.reserve_counted_memory_checked(len, elem_bytes) + self.reserve_counted_graph_checked(len, elem_bytes) } #[cold] #[inline(never)] - fn reserve_counted_memory_checked( + fn reserve_counted_graph_checked( &mut self, len: usize, elem_bytes: usize, ) -> Result<(), Error> { let bytes = match len.checked_mul(elem_bytes) { Some(bytes) => bytes, - None => return Err(container_memory_overflow(len, elem_bytes)), + None => return Err(graph_memory_overflow(len, elem_bytes)), }; - self.reserve_container_bytes(bytes) + self.reserve_graph_memory(bytes) } #[inline(always)] @@ -639,24 +639,24 @@ impl<'a> ReadContext<'a> { #[cold] #[inline(never)] -fn container_memory_error(message: &'static str) -> Error { +fn graph_memory_error(message: &'static str) -> Error { Error::invalid_data(message) } #[cold] #[inline(never)] -fn container_memory_overflow(len: usize, elem_bytes: usize) -> Error { +fn graph_memory_overflow(len: usize, elem_bytes: usize) -> Error { Error::invalid_data(format!( - "container memory estimate overflows: length={} elementBytes={}", + "graph memory estimate overflows: length={} elementBytes={}", len, elem_bytes )) } #[cold] #[inline(never)] -fn container_memory_exceeded(bytes: usize, remaining: usize, limit: usize) -> Error { +fn graph_memory_exceeded(bytes: usize, remaining: usize, limit: usize) -> Error { Error::invalid_data(format!( - "estimated container memory request {} bytes exceeds max_container_memory_bytes remaining budget {} bytes out of effective limit {} bytes", + "estimated graph memory request {} bytes exceeds max_graph_memory_bytes remaining budget {} bytes out of effective limit {} bytes", bytes, remaining, limit )) } diff --git a/rust/fory-core/src/fory.rs b/rust/fory-core/src/fory.rs index 8eb9d3794f..fd3bfdb827 100644 --- a/rust/fory-core/src/fory.rs +++ b/rust/fory-core/src/fory.rs @@ -261,15 +261,15 @@ impl ForyBuilder { self } - /// Sets the maximum estimated container-owned memory accepted during one root deserialization. + /// Sets the maximum estimated graph memory accepted during one root deserialization. /// /// Use `-1` for the automatic input-shaped limit. Positive values are explicit byte limits. - pub fn max_container_memory_bytes(mut self, max_bytes: i64) -> Self { + pub fn max_graph_memory_bytes(mut self, max_bytes: i64) -> Self { assert!( max_bytes == -1 || max_bytes > 0, - "max_container_memory_bytes must be positive or -1 for auto" + "max_graph_memory_bytes must be positive or -1 for auto" ); - self.config.max_container_memory_bytes = max_bytes; + self.config.max_graph_memory_bytes = max_bytes; self } @@ -1000,7 +1000,7 @@ impl Fory { self.with_read_context(|context| { let outlive_buffer = unsafe { mem::transmute::<&[u8], &[u8]>(bf) }; context.attach_reader(Reader::new(outlive_buffer)); - let result = match context.init_container_memory_budget(bf.len()) { + let result = match context.init_graph_memory_budget(bf.len()) { Ok(()) => self.deserialize_with_context(context), Err(err) => { context.reset(); @@ -1070,7 +1070,7 @@ impl Fory { new_reader.set_cursor(reader.cursor); let root_input_bytes = reader.bf.len().saturating_sub(reader.cursor); context.attach_reader(new_reader); - let result = match context.init_container_memory_budget(root_input_bytes) { + let result = match context.init_graph_memory_budget(root_input_bytes) { Ok(()) => self.deserialize_with_context(context), Err(err) => { context.reset(); @@ -1135,6 +1135,10 @@ impl Fory { RefMode::NullOnly }; // TypeMeta is read inline during deserialization (streaming protocol) + let root_graph_self_size = T::fory_graph_self_size(); + if root_graph_self_size != 0 { + context.reserve_graph_memory(root_graph_self_size)?; + } let result = ::fory_read(context, ref_mode, true); context.ref_reader.resolve_callbacks(); result diff --git a/rust/fory-core/src/resolver/type_resolver.rs b/rust/fory-core/src/resolver/type_resolver.rs index 41e6c53262..4813180594 100644 --- a/rust/fory-core/src/resolver/type_resolver.rs +++ b/rust/fory-core/src/resolver/type_resolver.rs @@ -949,6 +949,10 @@ impl TypeResolver { ref_mode: RefMode, read_type_info: bool, ) -> Result, Error> { + let graph_self_size = T2::fory_graph_self_size(); + if graph_self_size != 0 { + context.reserve_graph_memory(graph_self_size)?; + } Ok(Box::new(T2::fory_read(context, ref_mode, read_type_info)?)) } @@ -971,6 +975,10 @@ impl TypeResolver { fn read_data( context: &mut ReadContext, ) -> Result, Error> { + let graph_self_size = T2::fory_graph_self_size(); + if graph_self_size != 0 { + context.reserve_graph_memory(graph_self_size)?; + } match T2::fory_read_data(context) { Ok(v) => Ok(Box::new(v)), Err(e) => Err(e), @@ -1002,6 +1010,10 @@ impl TypeResolver { context: &mut ReadContext, type_info: Rc, ) -> Result, Error> { + let graph_self_size = T2::fory_graph_self_size(); + if graph_self_size != 0 { + context.reserve_graph_memory(graph_self_size)?; + } Ok(Box::new(T2::fory_read_compatible(context, type_info)?)) } @@ -1192,6 +1204,10 @@ impl TypeResolver { ref_mode: RefMode, read_type_info: bool, ) -> Result, Error> { + let graph_self_size = T2::fory_graph_self_size(); + if graph_self_size != 0 { + context.reserve_graph_memory(graph_self_size)?; + } Ok(Box::new(T2::fory_read(context, ref_mode, read_type_info)?)) } @@ -1214,6 +1230,10 @@ impl TypeResolver { fn read_data( context: &mut ReadContext, ) -> Result, Error> { + let graph_self_size = T2::fory_graph_self_size(); + if graph_self_size != 0 { + context.reserve_graph_memory(graph_self_size)?; + } match T2::fory_read_data(context) { Ok(v) => Ok(Box::new(v)), Err(e) => Err(e), diff --git a/rust/fory-core/src/serializer/arc.rs b/rust/fory-core/src/serializer/arc.rs index 1e69ae15c3..25677c9c34 100644 --- a/rust/fory-core/src/serializer/arc.rs +++ b/rust/fory-core/src/serializer/arc.rs @@ -217,6 +217,10 @@ fn read_arc_inner( ) -> Result { // Read type info if needed, then read data directly // No recursive ref handling needed since Arc only wraps allowed types + let graph_self_size = T::fory_graph_self_size(); + if graph_self_size != 0 { + context.reserve_graph_memory(graph_self_size)?; + } if let Some(typeinfo) = typeinfo { return T::fory_read_with_type_info(context, RefMode::None, typeinfo); } diff --git a/rust/fory-core/src/serializer/array.rs b/rust/fory-core/src/serializer/array.rs index 35e6a31069..c0257beb13 100644 --- a/rust/fory-core/src/serializer/array.rs +++ b/rust/fory-core/src/serializer/array.rs @@ -269,6 +269,18 @@ impl Serializer for [T; N] { } } + #[inline(always)] + fn fory_graph_self_size() -> usize + where + Self: Sized, + { + if is_primitive_type::() { + 0 + } else { + mem::size_of::() + } + } + #[inline(always)] fn fory_static_type_id() -> TypeId where diff --git a/rust/fory-core/src/serializer/box_.rs b/rust/fory-core/src/serializer/box_.rs index da7a378eb2..1ccd22eb65 100644 --- a/rust/fory-core/src/serializer/box_.rs +++ b/rust/fory-core/src/serializer/box_.rs @@ -29,6 +29,10 @@ impl Serializer for Box { where Self: Sized + ForyDefault, { + let graph_self_size = T::fory_graph_self_size(); + if graph_self_size != 0 { + context.reserve_graph_memory(graph_self_size)?; + } Ok(Box::new(T::fory_read_data(context)?)) } diff --git a/rust/fory-core/src/serializer/codec.rs b/rust/fory-core/src/serializer/codec.rs index f34c5b8fd9..302019c757 100644 --- a/rust/fory-core/src/serializer/codec.rs +++ b/rust/fory-core/src/serializer/codec.rs @@ -474,6 +474,11 @@ pub trait Codec: 'static { std::mem::size_of::() } + #[inline(always)] + fn graph_storage_size() -> usize { + std::mem::size_of::() + } + fn write_field(value: &T, context: &mut WriteContext) -> Result<(), Error>; fn read_field(context: &mut ReadContext) -> Result; @@ -615,6 +620,11 @@ where T::fory_reserved_space() + SIZE_OF_REF_AND_TYPE } + #[inline(always)] + fn graph_storage_size() -> usize { + T::fory_graph_storage_size() + } + #[inline(always)] fn write_field(value: &T, context: &mut WriteContext) -> Result<(), Error> { T::fory_write( @@ -1700,7 +1710,7 @@ where fn read_data(context: &mut ReadContext) -> Result, Error> { let len = context.reader.read_var_u32()?; - context.reserve_counted_container_memory(len, std::mem::size_of::())?; + context.reserve_counted_graph_memory(len, C::graph_storage_size())?; if len == 0 { return Ok(Vec::new()); } @@ -1729,7 +1739,7 @@ where remote_field_type: &FieldType, ) -> Result, Error> { let len = context.reader.read_var_u32()?; - context.reserve_counted_container_memory(len, std::mem::size_of::())?; + context.reserve_counted_graph_memory(len, C::graph_storage_size())?; if len == 0 { return Ok(Vec::new()); } @@ -2272,10 +2282,10 @@ where fn read_data(context: &mut ReadContext) -> Result, Error> { let len = context.reader.read_var_u32()?; - let elem_bytes = std::mem::size_of::() - .checked_add(std::mem::size_of::()) - .ok_or_else(|| Error::invalid_data("container memory estimate overflows"))?; - context.reserve_counted_container_memory(len, elem_bytes)?; + let elem_bytes = KC::graph_storage_size() + .checked_add(VC::graph_storage_size()) + .ok_or_else(|| Error::invalid_data("graph memory estimate overflows"))?; + context.reserve_counted_graph_memory(len, elem_bytes)?; if len == 0 { return Ok(HashMap::new()); } @@ -2295,10 +2305,10 @@ where remote_field_type: &FieldType, ) -> Result, Error> { let len = context.reader.read_var_u32()?; - let elem_bytes = std::mem::size_of::() - .checked_add(std::mem::size_of::()) - .ok_or_else(|| Error::invalid_data("container memory estimate overflows"))?; - let capacity = context.reserve_counted_container_memory(len, elem_bytes)?; + let elem_bytes = KC::graph_storage_size() + .checked_add(VC::graph_storage_size()) + .ok_or_else(|| Error::invalid_data("graph memory estimate overflows"))?; + let capacity = context.reserve_counted_graph_memory(len, elem_bytes)?; if len == 0 { return Ok(HashMap::new()); } diff --git a/rust/fory-core/src/serializer/collection.rs b/rust/fory-core/src/serializer/collection.rs index cdda3506bc..4b94f8ac32 100644 --- a/rust/fory-core/src/serializer/collection.rs +++ b/rust/fory-core/src/serializer/collection.rs @@ -239,7 +239,7 @@ where C: FromIterator, { let len = context.reader.read_var_u32()?; - let len_usize = context.reserve_counted_container_memory(len, std::mem::size_of::())?; + let len_usize = context.reserve_counted_graph_memory(len, T::fory_graph_storage_size())?; if len == 0 { return Ok(C::from_iter(std::iter::empty())); } @@ -258,7 +258,9 @@ where (header & IS_SAME_TYPE) != 0, Error::type_error("Type inconsistent, target type is not polymorphic") ); - context.reader.check_bound(len_usize)?; + if std::mem::size_of::() != 0 { + context.reader.check_bound(len_usize)?; + } if !has_null { (0..len) .map(|_| T::fory_read_data(context)) @@ -282,7 +284,7 @@ where T: Serializer + ForyDefault, { let len = context.reader.read_var_u32()?; - let len_usize = context.reserve_counted_container_memory(len, std::mem::size_of::())?; + let len_usize = context.reserve_counted_graph_memory(len, T::fory_graph_storage_size())?; if len == 0 { return Ok(Vec::new()); } @@ -299,7 +301,9 @@ where (header & IS_SAME_TYPE) != 0, Error::type_error("Type inconsistent, target type is not polymorphic") ); - context.reader.check_bound(len_usize)?; + if std::mem::size_of::() != 0 { + context.reader.check_bound(len_usize)?; + } let mut vec = Vec::with_capacity(len_usize); if !has_null { for _ in 0..len { @@ -729,7 +733,7 @@ where { let element_type = generic_field_type(remote_field_type, 0, "list")?; let len = context.reader.read_var_u32()?; - let len_usize = context.reserve_counted_container_memory(len, std::mem::size_of::())?; + let len_usize = context.reserve_counted_graph_memory(len, T::fory_graph_storage_size())?; if len == 0 { return Ok(Vec::new()); } diff --git a/rust/fory-core/src/serializer/core.rs b/rust/fory-core/src/serializer/core.rs index 2f155792d1..255f9c8a86 100644 --- a/rust/fory-core/src/serializer/core.rs +++ b/rust/fory-core/src/serializer/core.rs @@ -1048,6 +1048,31 @@ pub trait Serializer: 'static { Self::fory_is_shared_ref() } + /// Shallow value-owner self storage for root and box/shared-reference allocation sites. + /// + /// Value serializers must not reserve this unconditionally because parent structs, + /// arrays, maps, and collections may already own the inline storage. + #[inline(always)] + fn fory_graph_self_size() -> usize + where + Self: Sized, + { + 0 + } + + #[inline(always)] + fn fory_graph_storage_size() -> usize + where + Self: Sized, + { + let inline_size = std::mem::size_of::(); + if inline_size == 0 { + Self::fory_graph_self_size() + } else { + inline_size + } + } + /// Get the static Fory type ID for this type. /// /// Type IDs are Fory's internal type identification system, separate from diff --git a/rust/fory-core/src/serializer/heap.rs b/rust/fory-core/src/serializer/heap.rs index 7203c9545e..4c7b9483ba 100644 --- a/rust/fory-core/src/serializer/heap.rs +++ b/rust/fory-core/src/serializer/heap.rs @@ -58,6 +58,10 @@ impl Serializer for BinaryHeap { mem::size_of::() } + fn fory_graph_self_size() -> usize { + mem::size_of::() + } + fn fory_get_type_id(_: &TypeResolver) -> Result { Ok(TypeId::SET) } diff --git a/rust/fory-core/src/serializer/list.rs b/rust/fory-core/src/serializer/list.rs index 435ec97e9e..175ac41e21 100644 --- a/rust/fory-core/src/serializer/list.rs +++ b/rust/fory-core/src/serializer/list.rs @@ -154,6 +154,15 @@ impl Serializer for Vec { } } + #[inline(always)] + fn fory_graph_self_size() -> usize { + if is_primitive_type::() { + 0 + } else { + mem::size_of::() + } + } + #[inline(always)] fn fory_get_type_id(_: &TypeResolver) -> Result { let id = get_primitive_type_id::(); @@ -235,6 +244,11 @@ impl Serializer for VecDeque { mem::size_of::() } + #[inline(always)] + fn fory_graph_self_size() -> usize { + mem::size_of::() + } + #[inline(always)] fn fory_get_type_id(_: &TypeResolver) -> Result { Ok(TypeId::LIST) @@ -298,6 +312,11 @@ impl Serializer for LinkedList { mem::size_of::() } + #[inline(always)] + fn fory_graph_self_size() -> usize { + mem::size_of::() + } + #[inline(always)] fn fory_get_type_id(_: &TypeResolver) -> Result { Ok(TypeId::LIST) diff --git a/rust/fory-core/src/serializer/map.rs b/rust/fory-core/src/serializer/map.rs index b732f348df..8264d79dec 100644 --- a/rust/fory-core/src/serializer/map.rs +++ b/rust/fory-core/src/serializer/map.rs @@ -553,10 +553,10 @@ impl Result { let len = context.reader.read_var_u32()?; - let elem_bytes = std::mem::size_of::() - .checked_add(std::mem::size_of::()) - .ok_or_else(|| Error::invalid_data("container memory estimate overflows"))?; - let capacity = context.reserve_counted_container_memory(len, elem_bytes)?; + let elem_bytes = K::fory_graph_storage_size() + .checked_add(V::fory_graph_storage_size()) + .ok_or_else(|| Error::invalid_data("graph memory estimate overflows"))?; + let capacity = context.reserve_counted_graph_memory(len, elem_bytes)?; if len == 0 { return Ok(HashMap::new()); } @@ -657,6 +657,10 @@ impl() } + fn fory_graph_self_size() -> usize { + size_of::() + } + fn fory_get_type_id(_: &TypeResolver) -> Result { Ok(TypeId::MAP) } @@ -709,10 +713,10 @@ impl Result { let len = context.reader.read_var_u32()?; - let elem_bytes = std::mem::size_of::() - .checked_add(std::mem::size_of::()) - .ok_or_else(|| Error::invalid_data("container memory estimate overflows"))?; - let len_usize = context.reserve_counted_container_memory(len, elem_bytes)?; + let elem_bytes = K::fory_graph_storage_size() + .checked_add(V::fory_graph_storage_size()) + .ok_or_else(|| Error::invalid_data("graph memory estimate overflows"))?; + let len_usize = context.reserve_counted_graph_memory(len, elem_bytes)?; if len == 0 { return Ok(BTreeMap::new()); } @@ -812,6 +816,10 @@ impl() } + fn fory_graph_self_size() -> usize { + size_of::() + } + fn fory_get_type_id(_: &TypeResolver) -> Result { Ok(TypeId::MAP) } diff --git a/rust/fory-core/src/serializer/rc.rs b/rust/fory-core/src/serializer/rc.rs index 7c0f46861e..eb8522d503 100644 --- a/rust/fory-core/src/serializer/rc.rs +++ b/rust/fory-core/src/serializer/rc.rs @@ -205,6 +205,10 @@ fn read_rc_inner( ) -> Result { // Read type info if needed, then read data directly // No recursive ref handling needed since Rc only wraps allowed types + let graph_self_size = T::fory_graph_self_size(); + if graph_self_size != 0 { + context.reserve_graph_memory(graph_self_size)?; + } if let Some(typeinfo) = typeinfo { return T::fory_read_with_type_info(context, RefMode::None, typeinfo); } diff --git a/rust/fory-core/src/serializer/set.rs b/rust/fory-core/src/serializer/set.rs index d49bc5559f..856fd66c9a 100644 --- a/rust/fory-core/src/serializer/set.rs +++ b/rust/fory-core/src/serializer/set.rs @@ -58,6 +58,10 @@ impl Serializer for HashSet< mem::size_of::() } + fn fory_graph_self_size() -> usize { + mem::size_of::() + } + fn fory_get_type_id(_: &TypeResolver) -> Result { Ok(TypeId::SET) } @@ -113,6 +117,10 @@ impl Serializer for BTreeSet { mem::size_of::() } + fn fory_graph_self_size() -> usize { + mem::size_of::() + } + fn fory_get_type_id(_: &TypeResolver) -> Result { Ok(TypeId::SET) } diff --git a/rust/fory-core/src/serializer/tuple.rs b/rust/fory-core/src/serializer/tuple.rs index bbda7383b2..6dc730a0b9 100644 --- a/rust/fory-core/src/serializer/tuple.rs +++ b/rust/fory-core/src/serializer/tuple.rs @@ -181,6 +181,11 @@ impl Serializer for (T0,) { mem::size_of::() } + #[inline(always)] + fn fory_graph_self_size() -> usize { + mem::size_of::() + } + #[inline(always)] fn fory_get_type_id(_: &TypeResolver) -> Result { Ok(TypeId::LIST) @@ -449,6 +454,11 @@ macro_rules! impl_tuple_serializer { mem::size_of::() // Size for length } + #[inline(always)] + fn fory_graph_self_size() -> usize { + mem::size_of::() + } + #[inline(always)] fn fory_get_type_id(_: &TypeResolver) -> Result { Ok(TypeId::LIST) diff --git a/rust/fory-derive/src/object/serializer.rs b/rust/fory-derive/src/object/serializer.rs index 9624ebce10..922b32cf34 100644 --- a/rust/fory-derive/src/object/serializer.rs +++ b/rust/fory-derive/src/object/serializer.rs @@ -119,6 +119,7 @@ pub fn derive_serializer( read_type_info_ts, reserved_space_ts, static_type_id_ts, + graph_self_size_ts, ) = match &ast.data { syn::Data::Struct(s) => { let source_fields = source_fields(&s.fields); @@ -132,6 +133,10 @@ pub fn derive_serializer( read::gen_read_type_info(), write::gen_reserved_space(&source_fields), quote! { fory_core::TypeId::STRUCT }, + quote! { + let bytes = ::std::mem::size_of::(); + if bytes == 0 { 1 } else { bytes } + }, ) } syn::Data::Enum(e) => ( @@ -144,6 +149,7 @@ pub fn derive_serializer( derive_enum::gen_read_type_info(e), derive_enum::gen_reserved_space(), derive_enum::gen_static_type_id(e), + quote! { 0 }, ), syn::Data::Union(_) => { panic!("Union is not supported") @@ -226,6 +232,14 @@ pub fn derive_serializer( #reserved_space_ts } + #[inline(always)] + fn fory_graph_self_size() -> usize + where + Self: Sized, + { + #graph_self_size_ts + } + #[inline(always)] fn fory_write(&self, context: &mut fory_core::WriteContext, ref_mode: fory_core::RefMode, write_type_info: bool, _: bool) -> ::std::result::Result<(), fory_core::error::Error> { #write_ts @@ -289,6 +303,10 @@ fn generate_send_sync_tokens(ast: &syn::DeriveInput) -> SendSyncTokens { context: &mut fory_core::ReadContext, type_info: ::std::rc::Rc, ) -> ::std::result::Result<::std::boxed::Box, fory_core::error::Error> { + let graph_self_size = ::fory_graph_self_size(); + if graph_self_size != 0 { + context.reserve_graph_memory(graph_self_size)?; + } let value = ::fory_read_compatible(context, type_info)?; ::std::result::Result::Ok(fory_core::serializer::box_send_sync(value)) } @@ -305,6 +323,10 @@ fn generate_send_sync_tokens(ast: &syn::DeriveInput) -> SendSyncTokens { where Self: Sized + fory_core::ForyDefault, { + let graph_self_size = ::fory_graph_self_size(); + if graph_self_size != 0 { + context.reserve_graph_memory(graph_self_size)?; + } let value = ::fory_read_data(context)?; ::std::result::Result::Ok(fory_core::serializer::box_send_sync(value)) } diff --git a/rust/tests/tests/mod.rs b/rust/tests/tests/mod.rs index 74f62c87d3..1a3c78828f 100644 --- a/rust/tests/tests/mod.rs +++ b/rust/tests/tests/mod.rs @@ -18,7 +18,7 @@ mod compatible; mod test_any; mod test_collection; -mod test_container_memory_budget; mod test_field_meta; +mod test_graph_memory_budget; mod test_max_dyn_depth; mod test_tuple; diff --git a/rust/tests/tests/test_container_memory_budget.rs b/rust/tests/tests/test_graph_memory_budget.rs similarity index 67% rename from rust/tests/tests/test_container_memory_budget.rs rename to rust/tests/tests/test_graph_memory_budget.rs index 8759ee1d82..ef0bb7474b 100644 --- a/rust/tests/tests/test_container_memory_budget.rs +++ b/rust/tests/tests/test_graph_memory_budget.rs @@ -18,6 +18,7 @@ use fory_core::{Error, Fory, Reader}; use fory_derive::ForyStruct; use std::collections::HashMap; +use std::mem; use std::panic; #[derive(ForyStruct, Debug, PartialEq)] @@ -32,6 +33,9 @@ struct BudgetItem { right: u64, } +#[derive(ForyStruct, Debug, PartialEq)] +struct BudgetEmpty; + #[derive(ForyStruct, Debug)] struct ListWireInts { values: Vec>, @@ -42,26 +46,27 @@ struct DenseWireInts { values: Vec, } -fn fory_with_budget(max_container_memory_bytes: i64) -> Fory { +fn fory_with_budget(max_graph_memory_bytes: i64) -> Fory { let mut fory = Fory::builder() .xlang(false) .compatible(false) - .max_container_memory_bytes(max_container_memory_bytes) + .max_graph_memory_bytes(max_graph_memory_bytes) .build(); fory.register_by_name::("BudgetSiblings") .unwrap(); fory.register_by_name::("BudgetItem").unwrap(); + fory.register_by_name::("BudgetEmpty").unwrap(); fory } -fn compatible_fory(max_container_memory_bytes: i64) -> Fory +fn compatible_fory(max_graph_memory_bytes: i64) -> Fory where T: fory_core::Serializer + fory_core::StructSerializer + fory_core::ForyDefault, { let mut fory = Fory::builder() .xlang(false) .compatible(true) - .max_container_memory_bytes(max_container_memory_bytes) + .max_graph_memory_bytes(max_graph_memory_bytes) .build(); fory.register::(88_001).unwrap(); fory @@ -74,7 +79,7 @@ fn compact_empty_lists(count: usize) -> Vec> { fn assert_budget_error(err: Error, effective_limit: usize) { let message = err.to_string(); assert!( - message.contains("estimated container memory request"), + message.contains("estimated graph memory request"), "{message}" ); assert!( @@ -85,10 +90,10 @@ fn assert_budget_error(err: Error, effective_limit: usize) { #[test] fn config_validation() { - assert!(panic::catch_unwind(|| Fory::builder().max_container_memory_bytes(0)).is_err()); - assert!(panic::catch_unwind(|| Fory::builder().max_container_memory_bytes(-2)).is_err()); - let _ = Fory::builder().max_container_memory_bytes(-1).build(); - let _ = Fory::builder().max_container_memory_bytes(1).build(); + assert!(panic::catch_unwind(|| Fory::builder().max_graph_memory_bytes(0)).is_err()); + assert!(panic::catch_unwind(|| Fory::builder().max_graph_memory_bytes(-2)).is_err()); + let _ = Fory::builder().max_graph_memory_bytes(-1).build(); + let _ = Fory::builder().max_graph_memory_bytes(1).build(); } #[test] @@ -129,24 +134,54 @@ fn explicit_override() { let bytes = writer.serialize(&value).unwrap(); assert!(writer.deserialize::>>(&bytes).is_err()); - let vec_bytes = std::mem::size_of::>(); - let estimate = value.len() * vec_bytes; + let vec_bytes = mem::size_of::>(); + let estimate = mem::size_of::>>() + value.len() * vec_bytes; let explicit = fory_with_budget(estimate as i64); let decoded: Vec> = explicit.deserialize(&bytes).unwrap(); assert_eq!(decoded, value); } #[test] -fn empty_container_has_no_dynamic_storage() { +fn empty_collection_owner_self() { let value: Vec = Vec::new(); let writer = fory_with_budget(-1); let bytes = writer.serialize(&value).unwrap(); - let limited = fory_with_budget(1); + let limited = fory_with_budget((mem::size_of::>() - 1) as i64); + assert!(limited.deserialize::>(&bytes).is_err()); + + let limited = fory_with_budget(mem::size_of::>() as i64); let decoded: Vec = limited.deserialize(&bytes).unwrap(); assert!(decoded.is_empty()); } +#[test] +fn empty_struct_owner_self() { + let value = BudgetEmpty; + let writer = fory_with_budget(-1); + let bytes = writer.serialize(&value).unwrap(); + + assert_eq!( + fory_with_budget(1) + .deserialize::(&bytes) + .unwrap(), + value + ); + + let values = vec![BudgetEmpty, BudgetEmpty, BudgetEmpty]; + let bytes = writer.serialize(&values).unwrap(); + let required = mem::size_of::>() + values.len(); + assert!(fory_with_budget((required - 1) as i64) + .deserialize::>(&bytes) + .is_err()); + assert_eq!( + fory_with_budget(required as i64) + .deserialize::>(&bytes) + .unwrap(), + values + ); +} + #[test] fn sibling_cumulative_budget() { let value = BudgetSiblings { @@ -155,11 +190,12 @@ fn sibling_cumulative_budget() { }; let writer = fory_with_budget(-1); let bytes = writer.serialize(&value).unwrap(); - let one_vec = std::mem::size_of::() as i64; + let root = mem::size_of::() as i64; + let one_vec = mem::size_of::() as i64; - let limited = fory_with_budget(one_vec); + let limited = fory_with_budget(root + one_vec); assert!(limited.deserialize::(&bytes).is_err()); - let enough = fory_with_budget(one_vec * 2); + let enough = fory_with_budget(root + one_vec * 2); assert_eq!(enough.deserialize::(&bytes).unwrap(), value); } @@ -168,7 +204,9 @@ fn map_budget() { let value: HashMap = HashMap::from([("a".to_string(), 1)]); let writer = fory_with_budget(-1); let bytes = writer.serialize(&value).unwrap(); - let required = (std::mem::size_of::() + std::mem::size_of::()) as i64; + let required = (mem::size_of::>() + + mem::size_of::() + + mem::size_of::()) as i64; let limited = fory_with_budget(required - 1); assert!(limited.deserialize::>(&bytes).is_err()); @@ -190,12 +228,37 @@ fn inline_value_vec_budget() { .collect::>(); let writer = fory_with_budget(-1); let bytes = writer.serialize(&value).unwrap(); - let under_inline = value.len() * std::mem::size_of::(); + let under_inline = mem::size_of::>() + value.len() * mem::size_of::(); let limited = fory_with_budget(under_inline as i64); assert!(limited.deserialize::>(&bytes).is_err()); } +#[test] +fn box_vector_owner_self() { + let value = Box::new( + (0..4) + .map(|i| BudgetItem { + left: i, + right: i + 1, + }) + .collect::>(), + ); + let writer = fory_with_budget(-1); + let bytes = writer.serialize(&value).unwrap(); + let required = mem::size_of::>() + value.len() * mem::size_of::(); + + assert!(fory_with_budget((required - 1) as i64) + .deserialize::>>(&bytes) + .is_err()); + assert_eq!( + fory_with_budget(required as i64) + .deserialize::>>(&bytes) + .unwrap(), + value + ); +} + #[test] fn compatible_list_array_budget() { let value = ListWireInts { @@ -204,10 +267,11 @@ fn compatible_list_array_budget() { let writer = compatible_fory::(-1); let bytes = writer.serialize(&value).unwrap(); - let limited = compatible_fory::((64 * std::mem::size_of::() - 1) as i64); + let required = mem::size_of::() + 64 * mem::size_of::(); + let limited = compatible_fory::((required - 1) as i64); assert!(limited.deserialize::(&bytes).is_err()); - let enough = compatible_fory::(i64::MAX); + let enough = compatible_fory::(required as i64); let decoded = enough.deserialize::(&bytes).unwrap(); assert_eq!( decoded, diff --git a/scala/src/main/scala-3/org/apache/fory/scala/internal/ForySerializerMacros.scala b/scala/src/main/scala-3/org/apache/fory/scala/internal/ForySerializerMacros.scala index 025ff14c82..c20ee9e2b0 100644 --- a/scala/src/main/scala-3/org/apache/fory/scala/internal/ForySerializerMacros.scala +++ b/scala/src/main/scala-3/org/apache/fory/scala/internal/ForySerializerMacros.scala @@ -109,6 +109,21 @@ object ForySerializerMacros { annotations.foldRight(boxed)((annotation, current) => AnnotatedType(current, annotation)) } + def graphFieldBytes(tpe: TypeRepr): Long = { + val base = peelAnnotations(tpe.widen)._1.dealias + if base =:= TypeRepr.of[Boolean] then 1L + else if base =:= TypeRepr.of[Byte] then 1L + else if base =:= TypeRepr.of[Char] then 2L + else if base =:= TypeRepr.of[Short] then 2L + else if base =:= TypeRepr.of[Int] then 4L + else if base =:= TypeRepr.of[Float] then 4L + else if base =:= TypeRepr.of[Long] then 8L + else if base =:= TypeRepr.of[Double] then 8L + else 4L + } + + val objectGraphMemoryBytes: Long = 1L + fields.map(field => graphFieldBytes(field.sourceType)).sum + def classFor(tpe: TypeRepr): Expr[Class[?]] = { val normalized = peelAnnotations(tpe.widen)._1.dealias val fullName = normalized.typeSymbol.fullName @@ -1129,7 +1144,8 @@ object ForySerializerMacros { Block( localDefs ++ maskDefs, Block( - readLoop.asTerm :: defaultAssignments.toList, + '{ $readContextExpr.reserveGraphMemory(${ Expr(objectGraphMemoryBytes) }) }.asTerm :: + readLoop.asTerm :: defaultAssignments.toList, constructFromLocals(localFields, instantiatorExpr, fieldAccessorsExpr).asTerm)) .asExprOf[T] } @@ -1151,6 +1167,7 @@ object ForySerializerMacros { if $resolverExpr.checkClassVersion() then { $serializerExpr.checkClassVersion(buffer.readInt32(), $classVersionHashExpr) } + $readContextExpr.reserveGraphMemory(${ Expr(objectGraphMemoryBytes) }) val obj = $instantiatorExpr.newInstance() $readContextExpr.reference(obj) var i = 0 @@ -1176,6 +1193,7 @@ object ForySerializerMacros { if $resolverExpr.checkClassVersion() then { $serializerExpr.checkClassVersion(buffer.readInt32(), $classVersionHashExpr) } + $readContextExpr.reserveGraphMemory(${ Expr(objectGraphMemoryBytes) }) val values = new Array[Any]($descriptorsExpr.size()) var i = 0 while i < $allFieldsExpr.length do { @@ -1204,6 +1222,7 @@ object ForySerializerMacros { if $sameSchemaCompatibleExpr then { $serializerExpr.read($readContextExpr) } else { + $readContextExpr.reserveGraphMemory(${ Expr(objectGraphMemoryBytes) }) val obj = $instantiatorExpr.newInstance() $readContextExpr.reference(obj) val remoteFields = $serializerExpr.getRemoteFields() diff --git a/scala/src/main/scala/org/apache/fory/serializer/scala/CollectionSerializer.scala b/scala/src/main/scala/org/apache/fory/serializer/scala/CollectionSerializer.scala index 8b688bcb37..6bfcefaa9b 100644 --- a/scala/src/main/scala/org/apache/fory/serializer/scala/CollectionSerializer.scala +++ b/scala/src/main/scala/org/apache/fory/serializer/scala/CollectionSerializer.scala @@ -53,7 +53,7 @@ abstract class AbstractScalaCollectionSerializer[A, T <: Iterable[A]]( value: T): util.Collection[_] override def newCollection(readContext: ReadContext): util.Collection[_] = { - val numElements = readCollectionSize(readContext) + val numElements = readCollectionSize(readContext, readContext.getBuffer) setNumElements(numElements) val factory = readContext.readRef().asInstanceOf[Factory[A, T]] val builder = factory.newBuilder diff --git a/scala/src/main/scala/org/apache/fory/serializer/scala/MapSerializer.scala b/scala/src/main/scala/org/apache/fory/serializer/scala/MapSerializer.scala index 3891361615..145c4d47e8 100644 --- a/scala/src/main/scala/org/apache/fory/serializer/scala/MapSerializer.scala +++ b/scala/src/main/scala/org/apache/fory/serializer/scala/MapSerializer.scala @@ -50,7 +50,7 @@ abstract class AbstractScalaMapSerializer[K, V, T](typeResolver: TypeResolver, c def onMapWrite(writeContext: WriteContext, value: T): util.Map[_, _] override def newMap(readContext: ReadContext): util.Map[_, _] = { - val numElements = readMapSize(readContext) + val numElements = readMapSize(readContext, readContext.getBuffer) setNumElements(numElements) val factory = readContext.readRef().asInstanceOf[Factory[(K, V), T]] val builder = factory.newBuilder diff --git a/scala/src/main/scala/org/apache/fory/serializer/scala/XlangCollectionSerializer.scala b/scala/src/main/scala/org/apache/fory/serializer/scala/XlangCollectionSerializer.scala index 9439f3493e..4e34d77a20 100644 --- a/scala/src/main/scala/org/apache/fory/serializer/scala/XlangCollectionSerializer.scala +++ b/scala/src/main/scala/org/apache/fory/serializer/scala/XlangCollectionSerializer.scala @@ -43,7 +43,7 @@ abstract class AbstractScalaXlangCollectionSerializer[A, T <: scala.collection.I } override def newCollection(readContext: ReadContext): util.Collection[_] = { - val numElements = readCollectionSize(readContext) + val numElements = readCollectionSize(readContext, readContext.getBuffer) setNumElements(numElements) val builder = newBuilder(numElements) if (ScalaXlangCollectionShape.hasOptionElement(readContext)) { @@ -364,7 +364,7 @@ abstract class AbstractScalaXlangMapSerializer[K, V, T <: scala.collection.Map[K } override def newMap(readContext: ReadContext): util.Map[_, _] = { - val numElements = readMapSize(readContext) + val numElements = readMapSize(readContext, readContext.getBuffer) setNumElements(numElements) val builder = ScalaXlangCollectionShape.mapBuilder[K, V, T](cls, numElements) diff --git a/scala/src/test/scala/org/apache/fory/serializer/scala/CollectionSerializerTest.scala b/scala/src/test/scala/org/apache/fory/serializer/scala/CollectionSerializerTest.scala index cdae015153..085d0cc5a9 100644 --- a/scala/src/test/scala/org/apache/fory/serializer/scala/CollectionSerializerTest.scala +++ b/scala/src/test/scala/org/apache/fory/serializer/scala/CollectionSerializerTest.scala @@ -91,23 +91,23 @@ class CollectionSerializerTest extends AnyWordSpec with Matchers { } } - "fory scala container memory budget" should { - def runtime(maxContainerMemoryBytes: Long = -1): Fory = { + "fory scala graph memory budget" should { + def runtime(maxGraphMemoryBytes: Long = -1): Fory = { val builder = ForyScala.builder() .withXlang(false) .withRefTracking(true) .requireClassRegistration(false) .suppressClassRegistrationWarnings(false) .withSerializerFactory(new ScalaSerializerFactory()) - if (maxContainerMemoryBytes > 0) { - builder.withMaxContainerMemoryBytes(maxContainerMemoryBytes) + if (maxGraphMemoryBytes > 0) { + builder.withMaxGraphMemoryBytes(maxGraphMemoryBytes) } builder.build() } "reserve scala collection storage" in { val writer = runtime() - val reader = runtime(maxContainerMemoryBytes = 23) + val reader = runtime(maxGraphMemoryBytes = 23) intercept[InsecureException] { reader.deserialize(writer.serialize(List.fill(6)("v"))) } @@ -115,7 +115,7 @@ class CollectionSerializerTest extends AnyWordSpec with Matchers { "reserve scala map storage" in { val writer = runtime() - val reader = runtime(maxContainerMemoryBytes = 23) + val reader = runtime(maxGraphMemoryBytes = 23) intercept[InsecureException] { reader.deserialize(writer.serialize(Map("a" -> 1, "b" -> 2, "c" -> 3))) } diff --git a/scala/src/test/scala/org/apache/fory/serializer/scala/ScalaXlangSerializerTest.scala b/scala/src/test/scala/org/apache/fory/serializer/scala/ScalaXlangSerializerTest.scala index f8f979cd6c..6269061555 100644 --- a/scala/src/test/scala/org/apache/fory/serializer/scala/ScalaXlangSerializerTest.scala +++ b/scala/src/test/scala/org/apache/fory/serializer/scala/ScalaXlangSerializerTest.scala @@ -122,7 +122,7 @@ class ScalaXlangSerializerTest extends AnyWordSpec with Matchers { copiedCyclic(0) shouldBe theSameInstanceAs(copiedCyclic) } - "enforce container memory budget" in { + "enforce graph memory budget" in { val writer = fory val reader = ForyScala.builder() .withXlang(true) @@ -130,7 +130,7 @@ class ScalaXlangSerializerTest extends AnyWordSpec with Matchers { .withRefCopy(true) .requireClassRegistration(false) .suppressClassRegistrationWarnings(false) - .withMaxContainerMemoryBytes(23) + .withMaxGraphMemoryBytes(23) .build() intercept[InsecureException] { diff --git a/swift/Sources/Fory/AnySerializer.swift b/swift/Sources/Fory/AnySerializer.swift index 9a436ddce4..eca738293e 100644 --- a/swift/Sources/Fory/AnySerializer.swift +++ b/swift/Sources/Fory/AnySerializer.swift @@ -18,729 +18,753 @@ import Foundation private let anyReferenceBytes = 4 +private let anyArrayOwnerBytes = max(1, MemoryLayout<[Any]>.stride) @inline(__always) private func reserveAnyReferenceArrayMemory(_ context: ReadContext, count: Int) throws { - try context.reserveCountedContainerMemory(count: count, elementBytes: anyReferenceBytes) + let (slotBytes, overflow) = count.multipliedReportingOverflow(by: anyReferenceBytes) + if overflow { + try context.reserveGraphMemory(-1) + } + let (bytes, addOverflow) = anyArrayOwnerBytes.addingReportingOverflow(slotBytes) + if addOverflow { + try context.reserveGraphMemory(-1) + } + try context.reserveGraphMemory(bytes) } @inline(__always) -private func reserveAnyReferenceMapMemory(_ context: ReadContext, count: Int) throws { - try context.reserveCountedContainerMemory(count: count, elementBytes: 2 * anyReferenceBytes) +private func reserveAnyReferenceMapMemory(_ context: ReadContext, _ type: Map.Type, count: Int) + throws +{ + let (slotBytes, overflow) = count.multipliedReportingOverflow(by: 2 * anyReferenceBytes) + if overflow { + try context.reserveGraphMemory(-1) + } + let ownerBytes = max(1, MemoryLayout.stride) + let (bytes, addOverflow) = ownerBytes.addingReportingOverflow(slotBytes) + if addOverflow { + try context.reserveGraphMemory(-1) + } + try context.reserveGraphMemory(bytes) } public struct ForyAnyNullValue: Serializer { - public init() {} + public init() {} - public static func foryDefault() -> ForyAnyNullValue { - ForyAnyNullValue() - } + public static func foryDefault() -> ForyAnyNullValue { + ForyAnyNullValue() + } - public static var staticTypeId: TypeId { - .none - } + public static var staticTypeId: TypeId { + .none + } - public static func foryWriteStaticTypeInfo(_ context: WriteContext) throws { - context.writeStaticTypeInfo(staticTypeId) - } + public static func foryWriteStaticTypeInfo(_ context: WriteContext) throws { + context.writeStaticTypeInfo(staticTypeId) + } - public static func foryReadTypeInfo(_ context: ReadContext) throws -> TypeInfo? { - try context.readStaticTypeInfo(staticTypeId) - } + public static func foryReadTypeInfo(_ context: ReadContext) throws -> TypeInfo? { + try context.readStaticTypeInfo(staticTypeId) + } - public var foryIsNone: Bool { - true - } + public var foryIsNone: Bool { + true + } - public func foryWriteData(_ context: WriteContext, hasGenerics: Bool) throws { - _ = context - _ = hasGenerics - } + public func foryWriteData(_ context: WriteContext, hasGenerics: Bool) throws { + _ = context + _ = hasGenerics + } - public static func foryReadData(_ context: ReadContext) throws -> ForyAnyNullValue { - _ = context - return ForyAnyNullValue() - } + public static func foryReadData(_ context: ReadContext) throws -> ForyAnyNullValue { + _ = context + return ForyAnyNullValue() + } } extension AnyHashable: Serializer { - public static func foryDefault() -> AnyHashable { - AnyHashable(Int32(0)) - } - - public static var staticTypeId: TypeId { - .unknown - } - - public func foryWriteData(_ context: WriteContext, hasGenerics: Bool) throws { - try writeAnyPayload(base, context: context, hasGenerics: hasGenerics) - } - - public static func foryReadData(_ context: ReadContext) throws -> AnyHashable { - _ = context - throw ForyError.invalidData( - "dynamic AnyHashable key read requires type info; foryReadData should not be called directly" - ) - } - - public static func foryReadCompatibleData(_ context: ReadContext, remoteTypeInfo: TypeInfo) throws -> AnyHashable { - let typeInfo = remoteTypeInfo - if typeInfo.typeID == .none { - throw ForyError.invalidData("dynamic AnyHashable key cannot be null") - } - let decoded = try context.readAnyValue(typeInfo: typeInfo) - return try toAnyHashableKey(decoded) - } - - public static func foryWriteStaticTypeInfo(_ context: WriteContext) throws { - _ = context - throw ForyError.invalidData("dynamic AnyHashable key type info is runtime-only") - } - - public func foryWriteTypeInfo(_ context: WriteContext) throws { - try writeAnyTypeInfo(base, context: context) - } - - public static func foryReadTypeInfo(_ context: ReadContext) throws -> TypeInfo? { - try context.readTypeInfo() + public static func foryDefault() -> AnyHashable { + AnyHashable(Int32(0)) + } + + public static var staticTypeId: TypeId { + .unknown + } + + public func foryWriteData(_ context: WriteContext, hasGenerics: Bool) throws { + try writeAnyPayload(base, context: context, hasGenerics: hasGenerics) + } + + public static func foryReadData(_ context: ReadContext) throws -> AnyHashable { + _ = context + throw ForyError.invalidData( + "dynamic AnyHashable key read requires type info; foryReadData should not be called directly" + ) + } + + public static func foryReadCompatibleData(_ context: ReadContext, remoteTypeInfo: TypeInfo) throws + -> AnyHashable + { + let typeInfo = remoteTypeInfo + if typeInfo.typeID == .none { + throw ForyError.invalidData("dynamic AnyHashable key cannot be null") + } + let decoded = try context.readAnyValue(typeInfo: typeInfo) + return try toAnyHashableKey(decoded) + } + + public static func foryWriteStaticTypeInfo(_ context: WriteContext) throws { + _ = context + throw ForyError.invalidData("dynamic AnyHashable key type info is runtime-only") + } + + public func foryWriteTypeInfo(_ context: WriteContext) throws { + try writeAnyTypeInfo(base, context: context) + } + + public static func foryReadTypeInfo(_ context: ReadContext) throws -> TypeInfo? { + try context.readTypeInfo() + } + + public func foryWrite( + _ context: WriteContext, + refMode: RefMode, + writeTypeInfo: Bool, + hasGenerics: Bool + ) throws { + if refMode != .none { + context.buffer.writeInt8(RefFlag.notNullValue.rawValue) } - - public func foryWrite( - _ context: WriteContext, - refMode: RefMode, - writeTypeInfo: Bool, - hasGenerics: Bool - ) throws { - if refMode != .none { - context.buffer.writeInt8(RefFlag.notNullValue.rawValue) - } - if writeTypeInfo { - try foryWriteTypeInfo(context) - } - try foryWriteData(context, hasGenerics: hasGenerics) + if writeTypeInfo { + try foryWriteTypeInfo(context) } + try foryWriteData(context, hasGenerics: hasGenerics) + } } private protocol OptionalTypeMarker { - static var noneValue: Self { get } + static var noneValue: Self { get } } extension Optional: OptionalTypeMarker { - static var noneValue: Wrapped? { nil } + static var noneValue: Wrapped? { nil } } struct SerializableAny: Serializer { - var value: Any = ForyAnyNullValue() + var value: Any = ForyAnyNullValue() - init(_ value: Any) { - self.value = value - } + init(_ value: Any) { + self.value = value + } - static func foryDefault() -> SerializableAny { - SerializableAny(ForyAnyNullValue()) - } + static func foryDefault() -> SerializableAny { + SerializableAny(ForyAnyNullValue()) + } - static var staticTypeId: TypeId { - .unknown - } + static var staticTypeId: TypeId { + .unknown + } - static var isNullableType: Bool { - true - } + static var isNullableType: Bool { + true + } - static var isRefType: Bool { - true - } + static var isRefType: Bool { + true + } - var foryIsNone: Bool { - value is ForyAnyNullValue - } + var foryIsNone: Bool { + value is ForyAnyNullValue + } - static func wrapped(_ value: Any?) -> SerializableAny { - guard let value else { - return .foryDefault() - } - guard let unwrapped = unwrapOptionalAny(value) else { - return .foryDefault() - } - if unwrapped is NSNull { - return .foryDefault() - } - return SerializableAny(unwrapped) + static func wrapped(_ value: Any?) -> SerializableAny { + guard let value else { + return .foryDefault() } - - func anyValue() -> Any? { - foryIsNone ? nil : value + guard let unwrapped = unwrapOptionalAny(value) else { + return .foryDefault() } - - func anyValueForCollection() -> Any { - foryIsNone ? NSNull() : value + if unwrapped is NSNull { + return .foryDefault() } + return SerializableAny(unwrapped) + } - func foryWriteData(_ context: WriteContext, hasGenerics: Bool) throws { - if foryIsNone { - return - } - try writeAnyPayload(value, context: context, hasGenerics: hasGenerics) - } + func anyValue() -> Any? { + foryIsNone ? nil : value + } + + func anyValueForCollection() -> Any { + foryIsNone ? NSNull() : value + } - static func foryReadData(_ context: ReadContext) throws -> SerializableAny { - _ = context - throw ForyError.invalidData( - "dynamic Any read requires type info; foryReadData should not be called directly" - ) + func foryWriteData(_ context: WriteContext, hasGenerics: Bool) throws { + if foryIsNone { + return } + try writeAnyPayload(value, context: context, hasGenerics: hasGenerics) + } - static func foryReadCompatibleData(_ context: ReadContext, remoteTypeInfo: TypeInfo) throws -> SerializableAny { - let typeInfo = remoteTypeInfo - if typeInfo.typeID == .none { - return .foryDefault() - } - return SerializableAny(try context.readAnyValue(typeInfo: typeInfo)) + static func foryReadData(_ context: ReadContext) throws -> SerializableAny { + _ = context + throw ForyError.invalidData( + "dynamic Any read requires type info; foryReadData should not be called directly" + ) + } + + static func foryReadCompatibleData(_ context: ReadContext, remoteTypeInfo: TypeInfo) throws + -> SerializableAny + { + let typeInfo = remoteTypeInfo + if typeInfo.typeID == .none { + return .foryDefault() } + return SerializableAny(try context.readAnyValue(typeInfo: typeInfo)) + } - static func foryWriteStaticTypeInfo(_ context: WriteContext) throws { - _ = context - throw ForyError.invalidData("dynamic Any value type info is runtime-only") + static func foryWriteStaticTypeInfo(_ context: WriteContext) throws { + _ = context + throw ForyError.invalidData("dynamic Any value type info is runtime-only") + } + + func foryWriteTypeInfo(_ context: WriteContext) throws { + if foryIsNone { + context.buffer.writeUInt8(UInt8(truncatingIfNeeded: TypeId.none.rawValue)) + return } + try writeAnyTypeInfo(value, context: context) + } - func foryWriteTypeInfo(_ context: WriteContext) throws { - if foryIsNone { - context.buffer.writeUInt8(UInt8(truncatingIfNeeded: TypeId.none.rawValue)) - return - } - try writeAnyTypeInfo(value, context: context) + static func foryReadTypeInfo(_ context: ReadContext) throws -> TypeInfo? { + try context.readTypeInfo() + } + + func foryWrite( + _ context: WriteContext, + refMode: RefMode, + writeTypeInfo: Bool, + hasGenerics: Bool + ) throws { + if refMode != .none { + if foryIsNone { + context.buffer.writeInt8(RefFlag.null.rawValue) + return + } + context.buffer.writeInt8(RefFlag.notNullValue.rawValue) } - static func foryReadTypeInfo(_ context: ReadContext) throws -> TypeInfo? { - try context.readTypeInfo() + if writeTypeInfo { + try foryWriteTypeInfo(context) } + try foryWriteData(context, hasGenerics: hasGenerics) + } - func foryWrite( - _ context: WriteContext, - refMode: RefMode, - writeTypeInfo: Bool, - hasGenerics: Bool - ) throws { - if refMode != .none { - if foryIsNone { - context.buffer.writeInt8(RefFlag.null.rawValue) - return - } - context.buffer.writeInt8(RefFlag.notNullValue.rawValue) + static func foryRead( + _ context: ReadContext, + refMode: RefMode, + readTypeInfo: Bool + ) throws -> SerializableAny { + @inline(__always) + func requireDynamicTypeInfo() throws -> TypeInfo { + if readTypeInfo { + guard let remoteTypeInfo = try foryReadTypeInfo(context) else { + throw ForyError.invalidData("dynamic Any value requires type info") } - - if writeTypeInfo { - try foryWriteTypeInfo(context) + return remoteTypeInfo + } + guard let remoteTypeInfo = context.getTypeInfo(for: Self.self) else { + throw ForyError.invalidData("dynamic Any value requires type info") + } + return remoteTypeInfo + } + + if refMode != .none { + let rawFlag = try context.buffer.readInt8() + guard let flag = RefFlag(rawValue: rawFlag) else { + throw ForyError.refError("invalid ref flag \(rawFlag)") + } + + switch flag { + case .null: + return .foryDefault() + case .ref: + let refID = try context.buffer.readVarUInt32() + let referenced = try context.refReader.readRefValue(refID) + if let value = referenced as? SerializableAny { + return value } - try foryWriteData(context, hasGenerics: hasGenerics) - } - - static func foryRead( - _ context: ReadContext, - refMode: RefMode, - readTypeInfo: Bool - ) throws -> SerializableAny { - @inline(__always) - func requireDynamicTypeInfo() throws -> TypeInfo { - if readTypeInfo { - guard let remoteTypeInfo = try foryReadTypeInfo(context) else { - throw ForyError.invalidData("dynamic Any value requires type info") - } - return remoteTypeInfo - } - guard let remoteTypeInfo = context.getTypeInfo(for: Self.self) else { - throw ForyError.invalidData("dynamic Any value requires type info") - } - return remoteTypeInfo + if referenced is NSNull { + return .foryDefault() } - - if refMode != .none { - let rawFlag = try context.buffer.readInt8() - guard let flag = RefFlag(rawValue: rawFlag) else { - throw ForyError.refError("invalid ref flag \(rawFlag)") - } - - switch flag { - case .null: - return .foryDefault() - case .ref: - let refID = try context.buffer.readVarUInt32() - let referenced = try context.refReader.readRefValue(refID) - if let value = referenced as? SerializableAny { - return value - } - if referenced is NSNull { - return .foryDefault() - } - return SerializableAny(referenced) - case .refValue: - let reservedRefID = context.trackRef ? context.refReader.reserveRefID() : nil - let remoteTypeInfo = try requireDynamicTypeInfo() - let value = try foryReadCompatibleData(context, remoteTypeInfo: remoteTypeInfo) - if let reservedRefID { - if let object = value.value as AnyObject? { - context.refReader.storeRef(object, at: reservedRefID) - } else { - context.refReader.storeRef(value, at: reservedRefID) - } - } - return value - case .notNullValue: - break - } + return SerializableAny(referenced) + case .refValue: + let reservedRefID = context.trackRef ? context.refReader.reserveRefID() : nil + let remoteTypeInfo = try requireDynamicTypeInfo() + let value = try foryReadCompatibleData(context, remoteTypeInfo: remoteTypeInfo) + if let reservedRefID { + if let object = value.value as AnyObject? { + context.refReader.storeRef(object, at: reservedRefID) + } else { + context.refReader.storeRef(value, at: reservedRefID) + } } - - return try foryReadCompatibleData(context, remoteTypeInfo: requireDynamicTypeInfo()) + return value + case .notNullValue: + break + } } + + return try foryReadCompatibleData(context, remoteTypeInfo: requireDynamicTypeInfo()) + } } private func unwrapOptionalAny(_ value: Any) -> Any? { - let mirror = Mirror(reflecting: value) - guard mirror.displayStyle == .optional else { - return value - } - guard let (_, child) = mirror.children.first else { - return nil - } - return child + let mirror = Mirror(reflecting: value) + guard mirror.displayStyle == .optional else { + return value + } + guard let (_, child) = mirror.children.first else { + return nil + } + return child } private func toAnyHashableKey(_ value: Any) throws -> AnyHashable { - if let anyHashable = value as? AnyHashable { - return anyHashable - } - if value is ForyAnyNullValue { - throw ForyError.invalidData("dynamic AnyHashable key cannot be null") - } - guard let hashableValue = value as? any Hashable else { - throw ForyError.invalidData("dynamic AnyHashable key must be Hashable, got \(type(of: value))") - } - return AnyHashable(hashableValue) + if let anyHashable = value as? AnyHashable { + return anyHashable + } + if value is ForyAnyNullValue { + throw ForyError.invalidData("dynamic AnyHashable key cannot be null") + } + guard let hashableValue = value as? any Hashable else { + throw ForyError.invalidData("dynamic AnyHashable key must be Hashable, got \(type(of: value))") + } + return AnyHashable(hashableValue) } @inline(never) private func hasExactRuntimeType(_ value: Any, _: T.Type) -> Bool { - Swift.type(of: value) == T.self + Swift.type(of: value) == T.self } @inline(never) private func writePrimitiveArrayAnyTypeInfo(_ value: Any, context: WriteContext) -> Bool { - if hasExactRuntimeType(value, [Bool].self) { - context.buffer.writeUInt8(UInt8(truncatingIfNeeded: TypeId.boolArray.rawValue)) - return true - } - if hasExactRuntimeType(value, [Int8].self) { - context.buffer.writeUInt8(UInt8(truncatingIfNeeded: TypeId.int8Array.rawValue)) - return true - } - if hasExactRuntimeType(value, [Int16].self) { - context.buffer.writeUInt8(UInt8(truncatingIfNeeded: TypeId.int16Array.rawValue)) - return true - } - if hasExactRuntimeType(value, [Int32].self) { - context.buffer.writeUInt8(UInt8(truncatingIfNeeded: TypeId.int32Array.rawValue)) - return true - } - if hasExactRuntimeType(value, [Int64].self) { - context.buffer.writeUInt8(UInt8(truncatingIfNeeded: TypeId.int64Array.rawValue)) - return true - } - if hasExactRuntimeType(value, [UInt8].self) { - context.buffer.writeUInt8(UInt8(truncatingIfNeeded: TypeId.uint8Array.rawValue)) - return true - } - if hasExactRuntimeType(value, [UInt16].self) { - context.buffer.writeUInt8(UInt8(truncatingIfNeeded: TypeId.uint16Array.rawValue)) - return true - } - if hasExactRuntimeType(value, [UInt32].self) { - context.buffer.writeUInt8(UInt8(truncatingIfNeeded: TypeId.uint32Array.rawValue)) - return true - } - if hasExactRuntimeType(value, [UInt64].self) { - context.buffer.writeUInt8(UInt8(truncatingIfNeeded: TypeId.uint64Array.rawValue)) - return true - } - if hasExactRuntimeType(value, [Float16].self) { - context.buffer.writeUInt8(UInt8(truncatingIfNeeded: TypeId.float16Array.rawValue)) - return true - } - if hasExactRuntimeType(value, [BFloat16].self) { - context.buffer.writeUInt8(UInt8(truncatingIfNeeded: TypeId.bfloat16Array.rawValue)) - return true - } - if hasExactRuntimeType(value, [Float].self) { - context.buffer.writeUInt8(UInt8(truncatingIfNeeded: TypeId.float32Array.rawValue)) - return true - } - if hasExactRuntimeType(value, [Double].self) { - context.buffer.writeUInt8(UInt8(truncatingIfNeeded: TypeId.float64Array.rawValue)) - return true - } - return false + if hasExactRuntimeType(value, [Bool].self) { + context.buffer.writeUInt8(UInt8(truncatingIfNeeded: TypeId.boolArray.rawValue)) + return true + } + if hasExactRuntimeType(value, [Int8].self) { + context.buffer.writeUInt8(UInt8(truncatingIfNeeded: TypeId.int8Array.rawValue)) + return true + } + if hasExactRuntimeType(value, [Int16].self) { + context.buffer.writeUInt8(UInt8(truncatingIfNeeded: TypeId.int16Array.rawValue)) + return true + } + if hasExactRuntimeType(value, [Int32].self) { + context.buffer.writeUInt8(UInt8(truncatingIfNeeded: TypeId.int32Array.rawValue)) + return true + } + if hasExactRuntimeType(value, [Int64].self) { + context.buffer.writeUInt8(UInt8(truncatingIfNeeded: TypeId.int64Array.rawValue)) + return true + } + if hasExactRuntimeType(value, [UInt8].self) { + context.buffer.writeUInt8(UInt8(truncatingIfNeeded: TypeId.uint8Array.rawValue)) + return true + } + if hasExactRuntimeType(value, [UInt16].self) { + context.buffer.writeUInt8(UInt8(truncatingIfNeeded: TypeId.uint16Array.rawValue)) + return true + } + if hasExactRuntimeType(value, [UInt32].self) { + context.buffer.writeUInt8(UInt8(truncatingIfNeeded: TypeId.uint32Array.rawValue)) + return true + } + if hasExactRuntimeType(value, [UInt64].self) { + context.buffer.writeUInt8(UInt8(truncatingIfNeeded: TypeId.uint64Array.rawValue)) + return true + } + if hasExactRuntimeType(value, [Float16].self) { + context.buffer.writeUInt8(UInt8(truncatingIfNeeded: TypeId.float16Array.rawValue)) + return true + } + if hasExactRuntimeType(value, [BFloat16].self) { + context.buffer.writeUInt8(UInt8(truncatingIfNeeded: TypeId.bfloat16Array.rawValue)) + return true + } + if hasExactRuntimeType(value, [Float].self) { + context.buffer.writeUInt8(UInt8(truncatingIfNeeded: TypeId.float32Array.rawValue)) + return true + } + if hasExactRuntimeType(value, [Double].self) { + context.buffer.writeUInt8(UInt8(truncatingIfNeeded: TypeId.float64Array.rawValue)) + return true + } + return false } @inline(never) private func writePrimitiveArrayAnyPayload(_ value: Any, context: WriteContext) -> Bool { - if hasExactRuntimeType(value, [Bool].self), let array = value as? [Bool] { - writePrimitiveArray(array, context: context) - return true - } - if hasExactRuntimeType(value, [Int8].self), let array = value as? [Int8] { - writePrimitiveArray(array, context: context) - return true - } - if hasExactRuntimeType(value, [Int16].self), let array = value as? [Int16] { - writePrimitiveArray(array, context: context) - return true - } - if hasExactRuntimeType(value, [Int32].self), let array = value as? [Int32] { - writePrimitiveArray(array, context: context) - return true - } - if hasExactRuntimeType(value, [Int64].self), let array = value as? [Int64] { - writePrimitiveArray(array, context: context) - return true - } - if hasExactRuntimeType(value, [UInt8].self), let array = value as? [UInt8] { - writePrimitiveArray(array, context: context) - return true - } - if hasExactRuntimeType(value, [UInt16].self), let array = value as? [UInt16] { - writePrimitiveArray(array, context: context) - return true - } - if hasExactRuntimeType(value, [UInt32].self), let array = value as? [UInt32] { - writePrimitiveArray(array, context: context) - return true - } - if hasExactRuntimeType(value, [UInt64].self), let array = value as? [UInt64] { - writePrimitiveArray(array, context: context) - return true - } - if hasExactRuntimeType(value, [Float16].self), let array = value as? [Float16] { - writePrimitiveArray(array, context: context) - return true - } - if hasExactRuntimeType(value, [BFloat16].self), let array = value as? [BFloat16] { - writePrimitiveArray(array, context: context) - return true - } - if hasExactRuntimeType(value, [Float].self), let array = value as? [Float] { - writePrimitiveArray(array, context: context) - return true - } - if hasExactRuntimeType(value, [Double].self), let array = value as? [Double] { - writePrimitiveArray(array, context: context) - return true - } - return false + if hasExactRuntimeType(value, [Bool].self), let array = value as? [Bool] { + writePrimitiveArray(array, context: context) + return true + } + if hasExactRuntimeType(value, [Int8].self), let array = value as? [Int8] { + writePrimitiveArray(array, context: context) + return true + } + if hasExactRuntimeType(value, [Int16].self), let array = value as? [Int16] { + writePrimitiveArray(array, context: context) + return true + } + if hasExactRuntimeType(value, [Int32].self), let array = value as? [Int32] { + writePrimitiveArray(array, context: context) + return true + } + if hasExactRuntimeType(value, [Int64].self), let array = value as? [Int64] { + writePrimitiveArray(array, context: context) + return true + } + if hasExactRuntimeType(value, [UInt8].self), let array = value as? [UInt8] { + writePrimitiveArray(array, context: context) + return true + } + if hasExactRuntimeType(value, [UInt16].self), let array = value as? [UInt16] { + writePrimitiveArray(array, context: context) + return true + } + if hasExactRuntimeType(value, [UInt32].self), let array = value as? [UInt32] { + writePrimitiveArray(array, context: context) + return true + } + if hasExactRuntimeType(value, [UInt64].self), let array = value as? [UInt64] { + writePrimitiveArray(array, context: context) + return true + } + if hasExactRuntimeType(value, [Float16].self), let array = value as? [Float16] { + writePrimitiveArray(array, context: context) + return true + } + if hasExactRuntimeType(value, [BFloat16].self), let array = value as? [BFloat16] { + writePrimitiveArray(array, context: context) + return true + } + if hasExactRuntimeType(value, [Float].self), let array = value as? [Float] { + writePrimitiveArray(array, context: context) + return true + } + if hasExactRuntimeType(value, [Double].self), let array = value as? [Double] { + writePrimitiveArray(array, context: context) + return true + } + return false } private func writeAnyTypeInfo(_ value: Any, context: WriteContext) throws { - if writePrimitiveArrayAnyTypeInfo(value, context: context) { - return - } - - if let serializer = value as? any Serializer { - try serializer.foryWriteTypeInfo(context) - return - } - - if value is [Any] { - context.buffer.writeUInt8(UInt8(truncatingIfNeeded: TypeId.list.rawValue)) - return - } - if value is [String: Any] || value is [Int32: Any] || value is [AnyHashable: Any] { - context.buffer.writeUInt8(UInt8(truncatingIfNeeded: TypeId.map.rawValue)) - return - } - - throw ForyError.invalidData("unsupported dynamic Any runtime type \(type(of: value))") + if writePrimitiveArrayAnyTypeInfo(value, context: context) { + return + } + + if let serializer = value as? any Serializer { + try serializer.foryWriteTypeInfo(context) + return + } + + if value is [Any] { + context.buffer.writeUInt8(UInt8(truncatingIfNeeded: TypeId.list.rawValue)) + return + } + if value is [String: Any] || value is [Int32: Any] || value is [AnyHashable: Any] { + context.buffer.writeUInt8(UInt8(truncatingIfNeeded: TypeId.map.rawValue)) + return + } + + throw ForyError.invalidData("unsupported dynamic Any runtime type \(type(of: value))") } private func writeAnyPayload(_ value: Any, context: WriteContext, hasGenerics: Bool) throws { - try context.enterDynamicAnyDepth() - defer { context.leaveDynamicAnyDepth() } + try context.enterDynamicAnyDepth() + defer { context.leaveDynamicAnyDepth() } - if writePrimitiveArrayAnyPayload(value, context: context) { - return - } + if writePrimitiveArrayAnyPayload(value, context: context) { + return + } - if let serializer = value as? any Serializer { - if type(of: serializer).isRefType { - try serializer.foryWrite( - context, - refMode: .tracking, - writeTypeInfo: false, - hasGenerics: hasGenerics - ) - } else { - try serializer.foryWriteData(context, hasGenerics: hasGenerics) - } - return - } - if let list = value as? [Any] { - try writeListOfAny(list, context: context, refMode: .none, hasGenerics: hasGenerics) - return - } - if let map = value as? [String: Any] { - // Always include key type info for dynamic map payload. - try writeMapStringToAny(map, context: context, refMode: .none, hasGenerics: false) - return - } - if let map = value as? [Int32: Any] { - // Always include key type info for dynamic map payload. - try writeMapInt32ToAny(map, context: context, refMode: .none, hasGenerics: false) - return - } - if let map = value as? [AnyHashable: Any] { - // Always include key type info for dynamic map payload. - try writeMapAnyHashableToAny(map, context: context, refMode: .none, hasGenerics: false) - return - } - throw ForyError.invalidData("unsupported dynamic Any runtime type \(type(of: value))") + if let serializer = value as? any Serializer { + if type(of: serializer).isRefType { + try serializer.foryWrite( + context, + refMode: .tracking, + writeTypeInfo: false, + hasGenerics: hasGenerics + ) + } else { + try serializer.foryWriteData(context, hasGenerics: hasGenerics) + } + return + } + if let list = value as? [Any] { + try writeListOfAny(list, context: context, refMode: .none, hasGenerics: hasGenerics) + return + } + if let map = value as? [String: Any] { + // Always include key type info for dynamic map payload. + try writeMapStringToAny(map, context: context, refMode: .none, hasGenerics: false) + return + } + if let map = value as? [Int32: Any] { + // Always include key type info for dynamic map payload. + try writeMapInt32ToAny(map, context: context, refMode: .none, hasGenerics: false) + return + } + if let map = value as? [AnyHashable: Any] { + // Always include key type info for dynamic map payload. + try writeMapAnyHashableToAny(map, context: context, refMode: .none, hasGenerics: false) + return + } + throw ForyError.invalidData("unsupported dynamic Any runtime type \(type(of: value))") } public func castAnyDynamicValue(_ value: Any?, to type: T.Type) throws -> T { - _ = type - func castNilSentinel(_ sentinel: Any) throws -> T { - guard let casted = sentinel as? T else { - throw ForyError.invalidData("cannot cast dynamic Any value to \(type)") - } - return casted + _ = type + func castNilSentinel(_ sentinel: Any) throws -> T { + guard let casted = sentinel as? T else { + throw ForyError.invalidData("cannot cast dynamic Any value to \(type)") } + return casted + } - if value == nil { - if T.self == Any.self { - return try castNilSentinel(ForyAnyNullValue()) - } - if T.self == AnyObject.self { - return try castNilSentinel(NSNull()) - } - if T.self == (any Serializer).self { - return try castNilSentinel(ForyAnyNullValue()) - } - if let optionalType = T.self as? any OptionalTypeMarker.Type { - return try castNilSentinel(optionalType.noneValue) - } + if value == nil { + if T.self == Any.self { + return try castNilSentinel(ForyAnyNullValue()) } - - guard let typed = value as? T else { - throw ForyError.invalidData("cannot cast dynamic Any value to \(type)") + if T.self == AnyObject.self { + return try castNilSentinel(NSNull()) } - return typed + if T.self == (any Serializer).self { + return try castNilSentinel(ForyAnyNullValue()) + } + if let optionalType = T.self as? any OptionalTypeMarker.Type { + return try castNilSentinel(optionalType.noneValue) + } + } + + guard let typed = value as? T else { + throw ForyError.invalidData("cannot cast dynamic Any value to \(type)") + } + return typed } public func writeAny( - _ value: Any?, - context: WriteContext, - refMode: RefMode, - writeTypeInfo: Bool = true, - hasGenerics: Bool = false + _ value: Any?, + context: WriteContext, + refMode: RefMode, + writeTypeInfo: Bool = true, + hasGenerics: Bool = false ) throws { - try SerializableAny.wrapped(value).foryWrite( - context, - refMode: refMode, - writeTypeInfo: writeTypeInfo, - hasGenerics: hasGenerics - ) + try SerializableAny.wrapped(value).foryWrite( + context, + refMode: refMode, + writeTypeInfo: writeTypeInfo, + hasGenerics: hasGenerics + ) } public func readAny( - context: ReadContext, - refMode: RefMode, - readTypeInfo: Bool = true + context: ReadContext, + refMode: RefMode, + readTypeInfo: Bool = true ) throws -> Any? { - try SerializableAny.foryRead(context, refMode: refMode, readTypeInfo: readTypeInfo).anyValue() + try SerializableAny.foryRead(context, refMode: refMode, readTypeInfo: readTypeInfo).anyValue() } public func writeListOfAny( - _ value: [Any]?, - context: WriteContext, - refMode: RefMode, - writeTypeInfo: Bool = false, - hasGenerics: Bool = true + _ value: [Any]?, + context: WriteContext, + refMode: RefMode, + writeTypeInfo: Bool = false, + hasGenerics: Bool = true ) throws { - let wrapped = value?.map { SerializableAny.wrapped($0) } - try wrapped.foryWrite( - context, - refMode: refMode, - writeTypeInfo: writeTypeInfo, - hasGenerics: hasGenerics - ) + let wrapped = value?.map { SerializableAny.wrapped($0) } + try wrapped.foryWrite( + context, + refMode: refMode, + writeTypeInfo: writeTypeInfo, + hasGenerics: hasGenerics + ) } public func readListOfAny( - context: ReadContext, - refMode: RefMode, - readTypeInfo: Bool = false + context: ReadContext, + refMode: RefMode, + readTypeInfo: Bool = false ) throws -> [Any]? { - let wrapped: [SerializableAny]? = try [SerializableAny]?.foryRead( - context, - refMode: refMode, - readTypeInfo: readTypeInfo - ) - guard let wrapped else { - return nil - } - try reserveAnyReferenceArrayMemory(context, count: wrapped.count) - return wrapped.map { $0.anyValueForCollection() } + let wrapped: [SerializableAny]? = try [SerializableAny]?.foryRead( + context, + refMode: refMode, + readTypeInfo: readTypeInfo + ) + guard let wrapped else { + return nil + } + try reserveAnyReferenceArrayMemory(context, count: wrapped.count) + return wrapped.map { $0.anyValueForCollection() } } public func writeMapStringToAny( - _ value: [String: Any]?, - context: WriteContext, - refMode: RefMode, - writeTypeInfo: Bool = false, - hasGenerics: Bool = true + _ value: [String: Any]?, + context: WriteContext, + refMode: RefMode, + writeTypeInfo: Bool = false, + hasGenerics: Bool = true ) throws { - let wrapped = value?.reduce(into: [String: SerializableAny]()) { result, pair in - result[pair.key] = SerializableAny.wrapped(pair.value) - } - try wrapped.foryWrite( - context, - refMode: refMode, - writeTypeInfo: writeTypeInfo, - hasGenerics: hasGenerics - ) + let wrapped = value?.reduce(into: [String: SerializableAny]()) { result, pair in + result[pair.key] = SerializableAny.wrapped(pair.value) + } + try wrapped.foryWrite( + context, + refMode: refMode, + writeTypeInfo: writeTypeInfo, + hasGenerics: hasGenerics + ) } public func readMapStringToAny( - context: ReadContext, - refMode: RefMode, - readTypeInfo: Bool = false + context: ReadContext, + refMode: RefMode, + readTypeInfo: Bool = false ) throws -> [String: Any]? { - let wrapped: [String: SerializableAny]? = try [String: SerializableAny]?.foryRead( - context, - refMode: refMode, - readTypeInfo: readTypeInfo - ) - guard let wrapped else { - return nil - } - try reserveAnyReferenceMapMemory(context, count: wrapped.count) - var map: [String: Any] = [:] - map.reserveCapacity(wrapped.count) - for pair in wrapped { - map[pair.key] = pair.value.anyValueForCollection() - } - return map + let wrapped: [String: SerializableAny]? = try [String: SerializableAny]?.foryRead( + context, + refMode: refMode, + readTypeInfo: readTypeInfo + ) + guard let wrapped else { + return nil + } + try reserveAnyReferenceMapMemory(context, [String: Any].self, count: wrapped.count) + var map: [String: Any] = [:] + map.reserveCapacity(wrapped.count) + for pair in wrapped { + map[pair.key] = pair.value.anyValueForCollection() + } + return map } public func writeMapInt32ToAny( - _ value: [Int32: Any]?, - context: WriteContext, - refMode: RefMode, - writeTypeInfo: Bool = false, - hasGenerics: Bool = true + _ value: [Int32: Any]?, + context: WriteContext, + refMode: RefMode, + writeTypeInfo: Bool = false, + hasGenerics: Bool = true ) throws { - let wrapped = value?.reduce(into: [Int32: SerializableAny]()) { result, pair in - result[pair.key] = SerializableAny.wrapped(pair.value) - } - try wrapped.foryWrite( - context, - refMode: refMode, - writeTypeInfo: writeTypeInfo, - hasGenerics: hasGenerics - ) + let wrapped = value?.reduce(into: [Int32: SerializableAny]()) { result, pair in + result[pair.key] = SerializableAny.wrapped(pair.value) + } + try wrapped.foryWrite( + context, + refMode: refMode, + writeTypeInfo: writeTypeInfo, + hasGenerics: hasGenerics + ) } public func readMapInt32ToAny( - context: ReadContext, - refMode: RefMode, - readTypeInfo: Bool = false + context: ReadContext, + refMode: RefMode, + readTypeInfo: Bool = false ) throws -> [Int32: Any]? { - let wrapped: [Int32: SerializableAny]? = try [Int32: SerializableAny]?.foryRead( - context, - refMode: refMode, - readTypeInfo: readTypeInfo - ) - guard let wrapped else { - return nil - } - try reserveAnyReferenceMapMemory(context, count: wrapped.count) - var map: [Int32: Any] = [:] - map.reserveCapacity(wrapped.count) - for pair in wrapped { - map[pair.key] = pair.value.anyValueForCollection() - } - return map + let wrapped: [Int32: SerializableAny]? = try [Int32: SerializableAny]?.foryRead( + context, + refMode: refMode, + readTypeInfo: readTypeInfo + ) + guard let wrapped else { + return nil + } + try reserveAnyReferenceMapMemory(context, [Int32: Any].self, count: wrapped.count) + var map: [Int32: Any] = [:] + map.reserveCapacity(wrapped.count) + for pair in wrapped { + map[pair.key] = pair.value.anyValueForCollection() + } + return map } public func writeMapAnyHashableToAny( - _ value: [AnyHashable: Any]?, - context: WriteContext, - refMode: RefMode, - writeTypeInfo: Bool = false, - hasGenerics: Bool = true + _ value: [AnyHashable: Any]?, + context: WriteContext, + refMode: RefMode, + writeTypeInfo: Bool = false, + hasGenerics: Bool = true ) throws { - let wrapped = value?.reduce(into: [AnyHashable: SerializableAny]()) { result, pair in - result[pair.key] = SerializableAny.wrapped(pair.value) - } - try wrapped.foryWrite( - context, - refMode: refMode, - writeTypeInfo: writeTypeInfo, - hasGenerics: hasGenerics - ) + let wrapped = value?.reduce(into: [AnyHashable: SerializableAny]()) { result, pair in + result[pair.key] = SerializableAny.wrapped(pair.value) + } + try wrapped.foryWrite( + context, + refMode: refMode, + writeTypeInfo: writeTypeInfo, + hasGenerics: hasGenerics + ) } public func readMapAnyHashableToAny( - context: ReadContext, - refMode: RefMode, - readTypeInfo: Bool = false + context: ReadContext, + refMode: RefMode, + readTypeInfo: Bool = false ) throws -> [AnyHashable: Any]? { - let wrapped: [AnyHashable: SerializableAny]? = try [AnyHashable: SerializableAny]?.foryRead( - context, - refMode: refMode, - readTypeInfo: readTypeInfo - ) - guard let wrapped else { - return nil - } - try reserveAnyReferenceMapMemory(context, count: wrapped.count) - var map: [AnyHashable: Any] = [:] - map.reserveCapacity(wrapped.count) - for pair in wrapped { - map[pair.key] = pair.value.anyValueForCollection() - } - return map + let wrapped: [AnyHashable: SerializableAny]? = try [AnyHashable: SerializableAny]?.foryRead( + context, + refMode: refMode, + readTypeInfo: readTypeInfo + ) + guard let wrapped else { + return nil + } + try reserveAnyReferenceMapMemory(context, [AnyHashable: Any].self, count: wrapped.count) + var map: [AnyHashable: Any] = [:] + map.reserveCapacity(wrapped.count) + for pair in wrapped { + map[pair.key] = pair.value.anyValueForCollection() + } + return map } func readDynamicAnyMapValue(context: ReadContext) throws -> Any { - let map = try readMapAnyHashableToAny(context: context, refMode: .none) ?? [:] - if map.isEmpty { - try reserveAnyReferenceMapMemory(context, count: 0) - return [String: Any]() - } - try reserveAnyReferenceMapMemory(context, count: map.count) - var stringMap: [String: Any] = [:] - stringMap.reserveCapacity(map.count) - for pair in map { - guard let key = pair.key.base as? String else { - stringMap.removeAll(keepingCapacity: false) - break - } - stringMap[key] = pair.value - } - if stringMap.count == map.count { - return stringMap - } - - try reserveAnyReferenceMapMemory(context, count: map.count) - var int32Map: [Int32: Any] = [:] - int32Map.reserveCapacity(map.count) - for pair in map { - guard let key = pair.key.base as? Int32 else { - return map - } - int32Map[key] = pair.value - } - if int32Map.count == map.count { - return int32Map - } - - return map + let map = try readMapAnyHashableToAny(context: context, refMode: .none) ?? [:] + if map.isEmpty { + try reserveAnyReferenceMapMemory(context, [String: Any].self, count: 0) + return [String: Any]() + } + try reserveAnyReferenceMapMemory(context, [String: Any].self, count: map.count) + var stringMap: [String: Any] = [:] + stringMap.reserveCapacity(map.count) + for pair in map { + guard let key = pair.key.base as? String else { + stringMap.removeAll(keepingCapacity: false) + break + } + stringMap[key] = pair.value + } + if stringMap.count == map.count { + return stringMap + } + + try reserveAnyReferenceMapMemory(context, [Int32: Any].self, count: map.count) + var int32Map: [Int32: Any] = [:] + int32Map.reserveCapacity(map.count) + for pair in map { + guard let key = pair.key.base as? Int32 else { + return map + } + int32Map[key] = pair.value + } + if int32Map.count == map.count { + return int32Map + } + + return map } diff --git a/swift/Sources/Fory/CollectionSerializers.swift b/swift/Sources/Fory/CollectionSerializers.swift index 86d41af3a0..3c2ed59f74 100644 --- a/swift/Sources/Fory/CollectionSerializers.swift +++ b/swift/Sources/Fory/CollectionSerializers.swift @@ -18,1052 +18,1086 @@ import Foundation enum CollectionHeader { - static let trackingRef: UInt8 = 0b0000_0001 - static let hasNull: UInt8 = 0b0000_0010 - static let declaredElementType: UInt8 = 0b0000_0100 - static let sameType: UInt8 = 0b0000_1000 + static let trackingRef: UInt8 = 0b0000_0001 + static let hasNull: UInt8 = 0b0000_0010 + static let declaredElementType: UInt8 = 0b0000_0100 + static let sameType: UInt8 = 0b0000_1000 } enum MapHeader { - static let trackingKeyRef: UInt8 = 0b0000_0001 - static let keyNull: UInt8 = 0b0000_0010 - static let declaredKeyType: UInt8 = 0b0000_0100 + static let trackingKeyRef: UInt8 = 0b0000_0001 + static let keyNull: UInt8 = 0b0000_0010 + static let declaredKeyType: UInt8 = 0b0000_0100 - static let trackingValueRef: UInt8 = 0b0000_1000 - static let valueNull: UInt8 = 0b0001_0000 - static let declaredValueType: UInt8 = 0b0010_0000 + static let trackingValueRef: UInt8 = 0b0000_1000 + static let valueNull: UInt8 = 0b0001_0000 + static let declaredValueType: UInt8 = 0b0010_0000 } -private let containerReferenceBytes = 4 +private let storedReferenceBytes = 4 @inline(__always) -private func containerElementBytes(_ type: Element.Type) -> Int { - type.isRefType ? containerReferenceBytes : max(1, MemoryLayout.stride) +private func storedElementBytes(_ type: Element.Type) -> Int { + type.isRefType ? storedReferenceBytes : max(1, MemoryLayout.stride) } @inline(__always) -private func reserveContainerArrayMemory( - _ context: ReadContext, - _ type: Element.Type, - count: Int +private func reserveGraphArrayMemory( + _ context: ReadContext, + _ type: Element.Type, + count: Int ) throws { - try context.reserveCountedContainerMemory( - count: count, - elementBytes: containerElementBytes(type) - ) + try context.reserveCountedGraphMemory( + count: count, + elementBytes: storedElementBytes(type) + ) } @inline(__always) -private func reserveContainerMapMemory( - _ context: ReadContext, - key: Key.Type, - value: Value.Type, - count: Int +private func reserveGraphMapMemory( + _ context: ReadContext, + key: Key.Type, + value: Value.Type, + count: Int ) throws { - let keyBytes = containerElementBytes(key) - let valueBytes = containerElementBytes(value) - let (elementBytes, overflow) = keyBytes.addingReportingOverflow(valueBytes) - if overflow { - try context.reserveContainerMemory(-1) - } - try context.reserveCountedContainerMemory(count: count, elementBytes: elementBytes) + let keyBytes = storedElementBytes(key) + let valueBytes = storedElementBytes(value) + let (elementBytes, overflow) = keyBytes.addingReportingOverflow(valueBytes) + if overflow { + try context.reserveGraphMemory(-1) + } + try context.reserveCountedGraphMemory(count: count, elementBytes: elementBytes) } private func primitiveArrayTypeID(for _: Element.Type) -> TypeId? { - if Element.self == UInt8.self { return .uint8Array } - if Element.self == Bool.self { return .boolArray } - if Element.self == Int8.self { return .int8Array } - if Element.self == Int16.self { return .int16Array } - if Element.self == Int32.self { return .int32Array } - if Element.self == Int64.self { return .int64Array } - if Element.self == UInt16.self { return .uint16Array } - if Element.self == UInt32.self { return .uint32Array } - if Element.self == UInt64.self { return .uint64Array } - if Element.self == Float16.self { return .float16Array } - if Element.self == BFloat16.self { return .bfloat16Array } - if Element.self == Float.self { return .float32Array } - if Element.self == Double.self { return .float64Array } - return nil + if Element.self == UInt8.self { return .uint8Array } + if Element.self == Bool.self { return .boolArray } + if Element.self == Int8.self { return .int8Array } + if Element.self == Int16.self { return .int16Array } + if Element.self == Int32.self { return .int32Array } + if Element.self == Int64.self { return .int64Array } + if Element.self == UInt16.self { return .uint16Array } + if Element.self == UInt32.self { return .uint32Array } + if Element.self == UInt64.self { return .uint64Array } + if Element.self == Float16.self { return .float16Array } + if Element.self == BFloat16.self { return .bfloat16Array } + if Element.self == Float.self { return .float32Array } + if Element.self == Double.self { return .float64Array } + return nil } private let hostIsLittleEndian = Int(littleEndian: 1) == 1 @inline(__always) private func uncheckedArrayCast(_ array: [From], to _: To.Type) -> [To] { - assert(From.self == To.self) - return unsafeBitCast(array, to: [To].self) + assert(From.self == To.self) + return unsafeBitCast(array, to: [To].self) } @inline(__always) private func readArrayUninitialized( - count: Int, - _ initializer: (UnsafeMutablePointer) throws -> Void + count: Int, + _ initializer: (UnsafeMutablePointer) throws -> Void ) rethrows -> [Element] { - try [Element](unsafeUninitializedCapacity: count) { destination, initializedCount in - if count > 0 { - try initializer(destination.baseAddress!) - } - initializedCount = count + try [Element](unsafeUninitializedCapacity: count) { destination, initializedCount in + if count > 0 { + try initializer(destination.baseAddress!) } + initializedCount = count + } } func writePrimitiveArray(_ value: [Element], context: WriteContext) { - if Element.self == UInt8.self { - let bytes = uncheckedArrayCast(value, to: UInt8.self) - context.buffer.writeVarUInt32(UInt32(bytes.count)) - context.buffer.writeBytes(bytes) - return - } - - if Element.self == Bool.self { - let bools = uncheckedArrayCast(value, to: Bool.self) - context.buffer.writeVarUInt32(UInt32(bools.count)) - for item in bools { - context.buffer.writeUInt8(item ? 1 : 0) - } - return - } - - if Element.self == Int8.self { - let values = uncheckedArrayCast(value, to: Int8.self) - context.buffer.writeVarUInt32(UInt32(values.count)) - values.withUnsafeBytes { rawBytes in - context.buffer.writeBytes(rawBytes) - } - return + if Element.self == UInt8.self { + let bytes = uncheckedArrayCast(value, to: UInt8.self) + context.buffer.writeVarUInt32(UInt32(bytes.count)) + context.buffer.writeBytes(bytes) + return + } + + if Element.self == Bool.self { + let bools = uncheckedArrayCast(value, to: Bool.self) + context.buffer.writeVarUInt32(UInt32(bools.count)) + for item in bools { + context.buffer.writeUInt8(item ? 1 : 0) } - - if Element.self == Int16.self { - let values = uncheckedArrayCast(value, to: Int16.self) - context.buffer.writeVarUInt32(UInt32(values.count * 2)) - if hostIsLittleEndian { - values.withUnsafeBytes { rawBytes in - context.buffer.writeBytes(rawBytes) - } - } else { - for item in values { - context.buffer.writeInt16(item) - } - } - return + return + } + + if Element.self == Int8.self { + let values = uncheckedArrayCast(value, to: Int8.self) + context.buffer.writeVarUInt32(UInt32(values.count)) + values.withUnsafeBytes { rawBytes in + context.buffer.writeBytes(rawBytes) } + return + } - if Element.self == Int32.self { - let values = uncheckedArrayCast(value, to: Int32.self) - context.buffer.writeVarUInt32(UInt32(values.count * 4)) - if hostIsLittleEndian { - values.withUnsafeBytes { rawBytes in - context.buffer.writeBytes(rawBytes) - } - } else { - for item in values { - context.buffer.writeInt32(item) - } - } - return + if Element.self == Int16.self { + let values = uncheckedArrayCast(value, to: Int16.self) + context.buffer.writeVarUInt32(UInt32(values.count * 2)) + if hostIsLittleEndian { + values.withUnsafeBytes { rawBytes in + context.buffer.writeBytes(rawBytes) + } + } else { + for item in values { + context.buffer.writeInt16(item) + } } + return + } - if Element.self == UInt32.self { - let values = uncheckedArrayCast(value, to: UInt32.self) - context.buffer.writeVarUInt32(UInt32(values.count * 4)) - if hostIsLittleEndian { - values.withUnsafeBytes { rawBytes in - context.buffer.writeBytes(rawBytes) - } - } else { - for item in values { - context.buffer.writeUInt32(item) - } - } - return + if Element.self == Int32.self { + let values = uncheckedArrayCast(value, to: Int32.self) + context.buffer.writeVarUInt32(UInt32(values.count * 4)) + if hostIsLittleEndian { + values.withUnsafeBytes { rawBytes in + context.buffer.writeBytes(rawBytes) + } + } else { + for item in values { + context.buffer.writeInt32(item) + } } + return + } - if Element.self == Int64.self { - let values = uncheckedArrayCast(value, to: Int64.self) - context.buffer.writeVarUInt32(UInt32(values.count * 8)) - if hostIsLittleEndian { - values.withUnsafeBytes { rawBytes in - context.buffer.writeBytes(rawBytes) - } - } else { - for item in values { - context.buffer.writeInt64(item) - } - } - return + if Element.self == UInt32.self { + let values = uncheckedArrayCast(value, to: UInt32.self) + context.buffer.writeVarUInt32(UInt32(values.count * 4)) + if hostIsLittleEndian { + values.withUnsafeBytes { rawBytes in + context.buffer.writeBytes(rawBytes) + } + } else { + for item in values { + context.buffer.writeUInt32(item) + } } + return + } - if Element.self == UInt64.self { - let values = uncheckedArrayCast(value, to: UInt64.self) - context.buffer.writeVarUInt32(UInt32(values.count * 8)) - if hostIsLittleEndian { - values.withUnsafeBytes { rawBytes in - context.buffer.writeBytes(rawBytes) - } - } else { - for item in values { - context.buffer.writeUInt64(item) - } - } - return + if Element.self == Int64.self { + let values = uncheckedArrayCast(value, to: Int64.self) + context.buffer.writeVarUInt32(UInt32(values.count * 8)) + if hostIsLittleEndian { + values.withUnsafeBytes { rawBytes in + context.buffer.writeBytes(rawBytes) + } + } else { + for item in values { + context.buffer.writeInt64(item) + } } + return + } - if Element.self == UInt16.self { - let values = uncheckedArrayCast(value, to: UInt16.self) - context.buffer.writeVarUInt32(UInt32(values.count * 2)) - if hostIsLittleEndian { - values.withUnsafeBytes { rawBytes in - context.buffer.writeBytes(rawBytes) - } - } else { - for item in values { - context.buffer.writeUInt16(item) - } - } - return + if Element.self == UInt64.self { + let values = uncheckedArrayCast(value, to: UInt64.self) + context.buffer.writeVarUInt32(UInt32(values.count * 8)) + if hostIsLittleEndian { + values.withUnsafeBytes { rawBytes in + context.buffer.writeBytes(rawBytes) + } + } else { + for item in values { + context.buffer.writeUInt64(item) + } } + return + } - if Element.self == Float16.self { - let values = uncheckedArrayCast(value, to: Float16.self) - context.buffer.writeVarUInt32(UInt32(values.count * 2)) - for item in values { - context.buffer.writeUInt16(item.bitPattern) - } - return + if Element.self == UInt16.self { + let values = uncheckedArrayCast(value, to: UInt16.self) + context.buffer.writeVarUInt32(UInt32(values.count * 2)) + if hostIsLittleEndian { + values.withUnsafeBytes { rawBytes in + context.buffer.writeBytes(rawBytes) + } + } else { + for item in values { + context.buffer.writeUInt16(item) + } } - - if Element.self == BFloat16.self { - let values = uncheckedArrayCast(value, to: BFloat16.self) - context.buffer.writeVarUInt32(UInt32(values.count * 2)) - for item in values { - context.buffer.writeUInt16(item.rawValue) - } - return + return + } + + if Element.self == Float16.self { + let values = uncheckedArrayCast(value, to: Float16.self) + context.buffer.writeVarUInt32(UInt32(values.count * 2)) + for item in values { + context.buffer.writeUInt16(item.bitPattern) } - - if Element.self == Float.self { - let values = uncheckedArrayCast(value, to: Float.self) - context.buffer.writeVarUInt32(UInt32(values.count * 4)) - if hostIsLittleEndian { - values.withUnsafeBytes { rawBytes in - context.buffer.writeBytes(rawBytes) - } - } else { - for item in values { - context.buffer.writeFloat32(item) - } - } - return + return + } + + if Element.self == BFloat16.self { + let values = uncheckedArrayCast(value, to: BFloat16.self) + context.buffer.writeVarUInt32(UInt32(values.count * 2)) + for item in values { + context.buffer.writeUInt16(item.rawValue) } + return + } - let values = uncheckedArrayCast(value, to: Double.self) - context.buffer.writeVarUInt32(UInt32(values.count * 8)) + if Element.self == Float.self { + let values = uncheckedArrayCast(value, to: Float.self) + context.buffer.writeVarUInt32(UInt32(values.count * 4)) if hostIsLittleEndian { - values.withUnsafeBytes { rawBytes in - context.buffer.writeBytes(rawBytes) - } + values.withUnsafeBytes { rawBytes in + context.buffer.writeBytes(rawBytes) + } } else { - for item in values { - context.buffer.writeFloat64(item) - } + for item in values { + context.buffer.writeFloat32(item) + } } + return + } + + let values = uncheckedArrayCast(value, to: Double.self) + context.buffer.writeVarUInt32(UInt32(values.count * 8)) + if hostIsLittleEndian { + values.withUnsafeBytes { rawBytes in + context.buffer.writeBytes(rawBytes) + } + } else { + for item in values { + context.buffer.writeFloat64(item) + } + } } @inline(__always) private func preparePrimitiveArray( - _ context: ReadContext, - reserveContainerStorage: Bool, - type: Element.Type, - count: Int, - label: String + _ context: ReadContext, + reserveGraphStorage: Bool, + type: Element.Type, + count: Int, + label: String ) throws { - try context.ensureCollectionLength(count, label: label) - if reserveContainerStorage { - try reserveContainerArrayMemory(context, type, count: count) - } + try context.ensureCollectionLength(count, label: label) + if reserveGraphStorage { + try reserveGraphArrayMemory(context, type, count: count) + } } func readPrimitiveArray( - _ context: ReadContext, - reserveContainerStorage: Bool = false + _ context: ReadContext, + reserveGraphStorage: Bool = false ) throws -> [Element] { - let byteSize = Int(try context.buffer.readVarUInt32()) - try context.ensureRemainingBytes(byteSize, label: "primitive_array_bytes") - - if Element.self == UInt8.self { - try preparePrimitiveArray(context, reserveContainerStorage: reserveContainerStorage, type: Element.self, count: byteSize, label: "uint8_array") - let bytes = try context.buffer.readBytes(count: byteSize) - return uncheckedArrayCast(bytes, to: Element.self) + let byteSize = Int(try context.buffer.readVarUInt32()) + try context.ensureRemainingBytes(byteSize, label: "primitive_array_bytes") + + if Element.self == UInt8.self { + try preparePrimitiveArray( + context, reserveGraphStorage: reserveGraphStorage, type: Element.self, count: byteSize, + label: "uint8_array") + let bytes = try context.buffer.readBytes(count: byteSize) + return uncheckedArrayCast(bytes, to: Element.self) + } + + if Element.self == Bool.self { + try preparePrimitiveArray( + context, reserveGraphStorage: reserveGraphStorage, type: Element.self, count: byteSize, + label: "bool_array") + let out = try readArrayUninitialized(count: byteSize) { destination in + for index in 0.. [Element] { - [] + public static func foryDefault() -> [Element] { + [] + } + + public static var staticTypeId: TypeId { + .list + } + + public static func foryWriteStaticTypeInfo(_ context: WriteContext) throws { + context.buffer.writeUInt8(UInt8(truncatingIfNeeded: staticTypeId.rawValue)) + } + + public static func foryReadTypeInfo(_ context: ReadContext) throws -> TypeInfo? { + let rawTypeID = try context.buffer.readVarUInt32() + guard let actualTypeID = TypeId(rawValue: rawTypeID) else { + throw ForyError.invalidData("unknown type id \(rawTypeID)") } - public static var staticTypeId: TypeId { - .list + let expectedTypeID = staticTypeId + if actualTypeID != expectedTypeID { + throw ForyError.typeMismatch(expected: expectedTypeID.rawValue, actual: rawTypeID) } + return nil + } - public static func foryWriteStaticTypeInfo(_ context: WriteContext) throws { - context.buffer.writeUInt8(UInt8(truncatingIfNeeded: staticTypeId.rawValue)) + public func foryWriteData(_ context: WriteContext, hasGenerics: Bool) throws { + let buffer = context.buffer + buffer.writeVarUInt32(UInt32(self.count)) + if self.isEmpty { + return } - public static func foryReadTypeInfo(_ context: ReadContext) throws -> TypeInfo? { - let rawTypeID = try context.buffer.readVarUInt32() - guard let actualTypeID = TypeId(rawValue: rawTypeID) else { - throw ForyError.invalidData("unknown type id \(rawTypeID)") - } + let hasNull = Element.isNullableType && self.contains(where: { $0.foryIsNone }) + let trackRef = context.trackRef && Element.isRefType + let declaredElementType = hasGenerics && !TypeId.needsTypeInfoForField(Element.staticTypeId) + let dynamicElementType = Element.staticTypeId == .unknown - let expectedTypeID = staticTypeId - if actualTypeID != expectedTypeID { - throw ForyError.typeMismatch(expected: expectedTypeID.rawValue, actual: rawTypeID) - } - return nil + var header: UInt8 = dynamicElementType ? 0 : CollectionHeader.sameType + if trackRef { + header |= CollectionHeader.trackingRef + } + if hasNull { + header |= CollectionHeader.hasNull + } + if declaredElementType { + header |= CollectionHeader.declaredElementType } - public func foryWriteData(_ context: WriteContext, hasGenerics: Bool) throws { - let buffer = context.buffer - buffer.writeVarUInt32(UInt32(self.count)) - if self.isEmpty { - return - } + buffer.writeUInt8(header) + if !dynamicElementType && !declaredElementType { + try Element.foryWriteStaticTypeInfo(context) + } - let hasNull = Element.isNullableType && self.contains(where: { $0.foryIsNone }) - let trackRef = context.trackRef && Element.isRefType - let declaredElementType = hasGenerics && !TypeId.needsTypeInfoForField(Element.staticTypeId) - let dynamicElementType = Element.staticTypeId == .unknown + if dynamicElementType { + let refMode: RefMode + if trackRef { + refMode = .tracking + } else if hasNull { + refMode = .nullOnly + } else { + refMode = .none + } + for element in self { + try element.foryWrite( + context, refMode: refMode, writeTypeInfo: true, hasGenerics: hasGenerics) + } + return + } - var header: UInt8 = dynamicElementType ? 0 : CollectionHeader.sameType - if trackRef { - header |= CollectionHeader.trackingRef - } - if hasNull { - header |= CollectionHeader.hasNull - } - if declaredElementType { - header |= CollectionHeader.declaredElementType + if trackRef { + for element in self { + try element.foryWrite( + context, refMode: .tracking, writeTypeInfo: false, hasGenerics: hasGenerics) + } + } else if hasNull { + for element in self { + if element.foryIsNone { + buffer.writeInt8(RefFlag.null.rawValue) + } else { + buffer.writeInt8(RefFlag.notNullValue.rawValue) + try element.foryWriteData(context, hasGenerics: hasGenerics) } + } + } else { + for element in self { + try element.foryWriteData(context, hasGenerics: hasGenerics) + } + } + } + + public static func foryReadData(_ context: ReadContext) throws -> [Element] { + let buffer = context.buffer + let length = Int(try buffer.readVarUInt32()) + try context.ensureCollectionLength(length, label: "array") + if length == 0 { + try reserveGraphArrayMemory(context, Element.self, count: length) + return [] + } - buffer.writeUInt8(header) - if !dynamicElementType && !declaredElementType { - try Element.foryWriteStaticTypeInfo(context) + let header = try buffer.readUInt8() + let trackRef = (header & CollectionHeader.trackingRef) != 0 + let hasNull = (header & CollectionHeader.hasNull) != 0 + let declared = (header & CollectionHeader.declaredElementType) != 0 + let sameType = (header & CollectionHeader.sameType) != 0 + if !sameType { + try reserveGraphArrayMemory(context, Element.self, count: length) + try context.ensureRemainingBytes(length, label: "array") + if trackRef { + return try readArrayUninitialized(count: length) { destination in + for index in 0.. [Element] { - let buffer = context.buffer - let length = Int(try buffer.readVarUInt32()) - try context.ensureCollectionLength(length, label: "array") - if length == 0 { - try reserveContainerArrayMemory(context, Element.self, count: length) - return [] + let elementTypeInfo = declared ? nil : try Element.foryReadTypeInfo(context) + try reserveGraphArrayMemory(context, Element.self, count: length) + try context.ensureRemainingBytes(length, label: "array") + return try context.withTypeInfo(elementTypeInfo, for: Element.self) { + if trackRef { + return try readArrayUninitialized(count: length) { destination in + for index in 0.. Set { [] } + public static func foryDefault() -> Set { [] } - public static var staticTypeId: TypeId { .set } + public static var staticTypeId: TypeId { .set } - public static func foryWriteStaticTypeInfo(_ context: WriteContext) throws { - context.writeStaticTypeInfo(staticTypeId) - } + public static func foryWriteStaticTypeInfo(_ context: WriteContext) throws { + context.writeStaticTypeInfo(staticTypeId) + } - public static func foryReadTypeInfo(_ context: ReadContext) throws -> TypeInfo? { - try context.readStaticTypeInfo(staticTypeId) - } + public static func foryReadTypeInfo(_ context: ReadContext) throws -> TypeInfo? { + try context.readStaticTypeInfo(staticTypeId) + } - public func foryWriteData(_ context: WriteContext, hasGenerics: Bool) throws { - try Array(self).foryWriteData(context, hasGenerics: hasGenerics) - } + public func foryWriteData(_ context: WriteContext, hasGenerics: Bool) throws { + try Array(self).foryWriteData(context, hasGenerics: hasGenerics) + } - public static func foryReadData(_ context: ReadContext) throws -> Set { - let values = try [Element].foryReadData(context) - try reserveContainerArrayMemory(context, Element.self, count: values.count) - return Set(values) - } + public static func foryReadData(_ context: ReadContext) throws -> Set { + let values = try [Element].foryReadData(context) + try reserveGraphArrayMemory(context, Element.self, count: values.count) + return Set(values) + } } extension Dictionary: Serializer where Key: Serializer & Hashable, Value: Serializer { - public static func foryDefault() -> [Key: Value] { [:] } + public static func foryDefault() -> [Key: Value] { [:] } - public static var staticTypeId: TypeId { .map } + public static var staticTypeId: TypeId { .map } - public static func foryWriteStaticTypeInfo(_ context: WriteContext) throws { - context.writeStaticTypeInfo(staticTypeId) - } + public static func foryWriteStaticTypeInfo(_ context: WriteContext) throws { + context.writeStaticTypeInfo(staticTypeId) + } - public static func foryReadTypeInfo(_ context: ReadContext) throws -> TypeInfo? { - try context.readStaticTypeInfo(staticTypeId) + public static func foryReadTypeInfo(_ context: ReadContext) throws -> TypeInfo? { + try context.readStaticTypeInfo(staticTypeId) + } + + public func foryWriteData(_ context: WriteContext, hasGenerics: Bool) throws { + context.buffer.writeVarUInt32(UInt32(self.count)) + if self.isEmpty { + return } - public func foryWriteData(_ context: WriteContext, hasGenerics: Bool) throws { - context.buffer.writeVarUInt32(UInt32(self.count)) - if self.isEmpty { - return + let trackKeyRef = context.trackRef && Key.isRefType + let trackValueRef = context.trackRef && Value.isRefType + let keyDeclared = hasGenerics && !TypeId.needsTypeInfoForField(Key.staticTypeId) + let valueDeclared = hasGenerics && !TypeId.needsTypeInfoForField(Value.staticTypeId) + let keyDynamicType = Key.staticTypeId == .unknown + let valueDynamicType = Value.staticTypeId == .unknown + + if keyDynamicType || valueDynamicType { + for pair in self { + let keyIsNil = pair.key.foryIsNone + let valueIsNil = pair.value.foryIsNone + var header: UInt8 = 0 + if trackKeyRef { + header |= MapHeader.trackingKeyRef + } + if trackValueRef { + header |= MapHeader.trackingValueRef + } + if keyIsNil { + header |= MapHeader.keyNull + } else if !keyDynamicType && keyDeclared { + header |= MapHeader.declaredKeyType } + if valueIsNil { + header |= MapHeader.valueNull + } else if !valueDynamicType && valueDeclared { + header |= MapHeader.declaredValueType + } + context.buffer.writeUInt8(header) - let trackKeyRef = context.trackRef && Key.isRefType - let trackValueRef = context.trackRef && Value.isRefType - let keyDeclared = hasGenerics && !TypeId.needsTypeInfoForField(Key.staticTypeId) - let valueDeclared = hasGenerics && !TypeId.needsTypeInfoForField(Value.staticTypeId) - let keyDynamicType = Key.staticTypeId == .unknown - let valueDynamicType = Value.staticTypeId == .unknown - - if keyDynamicType || valueDynamicType { - for pair in self { - let keyIsNil = pair.key.foryIsNone - let valueIsNil = pair.value.foryIsNone - var header: UInt8 = 0 - if trackKeyRef { - header |= MapHeader.trackingKeyRef - } - if trackValueRef { - header |= MapHeader.trackingValueRef - } - if keyIsNil { - header |= MapHeader.keyNull - } else if !keyDynamicType && keyDeclared { - header |= MapHeader.declaredKeyType - } - if valueIsNil { - header |= MapHeader.valueNull - } else if !valueDynamicType && valueDeclared { - header |= MapHeader.declaredValueType - } - context.buffer.writeUInt8(header) - - if keyIsNil && valueIsNil { - continue - } - if keyIsNil { - if !valueDeclared { - if valueDynamicType { - try pair.value.foryWriteTypeInfo(context) - } else { - try Value.foryWriteStaticTypeInfo(context) - } - } - if trackValueRef { - try pair.value.foryWrite( - context, - refMode: .tracking, - writeTypeInfo: false, - hasGenerics: hasGenerics - ) - } else { - try pair.value.foryWriteData(context, hasGenerics: hasGenerics) - } - continue - } - - if valueIsNil { - if !keyDeclared { - if keyDynamicType { - try pair.key.foryWriteTypeInfo(context) - } else { - try Key.foryWriteStaticTypeInfo(context) - } - } - if trackKeyRef { - try pair.key.foryWrite( - context, - refMode: .tracking, - writeTypeInfo: false, - hasGenerics: hasGenerics - ) - } else { - try pair.key.foryWriteData(context, hasGenerics: hasGenerics) - } - continue - } - - context.buffer.writeUInt8(1) - - if !keyDeclared { - if keyDynamicType { - try pair.key.foryWriteTypeInfo(context) - } else { - try Key.foryWriteStaticTypeInfo(context) - } - } - if !valueDeclared { - if valueDynamicType { - try pair.value.foryWriteTypeInfo(context) - } else { - try Value.foryWriteStaticTypeInfo(context) - } - } - - if trackKeyRef { - try pair.key.foryWrite( - context, - refMode: .tracking, - writeTypeInfo: false, - hasGenerics: hasGenerics - ) - } else { - try pair.key.foryWriteData(context, hasGenerics: hasGenerics) - } - if trackValueRef { - try pair.value.foryWrite( - context, - refMode: .tracking, - writeTypeInfo: false, - hasGenerics: hasGenerics - ) - } else { - try pair.value.foryWriteData(context, hasGenerics: hasGenerics) - } + if keyIsNil && valueIsNil { + continue + } + if keyIsNil { + if !valueDeclared { + if valueDynamicType { + try pair.value.foryWriteTypeInfo(context) + } else { + try Value.foryWriteStaticTypeInfo(context) } - return + } + if trackValueRef { + try pair.value.foryWrite( + context, + refMode: .tracking, + writeTypeInfo: false, + hasGenerics: hasGenerics + ) + } else { + try pair.value.foryWriteData(context, hasGenerics: hasGenerics) + } + continue } - var iterator = makeIterator() - var pendingPair = iterator.next() - - while let pair = pendingPair { - let keyIsNil = pair.key.foryIsNone - let valueIsNil = pair.value.foryIsNone - - if keyIsNil || valueIsNil { - var header: UInt8 = 0 - if trackKeyRef { - header |= MapHeader.trackingKeyRef - } - if trackValueRef { - header |= MapHeader.trackingValueRef - } - if keyIsNil { header |= MapHeader.keyNull } - if valueIsNil { header |= MapHeader.valueNull } - if !keyIsNil && keyDeclared { header |= MapHeader.declaredKeyType } - if !valueIsNil && valueDeclared { header |= MapHeader.declaredValueType } - - context.buffer.writeUInt8(header) - if !keyIsNil { - if !keyDeclared { - try Key.foryWriteStaticTypeInfo(context) - } - if trackKeyRef { - try pair.key.foryWrite(context, refMode: .tracking, writeTypeInfo: false, hasGenerics: hasGenerics) - } else { - try pair.key.foryWriteData(context, hasGenerics: hasGenerics) - } - } - if !valueIsNil { - if !valueDeclared { - try Value.foryWriteStaticTypeInfo(context) - } - if trackValueRef { - try pair.value.foryWrite(context, refMode: .tracking, writeTypeInfo: false, hasGenerics: hasGenerics) - } else { - try pair.value.foryWriteData(context, hasGenerics: hasGenerics) - } - } - pendingPair = iterator.next() - continue + if valueIsNil { + if !keyDeclared { + if keyDynamicType { + try pair.key.foryWriteTypeInfo(context) + } else { + try Key.foryWriteStaticTypeInfo(context) } + } + if trackKeyRef { + try pair.key.foryWrite( + context, + refMode: .tracking, + writeTypeInfo: false, + hasGenerics: hasGenerics + ) + } else { + try pair.key.foryWriteData(context, hasGenerics: hasGenerics) + } + continue + } - var header: UInt8 = 0 - if trackKeyRef { header |= MapHeader.trackingKeyRef } - if trackValueRef { header |= MapHeader.trackingValueRef } - if keyDeclared { header |= MapHeader.declaredKeyType } - if valueDeclared { header |= MapHeader.declaredValueType } - - context.buffer.writeUInt8(header) - let chunkSizeOffset = context.buffer.count - context.buffer.writeUInt8(0) + context.buffer.writeUInt8(1) - if !keyDeclared { - try Key.foryWriteStaticTypeInfo(context) - } - if !valueDeclared { - try Value.foryWriteStaticTypeInfo(context) - } + if !keyDeclared { + if keyDynamicType { + try pair.key.foryWriteTypeInfo(context) + } else { + try Key.foryWriteStaticTypeInfo(context) + } + } + if !valueDeclared { + if valueDynamicType { + try pair.value.foryWriteTypeInfo(context) + } else { + try Value.foryWriteStaticTypeInfo(context) + } + } - var chunkSize: UInt8 = 0 - while chunkSize < UInt8.max, let current = pendingPair { - if current.key.foryIsNone || current.value.foryIsNone { - break - } - if trackKeyRef { - try current.key.foryWrite(context, refMode: .tracking, writeTypeInfo: false, hasGenerics: hasGenerics) - } else { - try current.key.foryWriteData(context, hasGenerics: hasGenerics) - } - if trackValueRef { - try current.value.foryWrite(context, refMode: .tracking, writeTypeInfo: false, hasGenerics: hasGenerics) - } else { - try current.value.foryWriteData(context, hasGenerics: hasGenerics) - } - chunkSize &+= 1 - pendingPair = iterator.next() - } - context.buffer.setByte(at: chunkSizeOffset, to: chunkSize) + if trackKeyRef { + try pair.key.foryWrite( + context, + refMode: .tracking, + writeTypeInfo: false, + hasGenerics: hasGenerics + ) + } else { + try pair.key.foryWriteData(context, hasGenerics: hasGenerics) + } + if trackValueRef { + try pair.value.foryWrite( + context, + refMode: .tracking, + writeTypeInfo: false, + hasGenerics: hasGenerics + ) + } else { + try pair.value.foryWriteData(context, hasGenerics: hasGenerics) } + } + return } - public static func foryReadData(_ context: ReadContext) throws -> [Key: Value] { - let totalLength = Int(try context.buffer.readVarUInt32()) - try context.ensureCollectionLength(totalLength, label: "map") - if totalLength == 0 { - try reserveContainerMapMemory(context, key: Key.self, value: Value.self, count: totalLength) - return [:] - } + var iterator = makeIterator() + var pendingPair = iterator.next() - try reserveContainerMapMemory(context, key: Key.self, value: Value.self, count: totalLength) - try context.ensureRemainingBytes(totalLength, label: "map") - var map: [Key: Value] = [:] - map.reserveCapacity(totalLength) - let keyDynamicType = Key.staticTypeId == .unknown - let valueDynamicType = Value.staticTypeId == .unknown - if keyDynamicType || valueDynamicType { - var dynamicReadCount = 0 - while dynamicReadCount < totalLength { - let header = try context.buffer.readUInt8() - let trackKeyRef = (header & MapHeader.trackingKeyRef) != 0 - let keyNull = (header & MapHeader.keyNull) != 0 - let keyDeclared = (header & MapHeader.declaredKeyType) != 0 - - let trackValueRef = (header & MapHeader.trackingValueRef) != 0 - let valueNull = (header & MapHeader.valueNull) != 0 - let valueDeclared = (header & MapHeader.declaredValueType) != 0 - - if keyNull && valueNull { - map[Key.foryDefault()] = Value.foryDefault() - dynamicReadCount += 1 - continue - } - - if keyNull { - let value = try Value.foryRead( - context, - refMode: trackValueRef ? .tracking : .none, - readTypeInfo: valueDynamicType || !valueDeclared - ) - map[Key.foryDefault()] = value - dynamicReadCount += 1 - continue - } - - if valueNull { - let key = try Key.foryRead( - context, - refMode: trackKeyRef ? .tracking : .none, - readTypeInfo: keyDynamicType || !keyDeclared - ) - map[key] = Value.foryDefault() - dynamicReadCount += 1 - continue - } - - let chunkSize = Int(try context.buffer.readUInt8()) - if chunkSize > (totalLength - dynamicReadCount) { - throw ForyError.invalidData("map dynamic chunk size exceeds remaining entries") - } - let keyTypeInfo = keyDeclared ? nil : try Key.foryReadTypeInfo(context) - let valueTypeInfo = valueDeclared ? nil : try Value.foryReadTypeInfo(context) - for _ in 0.. [Key: Value] { + let totalLength = Int(try context.buffer.readVarUInt32()) + try context.ensureCollectionLength(totalLength, label: "map") + if totalLength == 0 { + try reserveGraphMapMemory(context, key: Key.self, value: Value.self, count: totalLength) + return [:] + } - var readCount = 0 - while readCount < totalLength { - let header = try context.buffer.readUInt8() - let trackKeyRef = (header & MapHeader.trackingKeyRef) != 0 - let keyNull = (header & MapHeader.keyNull) != 0 - let keyDeclared = (header & MapHeader.declaredKeyType) != 0 - - let trackValueRef = (header & MapHeader.trackingValueRef) != 0 - let valueNull = (header & MapHeader.valueNull) != 0 - let valueDeclared = (header & MapHeader.declaredValueType) != 0 - - if keyNull && valueNull { - map[Key.foryDefault()] = Value.foryDefault() - readCount += 1 - continue - } + try reserveGraphMapMemory(context, key: Key.self, value: Value.self, count: totalLength) + try context.ensureRemainingBytes(totalLength, label: "map") + var map: [Key: Value] = [:] + map.reserveCapacity(totalLength) + let keyDynamicType = Key.staticTypeId == .unknown + let valueDynamicType = Value.staticTypeId == .unknown + if keyDynamicType || valueDynamicType { + var dynamicReadCount = 0 + while dynamicReadCount < totalLength { + let header = try context.buffer.readUInt8() + let trackKeyRef = (header & MapHeader.trackingKeyRef) != 0 + let keyNull = (header & MapHeader.keyNull) != 0 + let keyDeclared = (header & MapHeader.declaredKeyType) != 0 + + let trackValueRef = (header & MapHeader.trackingValueRef) != 0 + let valueNull = (header & MapHeader.valueNull) != 0 + let valueDeclared = (header & MapHeader.declaredValueType) != 0 + + if keyNull && valueNull { + map[Key.foryDefault()] = Value.foryDefault() + dynamicReadCount += 1 + continue + } - if keyNull { - let value = try Value.foryRead( - context, - refMode: trackValueRef ? .tracking : .none, - readTypeInfo: !valueDeclared - ) - map[Key.foryDefault()] = value - readCount += 1 - continue - } + if keyNull { + let value = try Value.foryRead( + context, + refMode: trackValueRef ? .tracking : .none, + readTypeInfo: valueDynamicType || !valueDeclared + ) + map[Key.foryDefault()] = value + dynamicReadCount += 1 + continue + } - if valueNull { - let key = try Key.foryRead( - context, - refMode: trackKeyRef ? .tracking : .none, - readTypeInfo: !keyDeclared - ) - map[key] = Value.foryDefault() - readCount += 1 - continue - } + if valueNull { + let key = try Key.foryRead( + context, + refMode: trackKeyRef ? .tracking : .none, + readTypeInfo: keyDynamicType || !keyDeclared + ) + map[key] = Value.foryDefault() + dynamicReadCount += 1 + continue + } - let chunkSize = Int(try context.buffer.readUInt8()) - if chunkSize > (totalLength - readCount) { - throw ForyError.invalidData("map chunk size exceeds remaining entries") - } - let keyTypeInfo = keyDeclared ? nil : try Key.foryReadTypeInfo(context) - let valueTypeInfo = valueDeclared ? nil : try Value.foryReadTypeInfo(context) - for _ in 0.. (totalLength - dynamicReadCount) { + throw ForyError.invalidData("map dynamic chunk size exceeds remaining entries") + } + let keyTypeInfo = keyDeclared ? nil : try Key.foryReadTypeInfo(context) + let valueTypeInfo = valueDeclared ? nil : try Value.foryReadTypeInfo(context) + for _ in 0.. (totalLength - readCount) { + throw ForyError.invalidData("map chunk size exceeds remaining entries") + } + let keyTypeInfo = keyDeclared ? nil : try Key.foryReadTypeInfo(context) + let valueTypeInfo = valueDeclared ? nil : try Value.foryReadTypeInfo(context) + for _ in 0..(_ codec: ElementCodec.Type) -> Int { - codec.isRefType ? fieldReferenceBytes : max(1, MemoryLayout.stride) + codec.isRefType ? fieldReferenceBytes : max(1, MemoryLayout.stride) } @inline(__always) private func serializerElementBytes(_ type: Element.Type) -> Int { - type.isRefType ? fieldReferenceBytes : max(1, MemoryLayout.stride) + type.isRefType ? fieldReferenceBytes : max(1, MemoryLayout.stride) } @inline(__always) private func reserveFieldArrayStorage( - _ context: ReadContext, - _ codec: ElementCodec.Type, - count: Int + _ context: ReadContext, + _ codec: ElementCodec.Type, + count: Int ) throws { - try context.reserveCountedContainerMemory(count: count, elementBytes: fieldElementBytes(codec)) + try context.reserveCountedGraphMemory(count: count, elementBytes: fieldElementBytes(codec)) } @inline(__always) private func reserveSerializerArrayMemory( - _ context: ReadContext, - _ type: Element.Type, - count: Int + _ context: ReadContext, + _ type: Element.Type, + count: Int ) throws { - try context.reserveCountedContainerMemory(count: count, elementBytes: serializerElementBytes(type)) + try context.reserveCountedGraphMemory(count: count, elementBytes: serializerElementBytes(type)) } @inline(__always) private func reserveFieldMapStorage( - _ context: ReadContext, - key: KeyCodec.Type, - value: ValueCodec.Type, - count: Int + _ context: ReadContext, + key: KeyCodec.Type, + value: ValueCodec.Type, + count: Int ) throws { - let keyBytes = fieldElementBytes(key) - let valueBytes = fieldElementBytes(value) - let (elementBytes, overflow) = keyBytes.addingReportingOverflow(valueBytes) - if overflow { - try context.reserveContainerMemory(-1) - } - try context.reserveCountedContainerMemory(count: count, elementBytes: elementBytes) + let keyBytes = fieldElementBytes(key) + let valueBytes = fieldElementBytes(value) + let (elementBytes, overflow) = keyBytes.addingReportingOverflow(valueBytes) + if overflow { + try context.reserveGraphMemory(-1) + } + try context.reserveCountedGraphMemory(count: count, elementBytes: elementBytes) } public protocol FieldCodec { - associatedtype Value - - static var typeId: TypeId { get } - static var defaultValue: Value { get } - static var isNullableType: Bool { get } - static var isRefType: Bool { get } - - static func isNone(_ value: Value) -> Bool - static func fieldType(nullable: Bool, trackRef: Bool) -> TypeMeta.FieldType - static func writePayload(_ value: Value, _ context: WriteContext) throws - static func readPayload(_ context: ReadContext) throws -> Value - static func writeStaticTypeInfo(_ context: WriteContext) throws - static func readTypeInfo(_ context: ReadContext) throws -> TypeInfo? - static func withTypeInfo(_ typeInfo: TypeInfo?, _ context: ReadContext, _ body: () throws -> R) rethrows -> R - static func readCompatibleField( - _ context: ReadContext, - remoteFieldType: TypeMeta.FieldType, - refMode: RefMode - ) throws -> Value + associatedtype Value + + static var typeId: TypeId { get } + static var defaultValue: Value { get } + static var isNullableType: Bool { get } + static var isRefType: Bool { get } + + static func isNone(_ value: Value) -> Bool + static func fieldType(nullable: Bool, trackRef: Bool) -> TypeMeta.FieldType + static func writePayload(_ value: Value, _ context: WriteContext) throws + static func readPayload(_ context: ReadContext) throws -> Value + static func writeStaticTypeInfo(_ context: WriteContext) throws + static func readTypeInfo(_ context: ReadContext) throws -> TypeInfo? + static func withTypeInfo(_ typeInfo: TypeInfo?, _ context: ReadContext, _ body: () throws -> R) + rethrows -> R + static func readCompatibleField( + _ context: ReadContext, + remoteFieldType: TypeMeta.FieldType, + refMode: RefMode + ) throws -> Value } -public extension FieldCodec { - static var isNullableType: Bool { false } - static var isRefType: Bool { false } - - static func isNone(_: Value) -> Bool { false } +extension FieldCodec { + public static var isNullableType: Bool { false } + public static var isRefType: Bool { false } - static func fieldType(nullable: Bool, trackRef: Bool) -> TypeMeta.FieldType { - TypeMeta.FieldType(typeID: typeId.rawValue, nullable: nullable, trackRef: trackRef) - } + public static func isNone(_: Value) -> Bool { false } - static func writeStaticTypeInfo(_ context: WriteContext) throws { - context.writeStaticTypeInfo(typeId) - } + public static func fieldType(nullable: Bool, trackRef: Bool) -> TypeMeta.FieldType { + TypeMeta.FieldType(typeID: typeId.rawValue, nullable: nullable, trackRef: trackRef) + } - static func readTypeInfo(_ context: ReadContext) throws -> TypeInfo? { - try context.readStaticTypeInfo(typeId) - } + public static func writeStaticTypeInfo(_ context: WriteContext) throws { + context.writeStaticTypeInfo(typeId) + } - static func withTypeInfo( - _ typeInfo: TypeInfo?, - _ context: ReadContext, - _ body: () throws -> R - ) rethrows -> R { - _ = typeInfo - _ = context - return try body() - } + public static func readTypeInfo(_ context: ReadContext) throws -> TypeInfo? { + try context.readStaticTypeInfo(typeId) + } - static func readCompatibleField( - _ context: ReadContext, - remoteFieldType: TypeMeta.FieldType, - refMode: RefMode - ) throws -> Value { - try read( - context, - refMode: refMode, - readTypeInfo: TypeId.needsTypeInfoForField(TypeId(rawValue: remoteFieldType.typeID) ?? .unknown) - ) - } - - static func write( - _ value: Value, - _ context: WriteContext, - refMode: RefMode, - writeTypeInfo: Bool - ) throws { - if refMode != .none { - if refMode == .tracking, isRefType, let object = value as AnyObject? { - if context.refWriter.tryWriteRef(buffer: context.buffer, object: object) { - return - } - } else { - context.buffer.writeInt8(RefFlag.notNullValue.rawValue) - } - } - - if writeTypeInfo { - try writeStaticTypeInfo(context) + public static func withTypeInfo( + _ typeInfo: TypeInfo?, + _ context: ReadContext, + _ body: () throws -> R + ) rethrows -> R { + _ = typeInfo + _ = context + return try body() + } + + public static func readCompatibleField( + _ context: ReadContext, + remoteFieldType: TypeMeta.FieldType, + refMode: RefMode + ) throws -> Value { + try read( + context, + refMode: refMode, + readTypeInfo: TypeId.needsTypeInfoForField( + TypeId(rawValue: remoteFieldType.typeID) ?? .unknown) + ) + } + + public static func write( + _ value: Value, + _ context: WriteContext, + refMode: RefMode, + writeTypeInfo: Bool + ) throws { + if refMode != .none { + if refMode == .tracking, isRefType, let object = value as AnyObject? { + if context.refWriter.tryWriteRef(buffer: context.buffer, object: object) { + return } - try writePayload(value, context) + } else { + context.buffer.writeInt8(RefFlag.notNullValue.rawValue) + } } - static func read( - _ context: ReadContext, - refMode: RefMode, - readTypeInfo: Bool - ) throws -> Value { - switch refMode { - case .none: - if readTypeInfo { - let typeInfo = try Self.readTypeInfo(context) - return try withTypeInfo(typeInfo, context) { - try readPayload(context) - } - } - return try readPayload(context) - case .nullOnly: - let rawFlag = try context.buffer.readInt8() - switch rawFlag { - case RefFlag.null.rawValue: - return defaultValue - case RefFlag.notNullValue.rawValue: - if readTypeInfo { - let typeInfo = try Self.readTypeInfo(context) - return try withTypeInfo(typeInfo, context) { - try readPayload(context) - } - } - return try readPayload(context) - case RefFlag.refValue.rawValue: - if context.trackRef { - let reservedRefID = context.refReader.reserveRefID() - let value = try readPayloadAfterTypeInfo(context, readTypeInfo: readTypeInfo) - context.refReader.storeRef(value, at: reservedRefID) - return value - } - return try readPayloadAfterTypeInfo(context, readTypeInfo: readTypeInfo) - case RefFlag.ref.rawValue: - let refID = try context.buffer.readVarUInt32() - return try context.refReader.readRef(refID, as: Value.self) - default: - throw ForyError.refError("invalid ref flag \(rawFlag)") - } - case .tracking: - let rawFlag = try context.buffer.readInt8() - guard let flag = RefFlag(rawValue: rawFlag) else { - throw ForyError.refError("invalid ref flag \(rawFlag)") - } - switch flag { - case .null: - return defaultValue - case .ref: - let refID = try context.buffer.readVarUInt32() - return try context.refReader.readRef(refID, as: Value.self) - case .refValue: - let reservedRefID = context.trackRef ? context.refReader.reserveRefID() : nil - let value = try readPayloadAfterTypeInfo(context, readTypeInfo: readTypeInfo) - if let reservedRefID { - context.refReader.storeRef(value, at: reservedRefID) - } - return value - case .notNullValue: - return try readPayloadAfterTypeInfo(context, readTypeInfo: readTypeInfo) - } - } + if writeTypeInfo { + try writeStaticTypeInfo(context) } + try writePayload(value, context) + } - private static func readPayloadAfterTypeInfo( - _ context: ReadContext, - readTypeInfo: Bool - ) throws -> Value { + public static func read( + _ context: ReadContext, + refMode: RefMode, + readTypeInfo: Bool + ) throws -> Value { + switch refMode { + case .none: + if readTypeInfo { + let typeInfo = try Self.readTypeInfo(context) + return try withTypeInfo(typeInfo, context) { + try readPayload(context) + } + } + return try readPayload(context) + case .nullOnly: + let rawFlag = try context.buffer.readInt8() + switch rawFlag { + case RefFlag.null.rawValue: + return defaultValue + case RefFlag.notNullValue.rawValue: if readTypeInfo { - let typeInfo = try Self.readTypeInfo(context) - return try withTypeInfo(typeInfo, context) { - try readPayload(context) - } + let typeInfo = try Self.readTypeInfo(context) + return try withTypeInfo(typeInfo, context) { + try readPayload(context) + } } return try readPayload(context) + case RefFlag.refValue.rawValue: + if context.trackRef { + let reservedRefID = context.refReader.reserveRefID() + let value = try readPayloadAfterTypeInfo(context, readTypeInfo: readTypeInfo) + context.refReader.storeRef(value, at: reservedRefID) + return value + } + return try readPayloadAfterTypeInfo(context, readTypeInfo: readTypeInfo) + case RefFlag.ref.rawValue: + let refID = try context.buffer.readVarUInt32() + return try context.refReader.readRef(refID, as: Value.self) + default: + throw ForyError.refError("invalid ref flag \(rawFlag)") + } + case .tracking: + let rawFlag = try context.buffer.readInt8() + guard let flag = RefFlag(rawValue: rawFlag) else { + throw ForyError.refError("invalid ref flag \(rawFlag)") + } + switch flag { + case .null: + return defaultValue + case .ref: + let refID = try context.buffer.readVarUInt32() + return try context.refReader.readRef(refID, as: Value.self) + case .refValue: + let reservedRefID = context.trackRef ? context.refReader.reserveRefID() : nil + let value = try readPayloadAfterTypeInfo(context, readTypeInfo: readTypeInfo) + if let reservedRefID { + context.refReader.storeRef(value, at: reservedRefID) + } + return value + case .notNullValue: + return try readPayloadAfterTypeInfo(context, readTypeInfo: readTypeInfo) + } } + } + + private static func readPayloadAfterTypeInfo( + _ context: ReadContext, + readTypeInfo: Bool + ) throws -> Value { + if readTypeInfo { + let typeInfo = try Self.readTypeInfo(context) + return try withTypeInfo(typeInfo, context) { + try readPayload(context) + } + } + return try readPayload(context) + } } private enum FieldCodecDefault { - static func readCompatibleField( - codec _: Codec.Type, - _ context: ReadContext, - remoteFieldType: TypeMeta.FieldType, - refMode: RefMode - ) throws -> Codec.Value { - try Codec.read( - context, - refMode: refMode, - readTypeInfo: TypeId.needsTypeInfoForField(TypeId(rawValue: remoteFieldType.typeID) ?? .unknown) - ) - } + static func readCompatibleField( + codec _: Codec.Type, + _ context: ReadContext, + remoteFieldType: TypeMeta.FieldType, + refMode: RefMode + ) throws -> Codec.Value { + try Codec.read( + context, + refMode: refMode, + readTypeInfo: TypeId.needsTypeInfoForField( + TypeId(rawValue: remoteFieldType.typeID) ?? .unknown) + ) + } } public enum SerializerCodec: FieldCodec { - public typealias Value = T + public typealias Value = T - public static var typeId: TypeId { T.staticTypeId } - public static var defaultValue: T { T.foryDefault() } - public static var isNullableType: Bool { T.isNullableType } - public static var isRefType: Bool { T.isRefType } + public static var typeId: TypeId { T.staticTypeId } + public static var defaultValue: T { T.foryDefault() } + public static var isNullableType: Bool { T.isNullableType } + public static var isRefType: Bool { T.isRefType } - public static func isNone(_ value: T) -> Bool { - value.foryIsNone - } + public static func isNone(_ value: T) -> Bool { + value.foryIsNone + } - public static func fieldType(nullable: Bool, trackRef: Bool) -> TypeMeta.FieldType { - let fieldTypeID = T.staticTypeId == .structType ? TypeId.compatibleStruct.rawValue : T.staticTypeId.rawValue - return TypeMeta.FieldType(typeID: fieldTypeID, nullable: nullable, trackRef: trackRef) - } + public static func fieldType(nullable: Bool, trackRef: Bool) -> TypeMeta.FieldType { + let fieldTypeID = + T.staticTypeId == .structType ? TypeId.compatibleStruct.rawValue : T.staticTypeId.rawValue + return TypeMeta.FieldType(typeID: fieldTypeID, nullable: nullable, trackRef: trackRef) + } - public static func writePayload(_ value: T, _ context: WriteContext) throws { - try value.foryWriteData(context, hasGenerics: false) - } + public static func writePayload(_ value: T, _ context: WriteContext) throws { + try value.foryWriteData(context, hasGenerics: false) + } - public static func readPayload(_ context: ReadContext) throws -> T { - try T.foryReadPayload(context, readTypeInfo: false) - } + public static func readPayload(_ context: ReadContext) throws -> T { + try T.foryReadPayload(context, readTypeInfo: false) + } - public static func writeStaticTypeInfo(_ context: WriteContext) throws { - try T.foryWriteStaticTypeInfo(context) - } + public static func writeStaticTypeInfo(_ context: WriteContext) throws { + try T.foryWriteStaticTypeInfo(context) + } - public static func readTypeInfo(_ context: ReadContext) throws -> TypeInfo? { - try T.foryReadTypeInfo(context) - } + public static func readTypeInfo(_ context: ReadContext) throws -> TypeInfo? { + try T.foryReadTypeInfo(context) + } - public static func withTypeInfo( - _ typeInfo: TypeInfo?, - _ context: ReadContext, - _ body: () throws -> R - ) rethrows -> R { - try context.withTypeInfo(typeInfo, for: T.self, body) - } + public static func withTypeInfo( + _ typeInfo: TypeInfo?, + _ context: ReadContext, + _ body: () throws -> R + ) rethrows -> R { + try context.withTypeInfo(typeInfo, for: T.self, body) + } - public static func write( - _ value: T, - _ context: WriteContext, - refMode: RefMode, - writeTypeInfo: Bool - ) throws { - try value.foryWrite(context, refMode: refMode, writeTypeInfo: writeTypeInfo, hasGenerics: false) - } + public static func write( + _ value: T, + _ context: WriteContext, + refMode: RefMode, + writeTypeInfo: Bool + ) throws { + try value.foryWrite(context, refMode: refMode, writeTypeInfo: writeTypeInfo, hasGenerics: false) + } - public static func read( - _ context: ReadContext, - refMode: RefMode, - readTypeInfo: Bool - ) throws -> T { - try T.foryRead(context, refMode: refMode, readTypeInfo: readTypeInfo) - } + public static func read( + _ context: ReadContext, + refMode: RefMode, + readTypeInfo: Bool + ) throws -> T { + try T.foryRead(context, refMode: refMode, readTypeInfo: readTypeInfo) + } } public enum OptionalFieldCodec: FieldCodec { - public typealias Value = WrappedCodec.Value? + public typealias Value = WrappedCodec.Value? - public static var typeId: TypeId { WrappedCodec.typeId } - public static var defaultValue: Value { nil } - public static var isNullableType: Bool { true } - public static var isRefType: Bool { WrappedCodec.isRefType } + public static var typeId: TypeId { WrappedCodec.typeId } + public static var defaultValue: Value { nil } + public static var isNullableType: Bool { true } + public static var isRefType: Bool { WrappedCodec.isRefType } - public static func isNone(_ value: Value) -> Bool { - value == nil - } + public static func isNone(_ value: Value) -> Bool { + value == nil + } - public static func fieldType(nullable: Bool, trackRef: Bool) -> TypeMeta.FieldType { - WrappedCodec.fieldType(nullable: nullable, trackRef: trackRef) - } + public static func fieldType(nullable: Bool, trackRef: Bool) -> TypeMeta.FieldType { + WrappedCodec.fieldType(nullable: nullable, trackRef: trackRef) + } - public static func writePayload(_ value: Value, _ context: WriteContext) throws { - guard let value else { - throw ForyError.invalidData("Option.none cannot write raw payload") - } - try WrappedCodec.writePayload(value, context) + public static func writePayload(_ value: Value, _ context: WriteContext) throws { + guard let value else { + throw ForyError.invalidData("Option.none cannot write raw payload") } + try WrappedCodec.writePayload(value, context) + } - public static func readPayload(_ context: ReadContext) throws -> Value { - try WrappedCodec.readPayload(context) - } + public static func readPayload(_ context: ReadContext) throws -> Value { + try WrappedCodec.readPayload(context) + } - public static func writeStaticTypeInfo(_ context: WriteContext) throws { - try WrappedCodec.writeStaticTypeInfo(context) - } + public static func writeStaticTypeInfo(_ context: WriteContext) throws { + try WrappedCodec.writeStaticTypeInfo(context) + } - public static func readTypeInfo(_ context: ReadContext) throws -> TypeInfo? { - try WrappedCodec.readTypeInfo(context) - } + public static func readTypeInfo(_ context: ReadContext) throws -> TypeInfo? { + try WrappedCodec.readTypeInfo(context) + } - public static func withTypeInfo( - _ typeInfo: TypeInfo?, - _ context: ReadContext, - _ body: () throws -> R - ) rethrows -> R { - try WrappedCodec.withTypeInfo(typeInfo, context, body) - } + public static func withTypeInfo( + _ typeInfo: TypeInfo?, + _ context: ReadContext, + _ body: () throws -> R + ) rethrows -> R { + try WrappedCodec.withTypeInfo(typeInfo, context, body) + } - public static func write( - _ value: Value, - _ context: WriteContext, - refMode: RefMode, - writeTypeInfo: Bool - ) throws { - switch refMode { - case .none: - guard let value else { - throw ForyError.invalidData("Option.none with RefMode.none") - } - try WrappedCodec.write(value, context, refMode: .none, writeTypeInfo: writeTypeInfo) - case .nullOnly: - guard let value else { - context.buffer.writeInt8(RefFlag.null.rawValue) - return - } - context.buffer.writeInt8(RefFlag.notNullValue.rawValue) - try WrappedCodec.write(value, context, refMode: .none, writeTypeInfo: writeTypeInfo) - case .tracking: - guard let value else { - context.buffer.writeInt8(RefFlag.null.rawValue) - return - } - try WrappedCodec.write(value, context, refMode: .tracking, writeTypeInfo: writeTypeInfo) - } + public static func write( + _ value: Value, + _ context: WriteContext, + refMode: RefMode, + writeTypeInfo: Bool + ) throws { + switch refMode { + case .none: + guard let value else { + throw ForyError.invalidData("Option.none with RefMode.none") + } + try WrappedCodec.write(value, context, refMode: .none, writeTypeInfo: writeTypeInfo) + case .nullOnly: + guard let value else { + context.buffer.writeInt8(RefFlag.null.rawValue) + return + } + context.buffer.writeInt8(RefFlag.notNullValue.rawValue) + try WrappedCodec.write(value, context, refMode: .none, writeTypeInfo: writeTypeInfo) + case .tracking: + guard let value else { + context.buffer.writeInt8(RefFlag.null.rawValue) + return + } + try WrappedCodec.write(value, context, refMode: .tracking, writeTypeInfo: writeTypeInfo) } + } - public static func read( - _ context: ReadContext, - refMode: RefMode, - readTypeInfo: Bool - ) throws -> Value { - switch refMode { - case .none: - return try WrappedCodec.read(context, refMode: .none, readTypeInfo: readTypeInfo) - case .nullOnly: - let refFlag = try context.buffer.readInt8() - if refFlag == RefFlag.null.rawValue { - return nil - } - return try WrappedCodec.read(context, refMode: .none, readTypeInfo: readTypeInfo) - case .tracking: - let refFlag = try context.buffer.readInt8() - if refFlag == RefFlag.null.rawValue { - return nil - } - context.buffer.moveBack(1) - return try WrappedCodec.read(context, refMode: .tracking, readTypeInfo: readTypeInfo) - } + public static func read( + _ context: ReadContext, + refMode: RefMode, + readTypeInfo: Bool + ) throws -> Value { + switch refMode { + case .none: + return try WrappedCodec.read(context, refMode: .none, readTypeInfo: readTypeInfo) + case .nullOnly: + let refFlag = try context.buffer.readInt8() + if refFlag == RefFlag.null.rawValue { + return nil + } + return try WrappedCodec.read(context, refMode: .none, readTypeInfo: readTypeInfo) + case .tracking: + let refFlag = try context.buffer.readInt8() + if refFlag == RefFlag.null.rawValue { + return nil + } + context.buffer.moveBack(1) + return try WrappedCodec.read(context, refMode: .tracking, readTypeInfo: readTypeInfo) } + } } public enum BoolCodec: FieldCodec { - public static let typeId: TypeId = .bool - public static let defaultValue = false - public static func writePayload(_ value: Bool, _ context: WriteContext) { - context.buffer.writeUInt8(value ? 1 : 0) - } - public static func readPayload(_ context: ReadContext) throws -> Bool { - try context.buffer.readUInt8() != 0 - } + public static let typeId: TypeId = .bool + public static let defaultValue = false + public static func writePayload(_ value: Bool, _ context: WriteContext) { + context.buffer.writeUInt8(value ? 1 : 0) + } + public static func readPayload(_ context: ReadContext) throws -> Bool { + try context.buffer.readUInt8() != 0 + } } public enum Int8Codec: FieldCodec { - public static let typeId: TypeId = .int8 - public static let defaultValue = Int8(0) - public static func writePayload(_ value: Int8, _ context: WriteContext) { - context.buffer.writeInt8(value) - } - public static func readPayload(_ context: ReadContext) throws -> Int8 { - try context.buffer.readInt8() - } + public static let typeId: TypeId = .int8 + public static let defaultValue = Int8(0) + public static func writePayload(_ value: Int8, _ context: WriteContext) { + context.buffer.writeInt8(value) + } + public static func readPayload(_ context: ReadContext) throws -> Int8 { + try context.buffer.readInt8() + } } public enum Int16Codec: FieldCodec { - public static let typeId: TypeId = .int16 - public static let defaultValue = Int16(0) - public static func writePayload(_ value: Int16, _ context: WriteContext) { - context.buffer.writeInt16(value) - } - public static func readPayload(_ context: ReadContext) throws -> Int16 { - try context.buffer.readInt16() - } + public static let typeId: TypeId = .int16 + public static let defaultValue = Int16(0) + public static func writePayload(_ value: Int16, _ context: WriteContext) { + context.buffer.writeInt16(value) + } + public static func readPayload(_ context: ReadContext) throws -> Int16 { + try context.buffer.readInt16() + } } public enum Int32VarintCodec: FieldCodec { - public static let typeId: TypeId = .varint32 - public static let defaultValue = Int32(0) - public static func writePayload(_ value: Int32, _ context: WriteContext) { - context.buffer.writeVarInt32(value) - } - public static func readPayload(_ context: ReadContext) throws -> Int32 { - try context.buffer.readVarInt32() - } + public static let typeId: TypeId = .varint32 + public static let defaultValue = Int32(0) + public static func writePayload(_ value: Int32, _ context: WriteContext) { + context.buffer.writeVarInt32(value) + } + public static func readPayload(_ context: ReadContext) throws -> Int32 { + try context.buffer.readVarInt32() + } } public enum Int32FixedCodec: FieldCodec { - public static let typeId: TypeId = .int32 - public static let defaultValue = Int32(0) - public static func writePayload(_ value: Int32, _ context: WriteContext) { - context.buffer.writeInt32(value) - } - public static func readPayload(_ context: ReadContext) throws -> Int32 { - try context.buffer.readInt32() - } + public static let typeId: TypeId = .int32 + public static let defaultValue = Int32(0) + public static func writePayload(_ value: Int32, _ context: WriteContext) { + context.buffer.writeInt32(value) + } + public static func readPayload(_ context: ReadContext) throws -> Int32 { + try context.buffer.readInt32() + } } public enum Int64VarintCodec: FieldCodec { - public static let typeId: TypeId = .varint64 - public static let defaultValue = Int64(0) - public static func writePayload(_ value: Int64, _ context: WriteContext) { - context.buffer.writeVarInt64(value) - } - public static func readPayload(_ context: ReadContext) throws -> Int64 { - try context.buffer.readVarInt64() - } + public static let typeId: TypeId = .varint64 + public static let defaultValue = Int64(0) + public static func writePayload(_ value: Int64, _ context: WriteContext) { + context.buffer.writeVarInt64(value) + } + public static func readPayload(_ context: ReadContext) throws -> Int64 { + try context.buffer.readVarInt64() + } } public enum Int64FixedCodec: FieldCodec { - public static let typeId: TypeId = .int64 - public static let defaultValue = Int64(0) - public static func writePayload(_ value: Int64, _ context: WriteContext) { - context.buffer.writeInt64(value) - } - public static func readPayload(_ context: ReadContext) throws -> Int64 { - try context.buffer.readInt64() - } + public static let typeId: TypeId = .int64 + public static let defaultValue = Int64(0) + public static func writePayload(_ value: Int64, _ context: WriteContext) { + context.buffer.writeInt64(value) + } + public static func readPayload(_ context: ReadContext) throws -> Int64 { + try context.buffer.readInt64() + } } public enum Int64TaggedCodec: FieldCodec { - public static let typeId: TypeId = .taggedInt64 - public static let defaultValue = Int64(0) - public static func writePayload(_ value: Int64, _ context: WriteContext) { - context.buffer.writeTaggedInt64(value) - } - public static func readPayload(_ context: ReadContext) throws -> Int64 { - try context.buffer.readTaggedInt64() - } + public static let typeId: TypeId = .taggedInt64 + public static let defaultValue = Int64(0) + public static func writePayload(_ value: Int64, _ context: WriteContext) { + context.buffer.writeTaggedInt64(value) + } + public static func readPayload(_ context: ReadContext) throws -> Int64 { + try context.buffer.readTaggedInt64() + } } public enum UInt8Codec: FieldCodec { - public static let typeId: TypeId = .uint8 - public static let defaultValue = UInt8(0) - public static func writePayload(_ value: UInt8, _ context: WriteContext) { - context.buffer.writeUInt8(value) - } - public static func readPayload(_ context: ReadContext) throws -> UInt8 { - try context.buffer.readUInt8() - } + public static let typeId: TypeId = .uint8 + public static let defaultValue = UInt8(0) + public static func writePayload(_ value: UInt8, _ context: WriteContext) { + context.buffer.writeUInt8(value) + } + public static func readPayload(_ context: ReadContext) throws -> UInt8 { + try context.buffer.readUInt8() + } } public enum UInt16Codec: FieldCodec { - public static let typeId: TypeId = .uint16 - public static let defaultValue = UInt16(0) - public static func writePayload(_ value: UInt16, _ context: WriteContext) { - context.buffer.writeUInt16(value) - } - public static func readPayload(_ context: ReadContext) throws -> UInt16 { - try context.buffer.readUInt16() - } + public static let typeId: TypeId = .uint16 + public static let defaultValue = UInt16(0) + public static func writePayload(_ value: UInt16, _ context: WriteContext) { + context.buffer.writeUInt16(value) + } + public static func readPayload(_ context: ReadContext) throws -> UInt16 { + try context.buffer.readUInt16() + } } public enum UInt32VarintCodec: FieldCodec { - public static let typeId: TypeId = .varUInt32 - public static let defaultValue = UInt32(0) - public static func writePayload(_ value: UInt32, _ context: WriteContext) { - context.buffer.writeVarUInt32(value) - } - public static func readPayload(_ context: ReadContext) throws -> UInt32 { - try context.buffer.readVarUInt32() - } + public static let typeId: TypeId = .varUInt32 + public static let defaultValue = UInt32(0) + public static func writePayload(_ value: UInt32, _ context: WriteContext) { + context.buffer.writeVarUInt32(value) + } + public static func readPayload(_ context: ReadContext) throws -> UInt32 { + try context.buffer.readVarUInt32() + } } public enum UInt32FixedCodec: FieldCodec { - public static let typeId: TypeId = .uint32 - public static let defaultValue = UInt32(0) - public static func writePayload(_ value: UInt32, _ context: WriteContext) { - context.buffer.writeUInt32(value) - } - public static func readPayload(_ context: ReadContext) throws -> UInt32 { - try context.buffer.readUInt32() - } + public static let typeId: TypeId = .uint32 + public static let defaultValue = UInt32(0) + public static func writePayload(_ value: UInt32, _ context: WriteContext) { + context.buffer.writeUInt32(value) + } + public static func readPayload(_ context: ReadContext) throws -> UInt32 { + try context.buffer.readUInt32() + } } public enum UInt64VarintCodec: FieldCodec { - public static let typeId: TypeId = .varUInt64 - public static let defaultValue = UInt64(0) - public static func writePayload(_ value: UInt64, _ context: WriteContext) { - context.buffer.writeVarUInt64(value) - } - public static func readPayload(_ context: ReadContext) throws -> UInt64 { - try context.buffer.readVarUInt64() - } + public static let typeId: TypeId = .varUInt64 + public static let defaultValue = UInt64(0) + public static func writePayload(_ value: UInt64, _ context: WriteContext) { + context.buffer.writeVarUInt64(value) + } + public static func readPayload(_ context: ReadContext) throws -> UInt64 { + try context.buffer.readVarUInt64() + } } public enum UInt64FixedCodec: FieldCodec { - public static let typeId: TypeId = .uint64 - public static let defaultValue = UInt64(0) - public static func writePayload(_ value: UInt64, _ context: WriteContext) { - context.buffer.writeUInt64(value) - } - public static func readPayload(_ context: ReadContext) throws -> UInt64 { - try context.buffer.readUInt64() - } + public static let typeId: TypeId = .uint64 + public static let defaultValue = UInt64(0) + public static func writePayload(_ value: UInt64, _ context: WriteContext) { + context.buffer.writeUInt64(value) + } + public static func readPayload(_ context: ReadContext) throws -> UInt64 { + try context.buffer.readUInt64() + } } public enum UInt64TaggedCodec: FieldCodec { - public static let typeId: TypeId = .taggedUInt64 - public static let defaultValue = UInt64(0) - public static func writePayload(_ value: UInt64, _ context: WriteContext) { - context.buffer.writeTaggedUInt64(value) - } - public static func readPayload(_ context: ReadContext) throws -> UInt64 { - try context.buffer.readTaggedUInt64() - } + public static let typeId: TypeId = .taggedUInt64 + public static let defaultValue = UInt64(0) + public static func writePayload(_ value: UInt64, _ context: WriteContext) { + context.buffer.writeTaggedUInt64(value) + } + public static func readPayload(_ context: ReadContext) throws -> UInt64 { + try context.buffer.readTaggedUInt64() + } } public enum IntVarintCodec: FieldCodec { - public static let typeId: TypeId = .varint64 - public static let defaultValue = Int(0) - public static func writePayload(_ value: Int, _ context: WriteContext) { - context.buffer.writeVarInt64(Int64(value)) - } - public static func readPayload(_ context: ReadContext) throws -> Int { - Int(try context.buffer.readVarInt64()) - } + public static let typeId: TypeId = .varint64 + public static let defaultValue = Int(0) + public static func writePayload(_ value: Int, _ context: WriteContext) { + context.buffer.writeVarInt64(Int64(value)) + } + public static func readPayload(_ context: ReadContext) throws -> Int { + Int(try context.buffer.readVarInt64()) + } } public enum IntFixedCodec: FieldCodec { - public static let typeId: TypeId = .int64 - public static let defaultValue = Int(0) - public static func writePayload(_ value: Int, _ context: WriteContext) { - context.buffer.writeInt64(Int64(value)) - } - public static func readPayload(_ context: ReadContext) throws -> Int { - Int(try context.buffer.readInt64()) - } + public static let typeId: TypeId = .int64 + public static let defaultValue = Int(0) + public static func writePayload(_ value: Int, _ context: WriteContext) { + context.buffer.writeInt64(Int64(value)) + } + public static func readPayload(_ context: ReadContext) throws -> Int { + Int(try context.buffer.readInt64()) + } } public enum IntTaggedCodec: FieldCodec { - public static let typeId: TypeId = .taggedInt64 - public static let defaultValue = Int(0) - public static func writePayload(_ value: Int, _ context: WriteContext) { - context.buffer.writeTaggedInt64(Int64(value)) - } - public static func readPayload(_ context: ReadContext) throws -> Int { - Int(try context.buffer.readTaggedInt64()) - } + public static let typeId: TypeId = .taggedInt64 + public static let defaultValue = Int(0) + public static func writePayload(_ value: Int, _ context: WriteContext) { + context.buffer.writeTaggedInt64(Int64(value)) + } + public static func readPayload(_ context: ReadContext) throws -> Int { + Int(try context.buffer.readTaggedInt64()) + } } public enum UIntVarintCodec: FieldCodec { - public static let typeId: TypeId = .varUInt64 - public static let defaultValue = UInt(0) - public static func writePayload(_ value: UInt, _ context: WriteContext) { - context.buffer.writeVarUInt64(UInt64(value)) - } - public static func readPayload(_ context: ReadContext) throws -> UInt { - UInt(try context.buffer.readVarUInt64()) - } + public static let typeId: TypeId = .varUInt64 + public static let defaultValue = UInt(0) + public static func writePayload(_ value: UInt, _ context: WriteContext) { + context.buffer.writeVarUInt64(UInt64(value)) + } + public static func readPayload(_ context: ReadContext) throws -> UInt { + UInt(try context.buffer.readVarUInt64()) + } } public enum UIntFixedCodec: FieldCodec { - public static let typeId: TypeId = .uint64 - public static let defaultValue = UInt(0) - public static func writePayload(_ value: UInt, _ context: WriteContext) { - context.buffer.writeUInt64(UInt64(value)) - } - public static func readPayload(_ context: ReadContext) throws -> UInt { - UInt(try context.buffer.readUInt64()) - } + public static let typeId: TypeId = .uint64 + public static let defaultValue = UInt(0) + public static func writePayload(_ value: UInt, _ context: WriteContext) { + context.buffer.writeUInt64(UInt64(value)) + } + public static func readPayload(_ context: ReadContext) throws -> UInt { + UInt(try context.buffer.readUInt64()) + } } public enum UIntTaggedCodec: FieldCodec { - public static let typeId: TypeId = .taggedUInt64 - public static let defaultValue = UInt(0) - public static func writePayload(_ value: UInt, _ context: WriteContext) { - context.buffer.writeTaggedUInt64(UInt64(value)) - } - public static func readPayload(_ context: ReadContext) throws -> UInt { - UInt(try context.buffer.readTaggedUInt64()) - } + public static let typeId: TypeId = .taggedUInt64 + public static let defaultValue = UInt(0) + public static func writePayload(_ value: UInt, _ context: WriteContext) { + context.buffer.writeTaggedUInt64(UInt64(value)) + } + public static func readPayload(_ context: ReadContext) throws -> UInt { + UInt(try context.buffer.readTaggedUInt64()) + } } public enum Float16Codec: FieldCodec { - public static let typeId: TypeId = .float16 - public static let defaultValue = Float16(0) - public static func writePayload(_ value: Float16, _ context: WriteContext) { - context.buffer.writeUInt16(value.bitPattern) - } - public static func readPayload(_ context: ReadContext) throws -> Float16 { - Float16(bitPattern: try context.buffer.readUInt16()) - } + public static let typeId: TypeId = .float16 + public static let defaultValue = Float16(0) + public static func writePayload(_ value: Float16, _ context: WriteContext) { + context.buffer.writeUInt16(value.bitPattern) + } + public static func readPayload(_ context: ReadContext) throws -> Float16 { + Float16(bitPattern: try context.buffer.readUInt16()) + } } public enum BFloat16Codec: FieldCodec { - public static let typeId: TypeId = .bfloat16 - public static let defaultValue = BFloat16() - public static func writePayload(_ value: BFloat16, _ context: WriteContext) { - context.buffer.writeUInt16(value.rawValue) - } - public static func readPayload(_ context: ReadContext) throws -> BFloat16 { - BFloat16(rawValue: try context.buffer.readUInt16()) - } + public static let typeId: TypeId = .bfloat16 + public static let defaultValue = BFloat16() + public static func writePayload(_ value: BFloat16, _ context: WriteContext) { + context.buffer.writeUInt16(value.rawValue) + } + public static func readPayload(_ context: ReadContext) throws -> BFloat16 { + BFloat16(rawValue: try context.buffer.readUInt16()) + } } public enum FloatCodec: FieldCodec { - public static let typeId: TypeId = .float32 - public static let defaultValue = Float(0) - public static func writePayload(_ value: Float, _ context: WriteContext) { - context.buffer.writeFloat32(value) - } - public static func readPayload(_ context: ReadContext) throws -> Float { - try context.buffer.readFloat32() - } + public static let typeId: TypeId = .float32 + public static let defaultValue = Float(0) + public static func writePayload(_ value: Float, _ context: WriteContext) { + context.buffer.writeFloat32(value) + } + public static func readPayload(_ context: ReadContext) throws -> Float { + try context.buffer.readFloat32() + } } public enum DoubleCodec: FieldCodec { - public static let typeId: TypeId = .float64 - public static let defaultValue = Double(0) - public static func writePayload(_ value: Double, _ context: WriteContext) { - context.buffer.writeFloat64(value) - } - public static func readPayload(_ context: ReadContext) throws -> Double { - try context.buffer.readFloat64() - } + public static let typeId: TypeId = .float64 + public static let defaultValue = Double(0) + public static func writePayload(_ value: Double, _ context: WriteContext) { + context.buffer.writeFloat64(value) + } + public static func readPayload(_ context: ReadContext) throws -> Double { + try context.buffer.readFloat64() + } } public typealias StringCodec = SerializerCodec @@ -679,1115 +683,1194 @@ public typealias DecimalCodec = SerializerCodec public typealias DataCodec = SerializerCodec public enum ListFieldCodec: FieldCodec { - public typealias Value = [ElementCodec.Value] - - public static var typeId: TypeId { .list } - public static var defaultValue: Value { [] } - - public static func fieldType(nullable: Bool, trackRef: Bool) -> TypeMeta.FieldType { - return TypeMeta.FieldType( - typeID: TypeId.list.rawValue, - nullable: nullable, - trackRef: trackRef, - generics: [ - ElementCodec.fieldType( - nullable: ElementCodec.isNullableType, - trackRef: trackRef && ElementCodec.isRefType) - ] - ) - } - - public static func writePayload(_ value: Value, _ context: WriteContext) throws { - try writeCollectionPayload(value, context, elementCodec: ElementCodec.self) - } - - public static func readPayload(_ context: ReadContext) throws -> Value { - return try readCollectionPayload(context, elementCodec: ElementCodec.self) - } - - public static func readCompatibleField( - _ context: ReadContext, - remoteFieldType: TypeMeta.FieldType, - refMode: RefMode - ) throws -> Value { - if isCompatiblePackedArrayTypeID(remoteFieldType.typeID, elementCodec: ElementCodec.self) { - return try readCompatiblePackedArrayField(context, refMode: refMode, elementCodec: ElementCodec.self) - } - return try FieldCodecDefault.readCompatibleField( - codec: Self.self, - context, - remoteFieldType: remoteFieldType, - refMode: refMode - ) - } + public typealias Value = [ElementCodec.Value] + + public static var typeId: TypeId { .list } + public static var defaultValue: Value { [] } + + public static func fieldType(nullable: Bool, trackRef: Bool) -> TypeMeta.FieldType { + return TypeMeta.FieldType( + typeID: TypeId.list.rawValue, + nullable: nullable, + trackRef: trackRef, + generics: [ + ElementCodec.fieldType( + nullable: ElementCodec.isNullableType, + trackRef: trackRef && ElementCodec.isRefType) + ] + ) + } + + public static func writePayload(_ value: Value, _ context: WriteContext) throws { + try writeCollectionPayload(value, context, elementCodec: ElementCodec.self) + } + + public static func readPayload(_ context: ReadContext) throws -> Value { + return try readCollectionPayload(context, elementCodec: ElementCodec.self) + } + + public static func readCompatibleField( + _ context: ReadContext, + remoteFieldType: TypeMeta.FieldType, + refMode: RefMode + ) throws -> Value { + if isCompatiblePackedArrayTypeID(remoteFieldType.typeID, elementCodec: ElementCodec.self) { + return try readCompatiblePackedArrayField( + context, refMode: refMode, elementCodec: ElementCodec.self) + } + return try FieldCodecDefault.readCompatibleField( + codec: Self.self, + context, + remoteFieldType: remoteFieldType, + refMode: refMode + ) + } } public enum ArrayFieldCodec: FieldCodec { - public typealias Value = [ElementCodec.Value] + public typealias Value = [ElementCodec.Value] - public static var typeId: TypeId { - guard let typeID = packedArrayTypeID(for: ElementCodec.self) else { - preconditionFailure("ArrayFieldCodec requires a non-null numeric or bool element codec") - } - return typeID + public static var typeId: TypeId { + guard let typeID = packedArrayTypeID(for: ElementCodec.self) else { + preconditionFailure("ArrayFieldCodec requires a non-null numeric or bool element codec") } + return typeID + } - public static var defaultValue: Value { [] } + public static var defaultValue: Value { [] } - public static func fieldType(nullable: Bool, trackRef: Bool) -> TypeMeta.FieldType { - TypeMeta.FieldType(typeID: typeId.rawValue, nullable: nullable, trackRef: trackRef) - } + public static func fieldType(nullable: Bool, trackRef: Bool) -> TypeMeta.FieldType { + TypeMeta.FieldType(typeID: typeId.rawValue, nullable: nullable, trackRef: trackRef) + } - public static func writePayload(_ value: Value, _ context: WriteContext) throws { - if try writePackedArrayPayload(value, context, elementCodec: ElementCodec.self) { - return - } - throw ForyError.invalidData("unsupported array field element codec \(ElementCodec.self)") + public static func writePayload(_ value: Value, _ context: WriteContext) throws { + if try writePackedArrayPayload(value, context, elementCodec: ElementCodec.self) { + return } + throw ForyError.invalidData("unsupported array field element codec \(ElementCodec.self)") + } - public static func readPayload(_ context: ReadContext) throws -> Value { - if let value = try readPackedArrayPayload(context, elementCodec: ElementCodec.self) { - return value - } - throw ForyError.invalidData("unsupported array field element codec \(ElementCodec.self)") + public static func readPayload(_ context: ReadContext) throws -> Value { + if let value = try readPackedArrayPayload(context, elementCodec: ElementCodec.self) { + return value } + throw ForyError.invalidData("unsupported array field element codec \(ElementCodec.self)") + } - public static func readCompatibleField( - _ context: ReadContext, - remoteFieldType: TypeMeta.FieldType, - refMode: RefMode - ) throws -> Value { - if remoteFieldType.typeID == TypeId.list.rawValue, - let element = remoteFieldType.generics.first, - let localArrayTypeID = packedArrayTypeID(for: ElementCodec.self), - TypeId.listElementTypeID(element.typeID, matchesDenseArrayTypeID: localArrayTypeID.rawValue) - { - return try readListPayloadAsArray( - context, - refMode: refMode, - elementCodec: ElementCodec.self, - remoteElementTypeID: element.typeID - ) - } - return try FieldCodecDefault.readCompatibleField( - codec: Self.self, - context, - remoteFieldType: remoteFieldType, - refMode: refMode - ) + public static func readCompatibleField( + _ context: ReadContext, + remoteFieldType: TypeMeta.FieldType, + refMode: RefMode + ) throws -> Value { + if remoteFieldType.typeID == TypeId.list.rawValue, + let element = remoteFieldType.generics.first, + let localArrayTypeID = packedArrayTypeID(for: ElementCodec.self), + TypeId.listElementTypeID(element.typeID, matchesDenseArrayTypeID: localArrayTypeID.rawValue) + { + return try readListPayloadAsArray( + context, + refMode: refMode, + elementCodec: ElementCodec.self, + remoteElementTypeID: element.typeID + ) + } + return try FieldCodecDefault.readCompatibleField( + codec: Self.self, + context, + remoteFieldType: remoteFieldType, + refMode: refMode + ) + } + + public static func write( + _ value: Value, + _ context: WriteContext, + refMode: RefMode, + writeTypeInfo: Bool + ) throws { + if refMode == .none, !writeTypeInfo { + try writePayload(value, context) + return } - - public static func write( - _ value: Value, - _ context: WriteContext, - refMode: RefMode, - writeTypeInfo: Bool - ) throws { - if refMode == .none, !writeTypeInfo { - try writePayload(value, context) - return - } - if refMode != .none { - context.buffer.writeInt8(RefFlag.notNullValue.rawValue) - } - if writeTypeInfo { - try writeStaticTypeInfo(context) - } - try writePayload(value, context) + if refMode != .none { + context.buffer.writeInt8(RefFlag.notNullValue.rawValue) } - - public static func read( - _ context: ReadContext, - refMode: RefMode, - readTypeInfo: Bool - ) throws -> Value { - switch refMode { - case .none: - return try readPayloadAfterTypeInfo(context, readTypeInfo: readTypeInfo) - case .nullOnly: - let rawFlag = try context.buffer.readInt8() - switch rawFlag { - case RefFlag.null.rawValue: - return defaultValue - case RefFlag.notNullValue.rawValue: - return try readPayloadAfterTypeInfo(context, readTypeInfo: readTypeInfo) - case RefFlag.refValue.rawValue: - if context.trackRef { - let reservedRefID = context.refReader.reserveRefID() - let value = try readPayloadAfterTypeInfo(context, readTypeInfo: readTypeInfo) - context.refReader.storeRef(value, at: reservedRefID) - return value - } - return try readPayloadAfterTypeInfo(context, readTypeInfo: readTypeInfo) - case RefFlag.ref.rawValue: - let refID = try context.buffer.readVarUInt32() - return try context.refReader.readRef(refID, as: Value.self) - default: - throw ForyError.refError("invalid ref flag \(rawFlag)") - } - case .tracking: - let rawFlag = try context.buffer.readInt8() - guard let flag = RefFlag(rawValue: rawFlag) else { - throw ForyError.refError("invalid ref flag \(rawFlag)") - } - switch flag { - case .null: - return defaultValue - case .ref: - let refID = try context.buffer.readVarUInt32() - return try context.refReader.readRef(refID, as: Value.self) - case .refValue: - let reservedRefID = context.trackRef ? context.refReader.reserveRefID() : nil - let value = try readPayloadAfterTypeInfo(context, readTypeInfo: readTypeInfo) - if let reservedRefID { - context.refReader.storeRef(value, at: reservedRefID) - } - return value - case .notNullValue: - return try readPayloadAfterTypeInfo(context, readTypeInfo: readTypeInfo) - } - } + if writeTypeInfo { + try writeStaticTypeInfo(context) } + try writePayload(value, context) + } - private static func readPayloadAfterTypeInfo( - _ context: ReadContext, - readTypeInfo: Bool - ) throws -> Value { - if readTypeInfo { - let typeInfo = try Self.readTypeInfo(context) - return try withTypeInfo(typeInfo, context) { - try readPayload(context) - } + public static func read( + _ context: ReadContext, + refMode: RefMode, + readTypeInfo: Bool + ) throws -> Value { + switch refMode { + case .none: + return try readPayloadAfterTypeInfo(context, readTypeInfo: readTypeInfo) + case .nullOnly: + let rawFlag = try context.buffer.readInt8() + switch rawFlag { + case RefFlag.null.rawValue: + return defaultValue + case RefFlag.notNullValue.rawValue: + return try readPayloadAfterTypeInfo(context, readTypeInfo: readTypeInfo) + case RefFlag.refValue.rawValue: + if context.trackRef { + let reservedRefID = context.refReader.reserveRefID() + let value = try readPayloadAfterTypeInfo(context, readTypeInfo: readTypeInfo) + context.refReader.storeRef(value, at: reservedRefID) + return value } - return try readPayload(context) + return try readPayloadAfterTypeInfo(context, readTypeInfo: readTypeInfo) + case RefFlag.ref.rawValue: + let refID = try context.buffer.readVarUInt32() + return try context.refReader.readRef(refID, as: Value.self) + default: + throw ForyError.refError("invalid ref flag \(rawFlag)") + } + case .tracking: + let rawFlag = try context.buffer.readInt8() + guard let flag = RefFlag(rawValue: rawFlag) else { + throw ForyError.refError("invalid ref flag \(rawFlag)") + } + switch flag { + case .null: + return defaultValue + case .ref: + let refID = try context.buffer.readVarUInt32() + return try context.refReader.readRef(refID, as: Value.self) + case .refValue: + let reservedRefID = context.trackRef ? context.refReader.reserveRefID() : nil + let value = try readPayloadAfterTypeInfo(context, readTypeInfo: readTypeInfo) + if let reservedRefID { + context.refReader.storeRef(value, at: reservedRefID) + } + return value + case .notNullValue: + return try readPayloadAfterTypeInfo(context, readTypeInfo: readTypeInfo) + } } + } + + private static func readPayloadAfterTypeInfo( + _ context: ReadContext, + readTypeInfo: Bool + ) throws -> Value { + if readTypeInfo { + let typeInfo = try Self.readTypeInfo(context) + return try withTypeInfo(typeInfo, context) { + try readPayload(context) + } + } + return try readPayload(context) + } } public enum SetFieldCodec: FieldCodec where ElementCodec.Value: Hashable { - public typealias Value = Set - - public static var typeId: TypeId { .set } - public static var defaultValue: Value { [] } - - public static func fieldType(nullable: Bool, trackRef: Bool) -> TypeMeta.FieldType { - TypeMeta.FieldType( - typeID: TypeId.set.rawValue, - nullable: nullable, - trackRef: trackRef, - generics: [ - ElementCodec.fieldType( - nullable: ElementCodec.isNullableType, - trackRef: trackRef && ElementCodec.isRefType) - ] - ) - } - - public static func writePayload(_ value: Value, _ context: WriteContext) throws { - try writeCollectionPayload(Array(value), context, elementCodec: ElementCodec.self) - } - - public static func readPayload(_ context: ReadContext) throws -> Value { - let values = try readCollectionPayload(context, elementCodec: ElementCodec.self) - try reserveFieldArrayStorage(context, ElementCodec.self, count: values.count) - return Set(values) - } + public typealias Value = Set + + public static var typeId: TypeId { .set } + public static var defaultValue: Value { [] } + + public static func fieldType(nullable: Bool, trackRef: Bool) -> TypeMeta.FieldType { + TypeMeta.FieldType( + typeID: TypeId.set.rawValue, + nullable: nullable, + trackRef: trackRef, + generics: [ + ElementCodec.fieldType( + nullable: ElementCodec.isNullableType, + trackRef: trackRef && ElementCodec.isRefType) + ] + ) + } + + public static func writePayload(_ value: Value, _ context: WriteContext) throws { + try writeCollectionPayload(Array(value), context, elementCodec: ElementCodec.self) + } + + public static func readPayload(_ context: ReadContext) throws -> Value { + let values = try readCollectionPayload(context, elementCodec: ElementCodec.self) + try reserveFieldArrayStorage(context, ElementCodec.self, count: values.count) + return Set(values) + } } public enum MapFieldCodec: FieldCodec where KeyCodec.Value: Hashable { - public typealias Value = [KeyCodec.Value: ValueCodec.Value] - - private struct MapEntryWriteOptions { - var trackKeyRef: Bool - var trackValueRef: Bool - var keyDeclared: Bool - var valueDeclared: Bool - var keyDynamicType: Bool - var valueDynamicType: Bool - var keyIsNil: Bool - var valueIsNil: Bool - } - - public static var typeId: TypeId { .map } - public static var defaultValue: Value { [:] } - - public static func fieldType(nullable: Bool, trackRef: Bool) -> TypeMeta.FieldType { - TypeMeta.FieldType( - typeID: TypeId.map.rawValue, - nullable: nullable, - trackRef: trackRef, - generics: [ - KeyCodec.fieldType( - nullable: KeyCodec.isNullableType, - trackRef: trackRef && KeyCodec.isRefType), - ValueCodec.fieldType( - nullable: ValueCodec.isNullableType, - trackRef: trackRef && ValueCodec.isRefType) - ] + public typealias Value = [KeyCodec.Value: ValueCodec.Value] + + private struct MapEntryWriteOptions { + var trackKeyRef: Bool + var trackValueRef: Bool + var keyDeclared: Bool + var valueDeclared: Bool + var keyDynamicType: Bool + var valueDynamicType: Bool + var keyIsNil: Bool + var valueIsNil: Bool + } + + public static var typeId: TypeId { .map } + public static var defaultValue: Value { [:] } + + public static func fieldType(nullable: Bool, trackRef: Bool) -> TypeMeta.FieldType { + TypeMeta.FieldType( + typeID: TypeId.map.rawValue, + nullable: nullable, + trackRef: trackRef, + generics: [ + KeyCodec.fieldType( + nullable: KeyCodec.isNullableType, + trackRef: trackRef && KeyCodec.isRefType), + ValueCodec.fieldType( + nullable: ValueCodec.isNullableType, + trackRef: trackRef && ValueCodec.isRefType), + ] + ) + } + + public static func writePayload(_ value: Value, _ context: WriteContext) throws { + context.buffer.writeVarUInt32(UInt32(value.count)) + if value.isEmpty { + return + } + + let trackKeyRef = context.trackRef && KeyCodec.isRefType + let trackValueRef = context.trackRef && ValueCodec.isRefType + let keyDeclared = !TypeId.needsTypeInfoForField(KeyCodec.typeId) + let valueDeclared = !TypeId.needsTypeInfoForField(ValueCodec.typeId) + let keyDynamicType = KeyCodec.typeId == .unknown + let valueDynamicType = ValueCodec.typeId == .unknown + let commonOptions = MapEntryWriteOptions( + trackKeyRef: trackKeyRef, + trackValueRef: trackValueRef, + keyDeclared: keyDeclared, + valueDeclared: valueDeclared, + keyDynamicType: keyDynamicType, + valueDynamicType: valueDynamicType, + keyIsNil: false, + valueIsNil: false + ) + + var iterator = value.makeIterator() + var pendingPair = iterator.next() + while let pair = pendingPair { + let keyIsNil = KeyCodec.isNone(pair.key) + let valueIsNil = ValueCodec.isNone(pair.value) + + if keyDynamicType || valueDynamicType || keyIsNil || valueIsNil { + var options = commonOptions + options.keyIsNil = keyIsNil + options.valueIsNil = valueIsNil + try writeMapEntry( + pair, + context, + options: options ) - } - - public static func writePayload(_ value: Value, _ context: WriteContext) throws { - context.buffer.writeVarUInt32(UInt32(value.count)) - if value.isEmpty { - return + pendingPair = iterator.next() + continue + } + + var header: UInt8 = 0 + if trackKeyRef { header |= MapHeader.trackingKeyRef } + if trackValueRef { header |= MapHeader.trackingValueRef } + if keyDeclared { header |= MapHeader.declaredKeyType } + if valueDeclared { header |= MapHeader.declaredValueType } + + context.buffer.writeUInt8(header) + let chunkSizeOffset = context.buffer.count + context.buffer.writeUInt8(0) + + if !keyDeclared { + try KeyCodec.writeStaticTypeInfo(context) + } + if !valueDeclared { + try ValueCodec.writeStaticTypeInfo(context) + } + + var chunkSize: UInt8 = 0 + while chunkSize < UInt8.max, let current = pendingPair { + if KeyCodec.isNone(current.key) || ValueCodec.isNone(current.value) { + break } - - let trackKeyRef = context.trackRef && KeyCodec.isRefType - let trackValueRef = context.trackRef && ValueCodec.isRefType - let keyDeclared = !TypeId.needsTypeInfoForField(KeyCodec.typeId) - let valueDeclared = !TypeId.needsTypeInfoForField(ValueCodec.typeId) - let keyDynamicType = KeyCodec.typeId == .unknown - let valueDynamicType = ValueCodec.typeId == .unknown - let commonOptions = MapEntryWriteOptions( - trackKeyRef: trackKeyRef, - trackValueRef: trackValueRef, - keyDeclared: keyDeclared, - valueDeclared: valueDeclared, - keyDynamicType: keyDynamicType, - valueDynamicType: valueDynamicType, - keyIsNil: false, - valueIsNil: false + try writeMapPayload( + current, + context, + trackKeyRef: trackKeyRef, + trackValueRef: trackValueRef ) - - var iterator = value.makeIterator() - var pendingPair = iterator.next() - while let pair = pendingPair { - let keyIsNil = KeyCodec.isNone(pair.key) - let valueIsNil = ValueCodec.isNone(pair.value) - - if keyDynamicType || valueDynamicType || keyIsNil || valueIsNil { - var options = commonOptions - options.keyIsNil = keyIsNil - options.valueIsNil = valueIsNil - try writeMapEntry( - pair, - context, - options: options - ) - pendingPair = iterator.next() - continue - } - - var header: UInt8 = 0 - if trackKeyRef { header |= MapHeader.trackingKeyRef } - if trackValueRef { header |= MapHeader.trackingValueRef } - if keyDeclared { header |= MapHeader.declaredKeyType } - if valueDeclared { header |= MapHeader.declaredValueType } - - context.buffer.writeUInt8(header) - let chunkSizeOffset = context.buffer.count - context.buffer.writeUInt8(0) - - if !keyDeclared { - try KeyCodec.writeStaticTypeInfo(context) - } - if !valueDeclared { - try ValueCodec.writeStaticTypeInfo(context) - } - - var chunkSize: UInt8 = 0 - while chunkSize < UInt8.max, let current = pendingPair { - if KeyCodec.isNone(current.key) || ValueCodec.isNone(current.value) { - break - } - try writeMapPayload( - current, - context, - trackKeyRef: trackKeyRef, - trackValueRef: trackValueRef - ) - chunkSize &+= 1 - pendingPair = iterator.next() - } - context.buffer.setByte(at: chunkSizeOffset, to: chunkSize) - } - } - - public static func readPayload(_ context: ReadContext) throws -> Value { - let totalLength = Int(try context.buffer.readVarUInt32()) - try context.ensureCollectionLength(totalLength, label: "map") - if totalLength == 0 { - try reserveFieldMapStorage(context, key: KeyCodec.self, value: ValueCodec.self, count: totalLength) - return [:] - } - - try reserveFieldMapStorage(context, key: KeyCodec.self, value: ValueCodec.self, count: totalLength) - try context.ensureRemainingBytes(totalLength, label: "map") - var map: Value = [:] - map.reserveCapacity(totalLength) - var readCount = 0 - while readCount < totalLength { - let header = try context.buffer.readUInt8() - // IMPORTANT: map readers must obey the sender-written key/value ref - // bits in this header. Local Swift field metadata must not - // override that decision while reading. Shared xlang tests - // intentionally deserialize one ref policy and then serialize - // another local payload. DO NOT REMOVE this comment. - let trackKeyRef = (header & MapHeader.trackingKeyRef) != 0 - let keyNull = (header & MapHeader.keyNull) != 0 - let keyDeclared = (header & MapHeader.declaredKeyType) != 0 - - let trackValueRef = (header & MapHeader.trackingValueRef) != 0 - let valueNull = (header & MapHeader.valueNull) != 0 - let valueDeclared = (header & MapHeader.declaredValueType) != 0 - - if keyNull && valueNull { - map[KeyCodec.defaultValue] = ValueCodec.defaultValue - readCount += 1 - continue - } - - if keyNull { - let value = try readMapValue( - context, - declared: valueDeclared, - trackRef: trackValueRef - ) - map[KeyCodec.defaultValue] = value - readCount += 1 - continue - } - - if valueNull { - let key = try readMapKey( - context, - declared: keyDeclared, - trackRef: trackKeyRef - ) - map[key] = ValueCodec.defaultValue - readCount += 1 - continue - } - - let chunkSize = Int(try context.buffer.readUInt8()) - if chunkSize > (totalLength - readCount) { - throw ForyError.invalidData("map chunk size exceeds remaining entries") - } - let keyTypeInfo = keyDeclared ? nil : try KeyCodec.readTypeInfo(context) - let valueTypeInfo = valueDeclared ? nil : try ValueCodec.readTypeInfo(context) - for _ in 0...Element, - _ context: WriteContext, - options: MapEntryWriteOptions - ) throws { - var header: UInt8 = 0 - if options.trackKeyRef { header |= MapHeader.trackingKeyRef } - if options.trackValueRef { header |= MapHeader.trackingValueRef } - if options.keyIsNil { - header |= MapHeader.keyNull - } else if !options.keyDynamicType && options.keyDeclared { - header |= MapHeader.declaredKeyType - } - if options.valueIsNil { - header |= MapHeader.valueNull - } else if !options.valueDynamicType && options.valueDeclared { - header |= MapHeader.declaredValueType - } - context.buffer.writeUInt8(header) - - if !options.keyIsNil { - if !options.keyDeclared { - try KeyCodec.writeStaticTypeInfo(context) - } - try KeyCodec.write( - pair.key, - context, - refMode: options.trackKeyRef ? .tracking : .none, - writeTypeInfo: false - ) - } - if !options.valueIsNil { - if !options.valueDeclared { - try ValueCodec.writeStaticTypeInfo(context) - } - try ValueCodec.write( - pair.value, - context, - refMode: options.trackValueRef ? .tracking : .none, - writeTypeInfo: false - ) - } - } - - private static func writeMapPayload( - _ pair: Dictionary.Element, - _ context: WriteContext, - trackKeyRef: Bool, - trackValueRef: Bool - ) throws { - try KeyCodec.write( - pair.key, + chunkSize &+= 1 + pendingPair = iterator.next() + } + context.buffer.setByte(at: chunkSizeOffset, to: chunkSize) + } + } + + public static func readPayload(_ context: ReadContext) throws -> Value { + let totalLength = Int(try context.buffer.readVarUInt32()) + try context.ensureCollectionLength(totalLength, label: "map") + if totalLength == 0 { + try reserveFieldMapStorage( + context, key: KeyCodec.self, value: ValueCodec.self, count: totalLength) + return [:] + } + + try reserveFieldMapStorage( + context, key: KeyCodec.self, value: ValueCodec.self, count: totalLength) + try context.ensureRemainingBytes(totalLength, label: "map") + var map: Value = [:] + map.reserveCapacity(totalLength) + var readCount = 0 + while readCount < totalLength { + let header = try context.buffer.readUInt8() + // IMPORTANT: map readers must obey the sender-written key/value ref + // bits in this header. Local Swift field metadata must not + // override that decision while reading. Shared xlang tests + // intentionally deserialize one ref policy and then serialize + // another local payload. DO NOT REMOVE this comment. + let trackKeyRef = (header & MapHeader.trackingKeyRef) != 0 + let keyNull = (header & MapHeader.keyNull) != 0 + let keyDeclared = (header & MapHeader.declaredKeyType) != 0 + + let trackValueRef = (header & MapHeader.trackingValueRef) != 0 + let valueNull = (header & MapHeader.valueNull) != 0 + let valueDeclared = (header & MapHeader.declaredValueType) != 0 + + if keyNull && valueNull { + map[KeyCodec.defaultValue] = ValueCodec.defaultValue + readCount += 1 + continue + } + + if keyNull { + let value = try readMapValue( + context, + declared: valueDeclared, + trackRef: trackValueRef + ) + map[KeyCodec.defaultValue] = value + readCount += 1 + continue + } + + if valueNull { + let key = try readMapKey( + context, + declared: keyDeclared, + trackRef: trackKeyRef + ) + map[key] = ValueCodec.defaultValue + readCount += 1 + continue + } + + let chunkSize = Int(try context.buffer.readUInt8()) + if chunkSize > (totalLength - readCount) { + throw ForyError.invalidData("map chunk size exceeds remaining entries") + } + let keyTypeInfo = keyDeclared ? nil : try KeyCodec.readTypeInfo(context) + let valueTypeInfo = valueDeclared ? nil : try ValueCodec.readTypeInfo(context) + for _ in 0.. KeyCodec.Value { - let typeInfo = declared ? nil : try KeyCodec.readTypeInfo(context) - return try KeyCodec.withTypeInfo(typeInfo, context) { - try KeyCodec.read(context, refMode: trackRef ? .tracking : .none, readTypeInfo: false) + readTypeInfo: false + ) } + map[key] = value + } + readCount += chunkSize } + return map + } - private static func readMapValue( - _ context: ReadContext, - declared: Bool, - trackRef: Bool - ) throws -> ValueCodec.Value { - let typeInfo = declared ? nil : try ValueCodec.readTypeInfo(context) - return try ValueCodec.withTypeInfo(typeInfo, context) { - try ValueCodec.read(context, refMode: trackRef ? .tracking : .none, readTypeInfo: false) - } + private static func writeMapEntry( + _ pair: Dictionary.Element, + _ context: WriteContext, + options: MapEntryWriteOptions + ) throws { + var header: UInt8 = 0 + if options.trackKeyRef { header |= MapHeader.trackingKeyRef } + if options.trackValueRef { header |= MapHeader.trackingValueRef } + if options.keyIsNil { + header |= MapHeader.keyNull + } else if !options.keyDynamicType && options.keyDeclared { + header |= MapHeader.declaredKeyType + } + if options.valueIsNil { + header |= MapHeader.valueNull + } else if !options.valueDynamicType && options.valueDeclared { + header |= MapHeader.declaredValueType + } + context.buffer.writeUInt8(header) + + if !options.keyIsNil { + if !options.keyDeclared { + try KeyCodec.writeStaticTypeInfo(context) + } + try KeyCodec.write( + pair.key, + context, + refMode: options.trackKeyRef ? .tracking : .none, + writeTypeInfo: false + ) + } + if !options.valueIsNil { + if !options.valueDeclared { + try ValueCodec.writeStaticTypeInfo(context) + } + try ValueCodec.write( + pair.value, + context, + refMode: options.trackValueRef ? .tracking : .none, + writeTypeInfo: false + ) + } + } + + private static func writeMapPayload( + _ pair: Dictionary.Element, + _ context: WriteContext, + trackKeyRef: Bool, + trackValueRef: Bool + ) throws { + try KeyCodec.write( + pair.key, + context, + refMode: trackKeyRef ? .tracking : .none, + writeTypeInfo: false + ) + try ValueCodec.write( + pair.value, + context, + refMode: trackValueRef ? .tracking : .none, + writeTypeInfo: false + ) + } + + private static func readMapKey( + _ context: ReadContext, + declared: Bool, + trackRef: Bool + ) throws -> KeyCodec.Value { + let typeInfo = declared ? nil : try KeyCodec.readTypeInfo(context) + return try KeyCodec.withTypeInfo(typeInfo, context) { + try KeyCodec.read(context, refMode: trackRef ? .tracking : .none, readTypeInfo: false) } + } + + private static func readMapValue( + _ context: ReadContext, + declared: Bool, + trackRef: Bool + ) throws -> ValueCodec.Value { + let typeInfo = declared ? nil : try ValueCodec.readTypeInfo(context) + return try ValueCodec.withTypeInfo(typeInfo, context) { + try ValueCodec.read(context, refMode: trackRef ? .tracking : .none, readTypeInfo: false) + } + } } @inline(__always) private func uncheckedPackedArrayCast(_ array: [From], to _: To.Type) -> [To] { - assert(From.self == To.self) - return unsafeBitCast(array, to: [To].self) + assert(From.self == To.self) + return unsafeBitCast(array, to: [To].self) } @inline(__always) private func uncheckedScalarCast(_ value: From, to _: To.Type) -> To { - assert(From.self == To.self) - return unsafeBitCast(value, to: To.self) + assert(From.self == To.self) + return unsafeBitCast(value, to: To.self) } private func packedArrayTypeID(for _: ElementCodec.Type) -> TypeId? { - if ElementCodec.isNullableType { - return nil - } - if ElementCodec.self == BoolCodec.self { - return .boolArray - } - if ElementCodec.self == Int8Codec.self { - return .int8Array - } - if ElementCodec.self == Int16Codec.self { - return .int16Array - } - if ElementCodec.self == Int32FixedCodec.self { - return .int32Array - } - if ElementCodec.self == Int64FixedCodec.self || ElementCodec.self == IntFixedCodec.self { - return .int64Array - } - if ElementCodec.self == UInt8Codec.self { - return .uint8Array - } - if ElementCodec.self == UInt16Codec.self { - return .uint16Array - } - if ElementCodec.self == UInt32FixedCodec.self { - return .uint32Array - } - if ElementCodec.self == UInt64FixedCodec.self || ElementCodec.self == UIntFixedCodec.self { - return .uint64Array - } - if ElementCodec.self == Float16Codec.self { - return .float16Array - } - if ElementCodec.self == BFloat16Codec.self { - return .bfloat16Array - } - if ElementCodec.self == FloatCodec.self { - return .float32Array - } - if ElementCodec.self == DoubleCodec.self { - return .float64Array - } + if ElementCodec.isNullableType { return nil + } + if ElementCodec.self == BoolCodec.self { + return .boolArray + } + if ElementCodec.self == Int8Codec.self { + return .int8Array + } + if ElementCodec.self == Int16Codec.self { + return .int16Array + } + if ElementCodec.self == Int32FixedCodec.self { + return .int32Array + } + if ElementCodec.self == Int64FixedCodec.self || ElementCodec.self == IntFixedCodec.self { + return .int64Array + } + if ElementCodec.self == UInt8Codec.self { + return .uint8Array + } + if ElementCodec.self == UInt16Codec.self { + return .uint16Array + } + if ElementCodec.self == UInt32FixedCodec.self { + return .uint32Array + } + if ElementCodec.self == UInt64FixedCodec.self || ElementCodec.self == UIntFixedCodec.self { + return .uint64Array + } + if ElementCodec.self == Float16Codec.self { + return .float16Array + } + if ElementCodec.self == BFloat16Codec.self { + return .bfloat16Array + } + if ElementCodec.self == FloatCodec.self { + return .float32Array + } + if ElementCodec.self == DoubleCodec.self { + return .float64Array + } + return nil } private func isCompatiblePackedArrayTypeID( - _ typeID: UInt32, - elementCodec _: ElementCodec.Type + _ typeID: UInt32, + elementCodec _: ElementCodec.Type ) -> Bool { - TypeId.listElementTypeID(ElementCodec.typeId.rawValue, matchesDenseArrayTypeID: typeID) + TypeId.listElementTypeID(ElementCodec.typeId.rawValue, matchesDenseArrayTypeID: typeID) } private func writePackedArrayPayload( - _ value: [ElementCodec.Value], - _ context: WriteContext, - elementCodec _: ElementCodec.Type + _ value: [ElementCodec.Value], + _ context: WriteContext, + elementCodec _: ElementCodec.Type ) throws -> Bool { - if ElementCodec.self == BoolCodec.self { - writePrimitiveArray(uncheckedPackedArrayCast(value, to: Bool.self), context: context) - return true - } - if ElementCodec.self == Int8Codec.self { - writePrimitiveArray(uncheckedPackedArrayCast(value, to: Int8.self), context: context) - return true - } - if ElementCodec.self == Int16Codec.self { - writePrimitiveArray(uncheckedPackedArrayCast(value, to: Int16.self), context: context) - return true - } - if ElementCodec.self == Int32FixedCodec.self { - writePrimitiveArray(uncheckedPackedArrayCast(value, to: Int32.self), context: context) - return true - } - if ElementCodec.self == Int64FixedCodec.self { - writePrimitiveArray(uncheckedPackedArrayCast(value, to: Int64.self), context: context) - return true - } - if ElementCodec.self == IntFixedCodec.self { - writeIntArrayPayload(uncheckedPackedArrayCast(value, to: Int.self), context) - return true - } - if ElementCodec.self == UInt8Codec.self { - writePrimitiveArray(uncheckedPackedArrayCast(value, to: UInt8.self), context: context) - return true - } - if ElementCodec.self == UInt16Codec.self { - writePrimitiveArray(uncheckedPackedArrayCast(value, to: UInt16.self), context: context) - return true - } - if ElementCodec.self == UInt32FixedCodec.self { - writePrimitiveArray(uncheckedPackedArrayCast(value, to: UInt32.self), context: context) - return true - } - if ElementCodec.self == UInt64FixedCodec.self { - writePrimitiveArray(uncheckedPackedArrayCast(value, to: UInt64.self), context: context) - return true - } - if ElementCodec.self == UIntFixedCodec.self { - writeUIntArrayPayload(uncheckedPackedArrayCast(value, to: UInt.self), context) - return true - } - if ElementCodec.self == Float16Codec.self { - writePrimitiveArray(uncheckedPackedArrayCast(value, to: Float16.self), context: context) - return true - } - if ElementCodec.self == BFloat16Codec.self { - writePrimitiveArray(uncheckedPackedArrayCast(value, to: BFloat16.self), context: context) - return true - } - if ElementCodec.self == FloatCodec.self { - writePrimitiveArray(uncheckedPackedArrayCast(value, to: Float.self), context: context) - return true - } - if ElementCodec.self == DoubleCodec.self { - writePrimitiveArray(uncheckedPackedArrayCast(value, to: Double.self), context: context) - return true - } - return false + if ElementCodec.self == BoolCodec.self { + writePrimitiveArray(uncheckedPackedArrayCast(value, to: Bool.self), context: context) + return true + } + if ElementCodec.self == Int8Codec.self { + writePrimitiveArray(uncheckedPackedArrayCast(value, to: Int8.self), context: context) + return true + } + if ElementCodec.self == Int16Codec.self { + writePrimitiveArray(uncheckedPackedArrayCast(value, to: Int16.self), context: context) + return true + } + if ElementCodec.self == Int32FixedCodec.self { + writePrimitiveArray(uncheckedPackedArrayCast(value, to: Int32.self), context: context) + return true + } + if ElementCodec.self == Int64FixedCodec.self { + writePrimitiveArray(uncheckedPackedArrayCast(value, to: Int64.self), context: context) + return true + } + if ElementCodec.self == IntFixedCodec.self { + writeIntArrayPayload(uncheckedPackedArrayCast(value, to: Int.self), context) + return true + } + if ElementCodec.self == UInt8Codec.self { + writePrimitiveArray(uncheckedPackedArrayCast(value, to: UInt8.self), context: context) + return true + } + if ElementCodec.self == UInt16Codec.self { + writePrimitiveArray(uncheckedPackedArrayCast(value, to: UInt16.self), context: context) + return true + } + if ElementCodec.self == UInt32FixedCodec.self { + writePrimitiveArray(uncheckedPackedArrayCast(value, to: UInt32.self), context: context) + return true + } + if ElementCodec.self == UInt64FixedCodec.self { + writePrimitiveArray(uncheckedPackedArrayCast(value, to: UInt64.self), context: context) + return true + } + if ElementCodec.self == UIntFixedCodec.self { + writeUIntArrayPayload(uncheckedPackedArrayCast(value, to: UInt.self), context) + return true + } + if ElementCodec.self == Float16Codec.self { + writePrimitiveArray(uncheckedPackedArrayCast(value, to: Float16.self), context: context) + return true + } + if ElementCodec.self == BFloat16Codec.self { + writePrimitiveArray(uncheckedPackedArrayCast(value, to: BFloat16.self), context: context) + return true + } + if ElementCodec.self == FloatCodec.self { + writePrimitiveArray(uncheckedPackedArrayCast(value, to: Float.self), context: context) + return true + } + if ElementCodec.self == DoubleCodec.self { + writePrimitiveArray(uncheckedPackedArrayCast(value, to: Double.self), context: context) + return true + } + return false } private func readPackedArrayPayload( - _ context: ReadContext, - elementCodec _: ElementCodec.Type + _ context: ReadContext, + elementCodec _: ElementCodec.Type ) throws -> [ElementCodec.Value]? { - if ElementCodec.self == BoolCodec.self { - return uncheckedPackedArrayCast(try readPrimitiveArray(context) as [Bool], to: ElementCodec.Value.self) - } - if ElementCodec.self == Int8Codec.self { - return uncheckedPackedArrayCast(try readPrimitiveArray(context) as [Int8], to: ElementCodec.Value.self) - } - if ElementCodec.self == Int16Codec.self { - return uncheckedPackedArrayCast(try readPrimitiveArray(context) as [Int16], to: ElementCodec.Value.self) - } - if ElementCodec.self == Int32FixedCodec.self { - return uncheckedPackedArrayCast(try readPrimitiveArray(context) as [Int32], to: ElementCodec.Value.self) - } - if ElementCodec.self == Int64FixedCodec.self { - return uncheckedPackedArrayCast(try readPrimitiveArray(context) as [Int64], to: ElementCodec.Value.self) - } - if ElementCodec.self == IntFixedCodec.self { - return uncheckedPackedArrayCast(try readIntArrayPayload(context), to: ElementCodec.Value.self) - } - if ElementCodec.self == UInt8Codec.self { - return uncheckedPackedArrayCast(try readPrimitiveArray(context) as [UInt8], to: ElementCodec.Value.self) - } - if ElementCodec.self == UInt16Codec.self { - return uncheckedPackedArrayCast(try readPrimitiveArray(context) as [UInt16], to: ElementCodec.Value.self) - } - if ElementCodec.self == UInt32FixedCodec.self { - return uncheckedPackedArrayCast(try readPrimitiveArray(context) as [UInt32], to: ElementCodec.Value.self) - } - if ElementCodec.self == UInt64FixedCodec.self { - return uncheckedPackedArrayCast(try readPrimitiveArray(context) as [UInt64], to: ElementCodec.Value.self) - } - if ElementCodec.self == UIntFixedCodec.self { - return uncheckedPackedArrayCast(try readUIntArrayPayload(context), to: ElementCodec.Value.self) - } - if ElementCodec.self == Float16Codec.self { - return uncheckedPackedArrayCast(try readPrimitiveArray(context) as [Float16], to: ElementCodec.Value.self) - } - if ElementCodec.self == BFloat16Codec.self { - return uncheckedPackedArrayCast(try readPrimitiveArray(context) as [BFloat16], to: ElementCodec.Value.self) - } - if ElementCodec.self == FloatCodec.self { - return uncheckedPackedArrayCast(try readPrimitiveArray(context) as [Float], to: ElementCodec.Value.self) - } - if ElementCodec.self == DoubleCodec.self { - return uncheckedPackedArrayCast(try readPrimitiveArray(context) as [Double], to: ElementCodec.Value.self) - } - return nil + if ElementCodec.self == BoolCodec.self { + return uncheckedPackedArrayCast( + try readPrimitiveArray(context) as [Bool], to: ElementCodec.Value.self) + } + if ElementCodec.self == Int8Codec.self { + return uncheckedPackedArrayCast( + try readPrimitiveArray(context) as [Int8], to: ElementCodec.Value.self) + } + if ElementCodec.self == Int16Codec.self { + return uncheckedPackedArrayCast( + try readPrimitiveArray(context) as [Int16], to: ElementCodec.Value.self) + } + if ElementCodec.self == Int32FixedCodec.self { + return uncheckedPackedArrayCast( + try readPrimitiveArray(context) as [Int32], to: ElementCodec.Value.self) + } + if ElementCodec.self == Int64FixedCodec.self { + return uncheckedPackedArrayCast( + try readPrimitiveArray(context) as [Int64], to: ElementCodec.Value.self) + } + if ElementCodec.self == IntFixedCodec.self { + return uncheckedPackedArrayCast(try readIntArrayPayload(context), to: ElementCodec.Value.self) + } + if ElementCodec.self == UInt8Codec.self { + return uncheckedPackedArrayCast( + try readPrimitiveArray(context) as [UInt8], to: ElementCodec.Value.self) + } + if ElementCodec.self == UInt16Codec.self { + return uncheckedPackedArrayCast( + try readPrimitiveArray(context) as [UInt16], to: ElementCodec.Value.self) + } + if ElementCodec.self == UInt32FixedCodec.self { + return uncheckedPackedArrayCast( + try readPrimitiveArray(context) as [UInt32], to: ElementCodec.Value.self) + } + if ElementCodec.self == UInt64FixedCodec.self { + return uncheckedPackedArrayCast( + try readPrimitiveArray(context) as [UInt64], to: ElementCodec.Value.self) + } + if ElementCodec.self == UIntFixedCodec.self { + return uncheckedPackedArrayCast(try readUIntArrayPayload(context), to: ElementCodec.Value.self) + } + if ElementCodec.self == Float16Codec.self { + return uncheckedPackedArrayCast( + try readPrimitiveArray(context) as [Float16], to: ElementCodec.Value.self) + } + if ElementCodec.self == BFloat16Codec.self { + return uncheckedPackedArrayCast( + try readPrimitiveArray(context) as [BFloat16], to: ElementCodec.Value.self) + } + if ElementCodec.self == FloatCodec.self { + return uncheckedPackedArrayCast( + try readPrimitiveArray(context) as [Float], to: ElementCodec.Value.self) + } + if ElementCodec.self == DoubleCodec.self { + return uncheckedPackedArrayCast( + try readPrimitiveArray(context) as [Double], to: ElementCodec.Value.self) + } + return nil } private func writeIntArrayPayload(_ value: [Int], _ context: WriteContext) { - context.buffer.writeVarUInt32(UInt32(value.count * 8)) - for item in value { - context.buffer.writeInt64(Int64(item)) - } + context.buffer.writeVarUInt32(UInt32(value.count * 8)) + for item in value { + context.buffer.writeInt64(Int64(item)) + } } private func writeUIntArrayPayload(_ value: [UInt], _ context: WriteContext) { - context.buffer.writeVarUInt32(UInt32(value.count * 8)) - for item in value { - context.buffer.writeUInt64(UInt64(item)) - } + context.buffer.writeVarUInt32(UInt32(value.count * 8)) + for item in value { + context.buffer.writeUInt64(UInt64(item)) + } } -private func readIntArrayPayload(_ context: ReadContext, reserveContainerStorage: Bool = false) throws -> [Int] { - let count = try readPackedArrayElementCount(context, width: 8, label: "int64_array") - if reserveContainerStorage { - try reserveSerializerArrayMemory(context, Int.self, count: count) - } - var values: [Int] = [] - values.reserveCapacity(count) - for _ in 0.. [Int] +{ + let count = try readPackedArrayElementCount(context, width: 8, label: "int64_array") + if reserveGraphStorage { + try reserveSerializerArrayMemory(context, Int.self, count: count) + } + var values: [Int] = [] + values.reserveCapacity(count) + for _ in 0.. [UInt] { - let count = try readPackedArrayElementCount(context, width: 8, label: "uint64_array") - if reserveContainerStorage { - try reserveSerializerArrayMemory(context, UInt.self, count: count) - } - var values: [UInt] = [] - values.reserveCapacity(count) - for _ in 0.. [UInt] +{ + let count = try readPackedArrayElementCount(context, width: 8, label: "uint64_array") + if reserveGraphStorage { + try reserveSerializerArrayMemory(context, UInt.self, count: count) + } + var values: [UInt] = [] + values.reserveCapacity(count) + for _ in 0..( - _ context: ReadContext, - refMode: RefMode, - elementCodec _: ElementCodec.Type + _ context: ReadContext, + refMode: RefMode, + elementCodec _: ElementCodec.Type ) throws -> [ElementCodec.Value] { - switch refMode { - case .none: - return try readCompatiblePackedArrayPayload(context, elementCodec: ElementCodec.self) - case .nullOnly, .tracking: - let rawFlag = try context.buffer.readInt8() - guard rawFlag != RefFlag.null.rawValue else { - return [] - } - if rawFlag == RefFlag.ref.rawValue { - let refID = try context.buffer.readVarUInt32() - return try context.refReader.readRef(refID, as: [ElementCodec.Value].self) - } - let reservedRefID = - (rawFlag == RefFlag.refValue.rawValue && context.trackRef) - ? context.refReader.reserveRefID() - : nil - let value = try readCompatiblePackedArrayPayload(context, elementCodec: ElementCodec.self) - if let reservedRefID { - context.refReader.storeRef(value, at: reservedRefID) - } - return value - } + switch refMode { + case .none: + return try readCompatiblePackedArrayPayload(context, elementCodec: ElementCodec.self) + case .nullOnly, .tracking: + let rawFlag = try context.buffer.readInt8() + guard rawFlag != RefFlag.null.rawValue else { + return [] + } + if rawFlag == RefFlag.ref.rawValue { + let refID = try context.buffer.readVarUInt32() + return try context.refReader.readRef(refID, as: [ElementCodec.Value].self) + } + let reservedRefID = + (rawFlag == RefFlag.refValue.rawValue && context.trackRef) + ? context.refReader.reserveRefID() + : nil + let value = try readCompatiblePackedArrayPayload(context, elementCodec: ElementCodec.self) + if let reservedRefID { + context.refReader.storeRef(value, at: reservedRefID) + } + return value + } } private func readCompatiblePackedArrayPayload( - _ context: ReadContext, - elementCodec _: ElementCodec.Type + _ context: ReadContext, + elementCodec _: ElementCodec.Type ) throws -> [ElementCodec.Value] { - if ElementCodec.self == BoolCodec.self { - return uncheckedPackedArrayCast(try readPrimitiveArray(context, reserveContainerStorage: true) as [Bool], to: ElementCodec.Value.self) - } - if ElementCodec.self == Int8Codec.self { - return uncheckedPackedArrayCast(try readPrimitiveArray(context, reserveContainerStorage: true) as [Int8], to: ElementCodec.Value.self) - } - if ElementCodec.self == Int16Codec.self { - return uncheckedPackedArrayCast(try readPrimitiveArray(context, reserveContainerStorage: true) as [Int16], to: ElementCodec.Value.self) - } - if ElementCodec.self == Int32FixedCodec.self || ElementCodec.self == Int32VarintCodec.self { - return uncheckedPackedArrayCast(try readPrimitiveArray(context, reserveContainerStorage: true) as [Int32], to: ElementCodec.Value.self) - } - if ElementCodec.self == Int64FixedCodec.self || ElementCodec.self == Int64VarintCodec.self || ElementCodec.self == Int64TaggedCodec.self { - return uncheckedPackedArrayCast(try readPrimitiveArray(context, reserveContainerStorage: true) as [Int64], to: ElementCodec.Value.self) - } - if ElementCodec.self == IntFixedCodec.self || ElementCodec.self == IntVarintCodec.self || ElementCodec.self == IntTaggedCodec.self { - return uncheckedPackedArrayCast(try readIntArrayPayload(context, reserveContainerStorage: true), to: ElementCodec.Value.self) - } - if ElementCodec.self == UInt8Codec.self { - return uncheckedPackedArrayCast(try readPrimitiveArray(context, reserveContainerStorage: true) as [UInt8], to: ElementCodec.Value.self) - } - if ElementCodec.self == UInt16Codec.self { - return uncheckedPackedArrayCast(try readPrimitiveArray(context, reserveContainerStorage: true) as [UInt16], to: ElementCodec.Value.self) - } - if ElementCodec.self == UInt32FixedCodec.self || ElementCodec.self == UInt32VarintCodec.self { - return uncheckedPackedArrayCast(try readPrimitiveArray(context, reserveContainerStorage: true) as [UInt32], to: ElementCodec.Value.self) - } - if ElementCodec.self == UInt64FixedCodec.self || ElementCodec.self == UInt64VarintCodec.self || ElementCodec.self == UInt64TaggedCodec.self { - return uncheckedPackedArrayCast(try readPrimitiveArray(context, reserveContainerStorage: true) as [UInt64], to: ElementCodec.Value.self) - } - if ElementCodec.self == UIntFixedCodec.self || ElementCodec.self == UIntVarintCodec.self || ElementCodec.self == UIntTaggedCodec.self { - return uncheckedPackedArrayCast(try readUIntArrayPayload(context, reserveContainerStorage: true), to: ElementCodec.Value.self) - } - if ElementCodec.self == Float16Codec.self { - return uncheckedPackedArrayCast(try readPrimitiveArray(context, reserveContainerStorage: true) as [Float16], to: ElementCodec.Value.self) - } - if ElementCodec.self == BFloat16Codec.self { - return uncheckedPackedArrayCast(try readPrimitiveArray(context, reserveContainerStorage: true) as [BFloat16], to: ElementCodec.Value.self) - } - if ElementCodec.self == FloatCodec.self { - return uncheckedPackedArrayCast(try readPrimitiveArray(context, reserveContainerStorage: true) as [Float], to: ElementCodec.Value.self) - } - if ElementCodec.self == DoubleCodec.self { - return uncheckedPackedArrayCast(try readPrimitiveArray(context, reserveContainerStorage: true) as [Double], to: ElementCodec.Value.self) - } - throw ForyError.invalidData("unsupported compatible array-to-list field element codec \(ElementCodec.self)") + if ElementCodec.self == BoolCodec.self { + return uncheckedPackedArrayCast( + try readPrimitiveArray(context, reserveGraphStorage: true) as [Bool], + to: ElementCodec.Value.self) + } + if ElementCodec.self == Int8Codec.self { + return uncheckedPackedArrayCast( + try readPrimitiveArray(context, reserveGraphStorage: true) as [Int8], + to: ElementCodec.Value.self) + } + if ElementCodec.self == Int16Codec.self { + return uncheckedPackedArrayCast( + try readPrimitiveArray(context, reserveGraphStorage: true) as [Int16], + to: ElementCodec.Value.self) + } + if ElementCodec.self == Int32FixedCodec.self || ElementCodec.self == Int32VarintCodec.self { + return uncheckedPackedArrayCast( + try readPrimitiveArray(context, reserveGraphStorage: true) as [Int32], + to: ElementCodec.Value.self) + } + if ElementCodec.self == Int64FixedCodec.self || ElementCodec.self == Int64VarintCodec.self + || ElementCodec.self == Int64TaggedCodec.self + { + return uncheckedPackedArrayCast( + try readPrimitiveArray(context, reserveGraphStorage: true) as [Int64], + to: ElementCodec.Value.self) + } + if ElementCodec.self == IntFixedCodec.self || ElementCodec.self == IntVarintCodec.self + || ElementCodec.self == IntTaggedCodec.self + { + return uncheckedPackedArrayCast( + try readIntArrayPayload(context, reserveGraphStorage: true), to: ElementCodec.Value.self) + } + if ElementCodec.self == UInt8Codec.self { + return uncheckedPackedArrayCast( + try readPrimitiveArray(context, reserveGraphStorage: true) as [UInt8], + to: ElementCodec.Value.self) + } + if ElementCodec.self == UInt16Codec.self { + return uncheckedPackedArrayCast( + try readPrimitiveArray(context, reserveGraphStorage: true) as [UInt16], + to: ElementCodec.Value.self) + } + if ElementCodec.self == UInt32FixedCodec.self || ElementCodec.self == UInt32VarintCodec.self { + return uncheckedPackedArrayCast( + try readPrimitiveArray(context, reserveGraphStorage: true) as [UInt32], + to: ElementCodec.Value.self) + } + if ElementCodec.self == UInt64FixedCodec.self || ElementCodec.self == UInt64VarintCodec.self + || ElementCodec.self == UInt64TaggedCodec.self + { + return uncheckedPackedArrayCast( + try readPrimitiveArray(context, reserveGraphStorage: true) as [UInt64], + to: ElementCodec.Value.self) + } + if ElementCodec.self == UIntFixedCodec.self || ElementCodec.self == UIntVarintCodec.self + || ElementCodec.self == UIntTaggedCodec.self + { + return uncheckedPackedArrayCast( + try readUIntArrayPayload(context, reserveGraphStorage: true), to: ElementCodec.Value.self) + } + if ElementCodec.self == Float16Codec.self { + return uncheckedPackedArrayCast( + try readPrimitiveArray(context, reserveGraphStorage: true) as [Float16], + to: ElementCodec.Value.self) + } + if ElementCodec.self == BFloat16Codec.self { + return uncheckedPackedArrayCast( + try readPrimitiveArray(context, reserveGraphStorage: true) as [BFloat16], + to: ElementCodec.Value.self) + } + if ElementCodec.self == FloatCodec.self { + return uncheckedPackedArrayCast( + try readPrimitiveArray(context, reserveGraphStorage: true) as [Float], + to: ElementCodec.Value.self) + } + if ElementCodec.self == DoubleCodec.self { + return uncheckedPackedArrayCast( + try readPrimitiveArray(context, reserveGraphStorage: true) as [Double], + to: ElementCodec.Value.self) + } + throw ForyError.invalidData( + "unsupported compatible array-to-list field element codec \(ElementCodec.self)") } private func readCompatibleElementPayload( - _ context: ReadContext, - elementCodec _: ElementCodec.Type, - remoteElementTypeID: UInt32? + _ context: ReadContext, + elementCodec _: ElementCodec.Type, + remoteElementTypeID: UInt32? ) throws -> ElementCodec.Value { - guard let remoteElementTypeID, - remoteElementTypeID != ElementCodec.typeId.rawValue, - let remoteTypeID = TypeId(rawValue: remoteElementTypeID) - else { - return try ElementCodec.readPayload(context) - } - - if ElementCodec.self == Int32FixedCodec.self || ElementCodec.self == Int32VarintCodec.self { - switch remoteTypeID { - case .int32: - return uncheckedScalarCast(try context.buffer.readInt32() as Int32, to: ElementCodec.Value.self) - case .varint32: - return uncheckedScalarCast(try context.buffer.readVarInt32() as Int32, to: ElementCodec.Value.self) - default: - break - } - } - if ElementCodec.self == Int64FixedCodec.self || ElementCodec.self == Int64VarintCodec.self || ElementCodec.self == Int64TaggedCodec.self { - switch remoteTypeID { - case .int64: - return uncheckedScalarCast(try context.buffer.readInt64() as Int64, to: ElementCodec.Value.self) - case .varint64: - return uncheckedScalarCast(try context.buffer.readVarInt64() as Int64, to: ElementCodec.Value.self) - case .taggedInt64: - return uncheckedScalarCast(try context.buffer.readTaggedInt64() as Int64, to: ElementCodec.Value.self) - default: - break - } - } - if ElementCodec.self == IntFixedCodec.self || ElementCodec.self == IntVarintCodec.self || ElementCodec.self == IntTaggedCodec.self { - switch remoteTypeID { - case .int64: - return uncheckedScalarCast(Int(try context.buffer.readInt64()), to: ElementCodec.Value.self) - case .varint64: - return uncheckedScalarCast(Int(try context.buffer.readVarInt64()), to: ElementCodec.Value.self) - case .taggedInt64: - return uncheckedScalarCast(Int(try context.buffer.readTaggedInt64()), to: ElementCodec.Value.self) - default: - break - } - } - if ElementCodec.self == UInt32FixedCodec.self || ElementCodec.self == UInt32VarintCodec.self { - switch remoteTypeID { - case .uint32: - return uncheckedScalarCast(try context.buffer.readUInt32() as UInt32, to: ElementCodec.Value.self) - case .varUInt32: - return uncheckedScalarCast(try context.buffer.readVarUInt32() as UInt32, to: ElementCodec.Value.self) - default: - break - } - } - if ElementCodec.self == UInt64FixedCodec.self || ElementCodec.self == UInt64VarintCodec.self || ElementCodec.self == UInt64TaggedCodec.self { - switch remoteTypeID { - case .uint64: - return uncheckedScalarCast(try context.buffer.readUInt64() as UInt64, to: ElementCodec.Value.self) - case .varUInt64: - return uncheckedScalarCast(try context.buffer.readVarUInt64() as UInt64, to: ElementCodec.Value.self) - case .taggedUInt64: - return uncheckedScalarCast(try context.buffer.readTaggedUInt64() as UInt64, to: ElementCodec.Value.self) - default: - break - } - } - if ElementCodec.self == UIntFixedCodec.self || ElementCodec.self == UIntVarintCodec.self || ElementCodec.self == UIntTaggedCodec.self { - switch remoteTypeID { - case .uint64: - return uncheckedScalarCast(UInt(try context.buffer.readUInt64()), to: ElementCodec.Value.self) - case .varUInt64: - return uncheckedScalarCast(UInt(try context.buffer.readVarUInt64()), to: ElementCodec.Value.self) - case .taggedUInt64: - return uncheckedScalarCast(UInt(try context.buffer.readTaggedUInt64()), to: ElementCodec.Value.self) - default: - break - } - } - throw ForyError.typeMismatch(expected: ElementCodec.typeId.rawValue, actual: remoteElementTypeID) + guard let remoteElementTypeID, + remoteElementTypeID != ElementCodec.typeId.rawValue, + let remoteTypeID = TypeId(rawValue: remoteElementTypeID) + else { + return try ElementCodec.readPayload(context) + } + + if ElementCodec.self == Int32FixedCodec.self || ElementCodec.self == Int32VarintCodec.self { + switch remoteTypeID { + case .int32: + return uncheckedScalarCast( + try context.buffer.readInt32() as Int32, to: ElementCodec.Value.self) + case .varint32: + return uncheckedScalarCast( + try context.buffer.readVarInt32() as Int32, to: ElementCodec.Value.self) + default: + break + } + } + if ElementCodec.self == Int64FixedCodec.self || ElementCodec.self == Int64VarintCodec.self + || ElementCodec.self == Int64TaggedCodec.self + { + switch remoteTypeID { + case .int64: + return uncheckedScalarCast( + try context.buffer.readInt64() as Int64, to: ElementCodec.Value.self) + case .varint64: + return uncheckedScalarCast( + try context.buffer.readVarInt64() as Int64, to: ElementCodec.Value.self) + case .taggedInt64: + return uncheckedScalarCast( + try context.buffer.readTaggedInt64() as Int64, to: ElementCodec.Value.self) + default: + break + } + } + if ElementCodec.self == IntFixedCodec.self || ElementCodec.self == IntVarintCodec.self + || ElementCodec.self == IntTaggedCodec.self + { + switch remoteTypeID { + case .int64: + return uncheckedScalarCast(Int(try context.buffer.readInt64()), to: ElementCodec.Value.self) + case .varint64: + return uncheckedScalarCast( + Int(try context.buffer.readVarInt64()), to: ElementCodec.Value.self) + case .taggedInt64: + return uncheckedScalarCast( + Int(try context.buffer.readTaggedInt64()), to: ElementCodec.Value.self) + default: + break + } + } + if ElementCodec.self == UInt32FixedCodec.self || ElementCodec.self == UInt32VarintCodec.self { + switch remoteTypeID { + case .uint32: + return uncheckedScalarCast( + try context.buffer.readUInt32() as UInt32, to: ElementCodec.Value.self) + case .varUInt32: + return uncheckedScalarCast( + try context.buffer.readVarUInt32() as UInt32, to: ElementCodec.Value.self) + default: + break + } + } + if ElementCodec.self == UInt64FixedCodec.self || ElementCodec.self == UInt64VarintCodec.self + || ElementCodec.self == UInt64TaggedCodec.self + { + switch remoteTypeID { + case .uint64: + return uncheckedScalarCast( + try context.buffer.readUInt64() as UInt64, to: ElementCodec.Value.self) + case .varUInt64: + return uncheckedScalarCast( + try context.buffer.readVarUInt64() as UInt64, to: ElementCodec.Value.self) + case .taggedUInt64: + return uncheckedScalarCast( + try context.buffer.readTaggedUInt64() as UInt64, to: ElementCodec.Value.self) + default: + break + } + } + if ElementCodec.self == UIntFixedCodec.self || ElementCodec.self == UIntVarintCodec.self + || ElementCodec.self == UIntTaggedCodec.self + { + switch remoteTypeID { + case .uint64: + return uncheckedScalarCast(UInt(try context.buffer.readUInt64()), to: ElementCodec.Value.self) + case .varUInt64: + return uncheckedScalarCast( + UInt(try context.buffer.readVarUInt64()), to: ElementCodec.Value.self) + case .taggedUInt64: + return uncheckedScalarCast( + UInt(try context.buffer.readTaggedUInt64()), to: ElementCodec.Value.self) + default: + break + } + } + throw ForyError.typeMismatch(expected: ElementCodec.typeId.rawValue, actual: remoteElementTypeID) } private func readPackedArrayElementCount( - _ context: ReadContext, - width: Int, - label: String + _ context: ReadContext, + width: Int, + label: String ) throws -> Int { - let byteSize = Int(try context.buffer.readVarUInt32()) - try context.ensureRemainingBytes(byteSize, label: "primitive_array_bytes") - if byteSize % width != 0 { - throw ForyError.invalidData("\(label) byte size mismatch") - } - let count = byteSize / width - try context.ensureCollectionLength(count, label: label) - return count + let byteSize = Int(try context.buffer.readVarUInt32()) + try context.ensureRemainingBytes(byteSize, label: "primitive_array_bytes") + if byteSize % width != 0 { + throw ForyError.invalidData("\(label) byte size mismatch") + } + let count = byteSize / width + try context.ensureCollectionLength(count, label: label) + return count } private func writeCollectionPayload( - _ value: [ElementCodec.Value], - _ context: WriteContext, - elementCodec _: ElementCodec.Type + _ value: [ElementCodec.Value], + _ context: WriteContext, + elementCodec _: ElementCodec.Type ) throws { - let buffer = context.buffer - buffer.writeVarUInt32(UInt32(value.count)) - if value.isEmpty { - return - } - - let hasNull = ElementCodec.isNullableType && value.contains(where: ElementCodec.isNone) - let trackRef = context.trackRef && ElementCodec.isRefType - let declaredElementType = !TypeId.needsTypeInfoForField(ElementCodec.typeId) - let dynamicElementType = ElementCodec.typeId == .unknown - - var header: UInt8 = dynamicElementType ? 0 : CollectionHeader.sameType - if trackRef { - header |= CollectionHeader.trackingRef - } - if hasNull { - header |= CollectionHeader.hasNull - } - if declaredElementType { - header |= CollectionHeader.declaredElementType - } - - buffer.writeUInt8(header) - if !dynamicElementType && !declaredElementType { - try ElementCodec.writeStaticTypeInfo(context) - } - - if dynamicElementType { - let refMode = RefMode.from(nullable: hasNull, trackRef: trackRef) - for element in value { - try ElementCodec.write(element, context, refMode: refMode, writeTypeInfo: true) - } - return - } - - if trackRef { - for element in value { - try ElementCodec.write(element, context, refMode: .tracking, writeTypeInfo: false) - } - } else if hasNull { - for element in value { - if ElementCodec.isNone(element) { - buffer.writeInt8(RefFlag.null.rawValue) - } else { - buffer.writeInt8(RefFlag.notNullValue.rawValue) - try ElementCodec.writePayload(element, context) - } - } - } else { - for element in value { - try ElementCodec.writePayload(element, context) - } - } + let buffer = context.buffer + buffer.writeVarUInt32(UInt32(value.count)) + if value.isEmpty { + return + } + + let hasNull = ElementCodec.isNullableType && value.contains(where: ElementCodec.isNone) + let trackRef = context.trackRef && ElementCodec.isRefType + let declaredElementType = !TypeId.needsTypeInfoForField(ElementCodec.typeId) + let dynamicElementType = ElementCodec.typeId == .unknown + + var header: UInt8 = dynamicElementType ? 0 : CollectionHeader.sameType + if trackRef { + header |= CollectionHeader.trackingRef + } + if hasNull { + header |= CollectionHeader.hasNull + } + if declaredElementType { + header |= CollectionHeader.declaredElementType + } + + buffer.writeUInt8(header) + if !dynamicElementType && !declaredElementType { + try ElementCodec.writeStaticTypeInfo(context) + } + + if dynamicElementType { + let refMode = RefMode.from(nullable: hasNull, trackRef: trackRef) + for element in value { + try ElementCodec.write(element, context, refMode: refMode, writeTypeInfo: true) + } + return + } + + if trackRef { + for element in value { + try ElementCodec.write(element, context, refMode: .tracking, writeTypeInfo: false) + } + } else if hasNull { + for element in value { + if ElementCodec.isNone(element) { + buffer.writeInt8(RefFlag.null.rawValue) + } else { + buffer.writeInt8(RefFlag.notNullValue.rawValue) + try ElementCodec.writePayload(element, context) + } + } + } else { + for element in value { + try ElementCodec.writePayload(element, context) + } + } } private func readCollectionPayload( - _ context: ReadContext, - elementCodec _: ElementCodec.Type + _ context: ReadContext, + elementCodec _: ElementCodec.Type ) throws -> [ElementCodec.Value] { - let buffer = context.buffer - let length = Int(try buffer.readVarUInt32()) - try context.ensureCollectionLength(length, label: "array") - if length == 0 { - try reserveFieldArrayStorage(context, ElementCodec.self, count: length) - return [] - } - - let header = try buffer.readUInt8() - // IMPORTANT: collection readers must obey the ref/null bits written on the - // wire, not the local Swift element metadata that may imply a different - // ref policy. Shared xlang tests intentionally deserialize one ref policy - // and then serialize another local payload. DO NOT REMOVE this comment. - let trackRef = (header & CollectionHeader.trackingRef) != 0 - let hasNull = (header & CollectionHeader.hasNull) != 0 - let declared = (header & CollectionHeader.declaredElementType) != 0 - let sameType = (header & CollectionHeader.sameType) != 0 - - var result: [ElementCodec.Value] = [] + let buffer = context.buffer + let length = Int(try buffer.readVarUInt32()) + try context.ensureCollectionLength(length, label: "array") + if length == 0 { try reserveFieldArrayStorage(context, ElementCodec.self, count: length) - try context.ensureRemainingBytes(length, label: "array") - result.reserveCapacity(length) - - if !sameType { - let refMode = RefMode.from(nullable: hasNull, trackRef: trackRef) - for _ in 0..( - _ context: ReadContext, - refMode: RefMode, - elementCodec _: ElementCodec.Type, - remoteElementTypeID: UInt32 + _ context: ReadContext, + refMode: RefMode, + elementCodec _: ElementCodec.Type, + remoteElementTypeID: UInt32 ) throws -> [ElementCodec.Value] { - switch refMode { - case .none: - return try readListPayloadAsArrayPayload( - context, - elementCodec: ElementCodec.self, - remoteElementTypeID: remoteElementTypeID - ) - case .nullOnly, .tracking: - let rawFlag = try context.buffer.readInt8() - guard rawFlag != RefFlag.null.rawValue else { - return [] - } - if rawFlag == RefFlag.ref.rawValue { - let refID = try context.buffer.readVarUInt32() - return try context.refReader.readRef(refID, as: [ElementCodec.Value].self) - } - let reservedRefID = - (rawFlag == RefFlag.refValue.rawValue && context.trackRef) - ? context.refReader.reserveRefID() - : nil - let value = try readListPayloadAsArrayPayload( - context, - elementCodec: ElementCodec.self, - remoteElementTypeID: remoteElementTypeID - ) - if let reservedRefID { - context.refReader.storeRef(value, at: reservedRefID) - } - return value - } + switch refMode { + case .none: + return try readListPayloadAsArrayPayload( + context, + elementCodec: ElementCodec.self, + remoteElementTypeID: remoteElementTypeID + ) + case .nullOnly, .tracking: + let rawFlag = try context.buffer.readInt8() + guard rawFlag != RefFlag.null.rawValue else { + return [] + } + if rawFlag == RefFlag.ref.rawValue { + let refID = try context.buffer.readVarUInt32() + return try context.refReader.readRef(refID, as: [ElementCodec.Value].self) + } + let reservedRefID = + (rawFlag == RefFlag.refValue.rawValue && context.trackRef) + ? context.refReader.reserveRefID() + : nil + let value = try readListPayloadAsArrayPayload( + context, + elementCodec: ElementCodec.self, + remoteElementTypeID: remoteElementTypeID + ) + if let reservedRefID { + context.refReader.storeRef(value, at: reservedRefID) + } + return value + } } private func readListPayloadAsArrayPayload( - _ context: ReadContext, - elementCodec _: ElementCodec.Type, - remoteElementTypeID: UInt32 + _ context: ReadContext, + elementCodec _: ElementCodec.Type, + remoteElementTypeID: UInt32 ) throws -> [ElementCodec.Value] { - let buffer = context.buffer - let length = Int(try buffer.readVarUInt32()) - try context.ensureCollectionLength(length, label: "array") - if length == 0 { - try reserveFieldArrayStorage(context, ElementCodec.self, count: length) - return [] - } - - let header = try buffer.readUInt8() - let trackRef = (header & CollectionHeader.trackingRef) != 0 - let hasNull = (header & CollectionHeader.hasNull) != 0 - if hasNull { - throw ForyError.invalidData("compatible list-to-array field cannot read nullable elements") - } - let declared = (header & CollectionHeader.declaredElementType) != 0 - let sameType = (header & CollectionHeader.sameType) != 0 - - if !sameType { - throw ForyError.invalidData("compatible list-to-array field requires same-type elements") - } - - if trackRef { - throw ForyError.invalidData("compatible list-to-array field cannot read ref-tracked elements") - } - let elementTypeInfo: TypeInfo? - if declared { - elementTypeInfo = nil - } else { - throw ForyError.invalidData("compatible list-to-array field requires declared elements") - } - try context.ensureRemainingBytes(length, label: "array") - var result: [ElementCodec.Value] = [] + let buffer = context.buffer + let length = Int(try buffer.readVarUInt32()) + try context.ensureCollectionLength(length, label: "array") + if length == 0 { try reserveFieldArrayStorage(context, ElementCodec.self, count: length) - result.reserveCapacity(length) - return try ElementCodec.withTypeInfo(elementTypeInfo, context) { - for _ in 0.. Bool { - guard let resolved = TypeId(rawValue: typeID) else { - return true - } - return TypeId.needsTypeInfoForField(resolved) + private func needsTypeInfoForSkippedField(_ typeID: UInt32) -> Bool { + guard let resolved = TypeId(rawValue: typeID) else { + return true } + return TypeId.needsTypeInfoForField(resolved) + } - private func readSkippedFieldValue( - fieldType: TypeMeta.FieldType, - typeInfo: TypeInfo? = nil, - readTypeInfo: Bool - ) throws -> Any? { - let refMode = RefMode.from(nullable: fieldType.nullable, trackRef: fieldType.trackRef) - return try readSkippedValue( - fieldType: fieldType, - typeInfo: typeInfo, - refMode: refMode, - readTypeInfo: readTypeInfo + private func readSkippedFieldValue( + fieldType: TypeMeta.FieldType, + typeInfo: TypeInfo? = nil, + readTypeInfo: Bool + ) throws -> Any? { + let refMode = RefMode.from(nullable: fieldType.nullable, trackRef: fieldType.trackRef) + return try readSkippedValue( + fieldType: fieldType, + typeInfo: typeInfo, + refMode: refMode, + readTypeInfo: readTypeInfo + ) + } + + private func readSkippedValue( + fieldType: TypeMeta.FieldType, + typeInfo: TypeInfo?, + refMode: RefMode, + readTypeInfo: Bool + ) throws -> Any? { + switch refMode { + case .none: + return try readSkippedFieldPayload( + fieldType: fieldType, + typeInfo: typeInfo, + readTypeInfo: readTypeInfo + ) + case .nullOnly: + let flag = try buffer.readInt8() + if flag == RefFlag.null.rawValue { + return nil + } + guard flag == RefFlag.notNullValue.rawValue else { + throw ForyError.invalidData("unexpected nullOnly flag \(flag)") + } + return try readSkippedFieldPayload( + fieldType: fieldType, + typeInfo: typeInfo, + readTypeInfo: readTypeInfo + ) + case .tracking: + let rawFlag = try buffer.readInt8() + guard let flag = RefFlag(rawValue: rawFlag) else { + throw ForyError.invalidData("unexpected tracking flag \(rawFlag)") + } + + switch flag { + case .null: + return nil + case .ref: + let refID = try buffer.readVarUInt32() + return try refReader.readRefValue(refID) + case .refValue: + let refID = refReader.reserveRefID() + let value = try readSkippedFieldPayload( + fieldType: fieldType, + typeInfo: typeInfo, + readTypeInfo: readTypeInfo + ) + refReader.storeRef(value, at: refID) + return value + case .notNullValue: + return try readSkippedFieldPayload( + fieldType: fieldType, + typeInfo: typeInfo, + readTypeInfo: readTypeInfo ) + } } + } - private func readSkippedValue( - fieldType: TypeMeta.FieldType, - typeInfo: TypeInfo?, - refMode: RefMode, - readTypeInfo: Bool - ) throws -> Any? { - switch refMode { - case .none: - return try readSkippedFieldPayload( - fieldType: fieldType, - typeInfo: typeInfo, - readTypeInfo: readTypeInfo - ) - case .nullOnly: - let flag = try buffer.readInt8() - if flag == RefFlag.null.rawValue { - return nil - } - guard flag == RefFlag.notNullValue.rawValue else { - throw ForyError.invalidData("unexpected nullOnly flag \(flag)") - } - return try readSkippedFieldPayload( - fieldType: fieldType, - typeInfo: typeInfo, - readTypeInfo: readTypeInfo - ) - case .tracking: - let rawFlag = try buffer.readInt8() - guard let flag = RefFlag(rawValue: rawFlag) else { - throw ForyError.invalidData("unexpected tracking flag \(rawFlag)") - } - - switch flag { - case .null: - return nil - case .ref: - let refID = try buffer.readVarUInt32() - return try refReader.readRefValue(refID) - case .refValue: - let refID = refReader.reserveRefID() - let value = try readSkippedFieldPayload( - fieldType: fieldType, - typeInfo: typeInfo, - readTypeInfo: readTypeInfo - ) - refReader.storeRef(value, at: refID) - return value - case .notNullValue: - return try readSkippedFieldPayload( - fieldType: fieldType, - typeInfo: typeInfo, - readTypeInfo: readTypeInfo - ) - } - } + private func readSkippedFieldPayload( + fieldType: TypeMeta.FieldType, + typeInfo: TypeInfo?, + readTypeInfo: Bool + ) throws -> Any { + if let typeInfo { + return try readAnyValue(typeInfo: typeInfo) + } + if readTypeInfo { + let typeInfo = try self.readTypeInfo() + return try readAnyValue(typeInfo: typeInfo) } - private func readSkippedFieldPayload( - fieldType: TypeMeta.FieldType, - typeInfo: TypeInfo?, - readTypeInfo: Bool - ) throws -> Any { - if let typeInfo { - return try readAnyValue(typeInfo: typeInfo) - } - if readTypeInfo { - let typeInfo = try self.readTypeInfo() - return try readAnyValue(typeInfo: typeInfo) - } + guard let resolvedTypeID = TypeId(rawValue: fieldType.typeID) else { + throw ForyError.invalidData("unknown compatible field type id \(fieldType.typeID)") + } - guard let resolvedTypeID = TypeId(rawValue: fieldType.typeID) else { - throw ForyError.invalidData("unknown compatible field type id \(fieldType.typeID)") - } + switch resolvedTypeID { + case .none: + return ForyAnyNullValue() + case .bool: + return try Bool.foryRead(self, refMode: .none, readTypeInfo: false) + case .int8: + return try Int8.foryRead(self, refMode: .none, readTypeInfo: false) + case .int16: + return try Int16.foryRead(self, refMode: .none, readTypeInfo: false) + case .int32: + return try buffer.readInt32() + case .varint32: + return try Int32.foryRead(self, refMode: .none, readTypeInfo: false) + case .int64: + return try buffer.readInt64() + case .varint64: + return try Int64.foryRead(self, refMode: .none, readTypeInfo: false) + case .taggedInt64: + return try buffer.readTaggedInt64() + case .uint8: + return try UInt8.foryRead(self, refMode: .none, readTypeInfo: false) + case .uint16: + return try UInt16.foryRead(self, refMode: .none, readTypeInfo: false) + case .uint32: + return try buffer.readUInt32() + case .varUInt32: + return try UInt32.foryRead(self, refMode: .none, readTypeInfo: false) + case .uint64: + return try buffer.readUInt64() + case .varUInt64: + return try UInt64.foryRead(self, refMode: .none, readTypeInfo: false) + case .taggedUInt64: + return try buffer.readTaggedUInt64() + case .float16: + return try Float16.foryRead(self, refMode: .none, readTypeInfo: false) + case .bfloat16: + return try BFloat16.foryRead(self, refMode: .none, readTypeInfo: false) + case .float32: + return try Float.foryRead(self, refMode: .none, readTypeInfo: false) + case .float64: + return try Double.foryRead(self, refMode: .none, readTypeInfo: false) + case .string: + return try String.foryRead(self, refMode: .none, readTypeInfo: false) + case .duration: + return try Duration.foryRead(self, refMode: .none, readTypeInfo: false) + case .timestamp: + return try Date.foryRead(self, refMode: .none, readTypeInfo: false) + case .date: + return try LocalDate.foryRead(self, refMode: .none, readTypeInfo: false) + case .decimal: + return try Decimal.foryRead(self, refMode: .none, readTypeInfo: false) + case .binary, .uint8Array: + return try Data.foryRead(self, refMode: .none, readTypeInfo: false) + case .boolArray: + return try [Bool].foryRead(self, refMode: .none, readTypeInfo: false) + case .int8Array: + return try [Int8].foryRead(self, refMode: .none, readTypeInfo: false) + case .int16Array: + return try [Int16].foryRead(self, refMode: .none, readTypeInfo: false) + case .int32Array: + return try [Int32].foryRead(self, refMode: .none, readTypeInfo: false) + case .int64Array: + return try [Int64].foryRead(self, refMode: .none, readTypeInfo: false) + case .uint16Array: + return try [UInt16].foryRead(self, refMode: .none, readTypeInfo: false) + case .uint32Array: + return try [UInt32].foryRead(self, refMode: .none, readTypeInfo: false) + case .uint64Array: + return try [UInt64].foryRead(self, refMode: .none, readTypeInfo: false) + case .float16Array: + return try [Float16].foryRead(self, refMode: .none, readTypeInfo: false) + case .bfloat16Array: + return try [BFloat16].foryRead(self, refMode: .none, readTypeInfo: false) + case .float32Array: + return try [Float].foryRead(self, refMode: .none, readTypeInfo: false) + case .float64Array: + return try [Double].foryRead(self, refMode: .none, readTypeInfo: false) + case .array, .list: + return try readSkippedCollection(fieldType: fieldType) + case .set: + return try readSkippedSet(fieldType: fieldType) + case .map: + return try readSkippedMap(fieldType: fieldType) + case .union, .typedUnion, .namedUnion: + return try readSkippedUnion() + case .enumType, .namedEnum: + return try buffer.readVarUInt32() + default: + throw ForyError.invalidData("unsupported compatible field type id \(fieldType.typeID)") + } + } - switch resolvedTypeID { - case .none: - return ForyAnyNullValue() - case .bool: - return try Bool.foryRead(self, refMode: .none, readTypeInfo: false) - case .int8: - return try Int8.foryRead(self, refMode: .none, readTypeInfo: false) - case .int16: - return try Int16.foryRead(self, refMode: .none, readTypeInfo: false) - case .int32: - return try buffer.readInt32() - case .varint32: - return try Int32.foryRead(self, refMode: .none, readTypeInfo: false) - case .int64: - return try buffer.readInt64() - case .varint64: - return try Int64.foryRead(self, refMode: .none, readTypeInfo: false) - case .taggedInt64: - return try buffer.readTaggedInt64() - case .uint8: - return try UInt8.foryRead(self, refMode: .none, readTypeInfo: false) - case .uint16: - return try UInt16.foryRead(self, refMode: .none, readTypeInfo: false) - case .uint32: - return try buffer.readUInt32() - case .varUInt32: - return try UInt32.foryRead(self, refMode: .none, readTypeInfo: false) - case .uint64: - return try buffer.readUInt64() - case .varUInt64: - return try UInt64.foryRead(self, refMode: .none, readTypeInfo: false) - case .taggedUInt64: - return try buffer.readTaggedUInt64() - case .float16: - return try Float16.foryRead(self, refMode: .none, readTypeInfo: false) - case .bfloat16: - return try BFloat16.foryRead(self, refMode: .none, readTypeInfo: false) - case .float32: - return try Float.foryRead(self, refMode: .none, readTypeInfo: false) - case .float64: - return try Double.foryRead(self, refMode: .none, readTypeInfo: false) - case .string: - return try String.foryRead(self, refMode: .none, readTypeInfo: false) - case .duration: - return try Duration.foryRead(self, refMode: .none, readTypeInfo: false) - case .timestamp: - return try Date.foryRead(self, refMode: .none, readTypeInfo: false) - case .date: - return try LocalDate.foryRead(self, refMode: .none, readTypeInfo: false) - case .decimal: - return try Decimal.foryRead(self, refMode: .none, readTypeInfo: false) - case .binary, .uint8Array: - return try Data.foryRead(self, refMode: .none, readTypeInfo: false) - case .boolArray: - return try [Bool].foryRead(self, refMode: .none, readTypeInfo: false) - case .int8Array: - return try [Int8].foryRead(self, refMode: .none, readTypeInfo: false) - case .int16Array: - return try [Int16].foryRead(self, refMode: .none, readTypeInfo: false) - case .int32Array: - return try [Int32].foryRead(self, refMode: .none, readTypeInfo: false) - case .int64Array: - return try [Int64].foryRead(self, refMode: .none, readTypeInfo: false) - case .uint16Array: - return try [UInt16].foryRead(self, refMode: .none, readTypeInfo: false) - case .uint32Array: - return try [UInt32].foryRead(self, refMode: .none, readTypeInfo: false) - case .uint64Array: - return try [UInt64].foryRead(self, refMode: .none, readTypeInfo: false) - case .float16Array: - return try [Float16].foryRead(self, refMode: .none, readTypeInfo: false) - case .bfloat16Array: - return try [BFloat16].foryRead(self, refMode: .none, readTypeInfo: false) - case .float32Array: - return try [Float].foryRead(self, refMode: .none, readTypeInfo: false) - case .float64Array: - return try [Double].foryRead(self, refMode: .none, readTypeInfo: false) - case .array, .list: - return try readSkippedCollection(fieldType: fieldType) - case .set: - return try readSkippedSet(fieldType: fieldType) - case .map: - return try readSkippedMap(fieldType: fieldType) - case .union, .typedUnion, .namedUnion: - return try readSkippedUnion() - case .enumType, .namedEnum: - return try buffer.readVarUInt32() - default: - throw ForyError.invalidData("unsupported compatible field type id \(fieldType.typeID)") - } + private func readSkippedCollection( + fieldType: TypeMeta.FieldType + ) throws -> [Any] { + let elementFieldType = + fieldType.generics.first + ?? TypeMeta.FieldType(typeID: TypeId.unknown.rawValue, nullable: true) + let length = Int(try buffer.readVarUInt32()) + try ensureCollectionLength(length, label: "compatible_collection") + if length == 0 { + return [] } - private func readSkippedCollection( - fieldType: TypeMeta.FieldType - ) throws -> [Any] { - let elementFieldType = - fieldType.generics.first - ?? TypeMeta.FieldType(typeID: TypeId.unknown.rawValue, nullable: true) - let length = Int(try buffer.readVarUInt32()) - try ensureCollectionLength(length, label: "compatible_collection") - if length == 0 { - return [] - } + let header = try buffer.readUInt8() + let trackRef = (header & 0b0000_0001) != 0 + let hasNull = (header & 0b0000_0010) != 0 + let declared = (header & 0b0000_0100) != 0 + let sameType = (header & 0b0000_1000) != 0 - let header = try buffer.readUInt8() - let trackRef = (header & 0b0000_0001) != 0 - let hasNull = (header & 0b0000_0010) != 0 - let declared = (header & 0b0000_0100) != 0 - let sameType = (header & 0b0000_1000) != 0 + var typeInfo: TypeInfo? + if sameType, !declared { + typeInfo = try self.readTypeInfo() + } - var typeInfo: TypeInfo? - if sameType, !declared { - typeInfo = try self.readTypeInfo() + for _ in 0.. Set { - _ = try readSkippedCollection(fieldType: fieldType) - return [] - } + return [] + } - private func readSkippedMap( - fieldType: TypeMeta.FieldType - ) throws -> [AnyHashable: Any] { - let keyType = - fieldType.generics.first - ?? TypeMeta.FieldType(typeID: TypeId.unknown.rawValue, nullable: true) - let valueType = - fieldType.generics.dropFirst().first - ?? TypeMeta.FieldType(typeID: TypeId.unknown.rawValue, nullable: true) + private func readSkippedSet( + fieldType: TypeMeta.FieldType + ) throws -> Set { + _ = try readSkippedCollection(fieldType: fieldType) + return [] + } - let totalLength = Int(try buffer.readVarUInt32()) - try ensureCollectionLength(totalLength, label: "compatible_map") - if totalLength == 0 { - return [:] - } + private func readSkippedMap( + fieldType: TypeMeta.FieldType + ) throws -> [AnyHashable: Any] { + let keyType = + fieldType.generics.first + ?? TypeMeta.FieldType(typeID: TypeId.unknown.rawValue, nullable: true) + let valueType = + fieldType.generics.dropFirst().first + ?? TypeMeta.FieldType(typeID: TypeId.unknown.rawValue, nullable: true) - var readCount = 0 - while readCount < totalLength { - let header = try buffer.readUInt8() - let trackKeyRef = (header & 0b0000_0001) != 0 - let keyNull = (header & 0b0000_0010) != 0 - let keyDeclared = (header & 0b0000_0100) != 0 + let totalLength = Int(try buffer.readVarUInt32()) + try ensureCollectionLength(totalLength, label: "compatible_map") + if totalLength == 0 { + return [:] + } - let trackValueRef = (header & 0b0000_1000) != 0 - let valueNull = (header & 0b0001_0000) != 0 - let valueDeclared = (header & 0b0010_0000) != 0 + var readCount = 0 + while readCount < totalLength { + let header = try buffer.readUInt8() + let trackKeyRef = (header & 0b0000_0001) != 0 + let keyNull = (header & 0b0000_0010) != 0 + let keyDeclared = (header & 0b0000_0100) != 0 - if keyNull && valueNull { - readCount += 1 - continue - } + let trackValueRef = (header & 0b0000_1000) != 0 + let valueNull = (header & 0b0001_0000) != 0 + let valueDeclared = (header & 0b0010_0000) != 0 - if keyNull { - let valueTypeInfo = valueDeclared ? nil : try self.readTypeInfo() - _ = try readSkippedValue( - fieldType: valueType, - typeInfo: valueTypeInfo, - refMode: trackValueRef ? .tracking : .none, - readTypeInfo: false - ) - readCount += 1 - continue - } + if keyNull && valueNull { + readCount += 1 + continue + } - if valueNull { - let keyTypeInfo = keyDeclared ? nil : try self.readTypeInfo() - _ = try readSkippedValue( - fieldType: keyType, - typeInfo: keyTypeInfo, - refMode: trackKeyRef ? .tracking : .none, - readTypeInfo: false - ) - readCount += 1 - continue - } + if keyNull { + let valueTypeInfo = valueDeclared ? nil : try self.readTypeInfo() + _ = try readSkippedValue( + fieldType: valueType, + typeInfo: valueTypeInfo, + refMode: trackValueRef ? .tracking : .none, + readTypeInfo: false + ) + readCount += 1 + continue + } - let chunkSize = Int(try buffer.readUInt8()) - if chunkSize <= 0 { - throw ForyError.invalidData("invalid map chunk size \(chunkSize)") - } - if chunkSize > (totalLength - readCount) { - throw ForyError.invalidData("map chunk size exceeds remaining entries") - } + if valueNull { + let keyTypeInfo = keyDeclared ? nil : try self.readTypeInfo() + _ = try readSkippedValue( + fieldType: keyType, + typeInfo: keyTypeInfo, + refMode: trackKeyRef ? .tracking : .none, + readTypeInfo: false + ) + readCount += 1 + continue + } - let keyTypeInfo = keyDeclared ? nil : try self.readTypeInfo() - let valueTypeInfo = valueDeclared ? nil : try self.readTypeInfo() + let chunkSize = Int(try buffer.readUInt8()) + if chunkSize <= 0 { + throw ForyError.invalidData("invalid map chunk size \(chunkSize)") + } + if chunkSize > (totalLength - readCount) { + throw ForyError.invalidData("map chunk size exceeds remaining entries") + } - for _ in 0.. Any { - _ = try buffer.readVarUInt32() - return try readAny(refMode: .tracking, readTypeInfo: true) ?? ForyAnyNullValue() - } + return [:] + } + + private func readSkippedUnion() throws -> Any { + _ = try buffer.readVarUInt32() + return try readAny(context: self, refMode: .tracking, readTypeInfo: true) + ?? ForyAnyNullValue() + } } diff --git a/swift/Sources/Fory/Fory.swift b/swift/Sources/Fory/Fory.swift index 70b5189f62..c473831405 100644 --- a/swift/Sources/Fory/Fory.swift +++ b/swift/Sources/Fory/Fory.swift @@ -22,7 +22,7 @@ public struct Config { public let compatible: Bool public let checkClassVersion: Bool public let maxDepth: Int - public let maxContainerMemoryBytes: Int64 + public let maxGraphMemoryBytes: Int64 public let maxTypeFields: Int public let maxTypeMetaBytes: Int public let maxSchemaVersionsPerType: Int @@ -33,15 +33,15 @@ public struct Config { compatible: Bool? = nil, checkClassVersion: Bool? = nil, maxDepth: Int = 5, - maxContainerMemoryBytes: Int64 = -1, + maxGraphMemoryBytes: Int64 = -1, maxTypeFields: Int = 512, maxTypeMetaBytes: Int = 4096, maxSchemaVersionsPerType: Int = 10, maxAverageSchemaVersionsPerType: Int = 3 ) { precondition( - maxContainerMemoryBytes == -1 || maxContainerMemoryBytes > 0, - "maxContainerMemoryBytes must be positive or -1 for auto") + maxGraphMemoryBytes == -1 || maxGraphMemoryBytes > 0, + "maxGraphMemoryBytes must be positive or -1 for auto") precondition(maxTypeFields > 0, "maxTypeFields must be positive") precondition(maxTypeMetaBytes > 0, "maxTypeMetaBytes must be positive") precondition(maxSchemaVersionsPerType > 0, "maxSchemaVersionsPerType must be positive") @@ -54,7 +54,7 @@ public struct Config { self.compatible = effectiveCompatible self.checkClassVersion = effectiveCheckClassVersion self.maxDepth = maxDepth - self.maxContainerMemoryBytes = maxContainerMemoryBytes + self.maxGraphMemoryBytes = maxGraphMemoryBytes self.maxTypeFields = maxTypeFields self.maxTypeMetaBytes = maxTypeMetaBytes self.maxSchemaVersionsPerType = maxSchemaVersionsPerType @@ -78,7 +78,7 @@ public final class Fory { compatible: Bool? = nil, checkClassVersion: Bool? = nil, maxDepth: Int = 5, - maxContainerMemoryBytes: Int64 = -1, + maxGraphMemoryBytes: Int64 = -1, maxTypeFields: Int = 512, maxTypeMetaBytes: Int = 4096, maxSchemaVersionsPerType: Int = 10, @@ -90,7 +90,7 @@ public final class Fory { compatible: compatible, checkClassVersion: checkClassVersion, maxDepth: maxDepth, - maxContainerMemoryBytes: maxContainerMemoryBytes, + maxGraphMemoryBytes: maxGraphMemoryBytes, maxTypeFields: maxTypeFields, maxTypeMetaBytes: maxTypeMetaBytes, maxSchemaVersionsPerType: maxSchemaVersionsPerType, @@ -146,7 +146,8 @@ public final class Fory { } } - public func deserialize(from buffer: ByteBuffer, as _: T.Type = T.self) throws -> T { + public func deserialize(from buffer: ByteBuffer, as _: T.Type = T.self) throws -> T + { try deserializeRoot( from: buffer ) { context in @@ -167,7 +168,7 @@ public final class Fory { data: data ) { context in try castAnyDynamicValue( - context.readAny(refMode: refMode, readTypeInfo: true), + readAny(context: context, refMode: refMode, readTypeInfo: true), to: Any.self ) } @@ -186,7 +187,7 @@ public final class Fory { data: data ) { context in try castAnyDynamicValue( - context.readAny(refMode: refMode, readTypeInfo: true), + readAny(context: context, refMode: refMode, readTypeInfo: true), to: AnyObject.self ) } @@ -201,12 +202,13 @@ public final class Fory { @_disfavoredOverload public func deserialize(_ data: Data, as _: (any Serializer).Type = (any Serializer).self) throws - -> any Serializer { + -> any Serializer + { try deserializeRoot( data: data ) { context in try castAnyDynamicValue( - context.readAny(refMode: refMode, readTypeInfo: true), + readAny(context: context, refMode: refMode, readTypeInfo: true), to: (any Serializer).self ) } @@ -224,7 +226,7 @@ public final class Fory { try deserializeRoot( data: data ) { context in - try context.readListOfAny(refMode: refMode, readTypeInfo: true) ?? [] + try readListOfAny(context: context, refMode: refMode, readTypeInfo: true) ?? [] } } @@ -238,11 +240,12 @@ public final class Fory { @_disfavoredOverload public func deserialize(_ data: Data, as _: [String: Any].Type = [String: Any].self) throws - -> [String: Any] { + -> [String: Any] + { try deserializeRoot( data: data ) { context in - try context.readMapStringToAny(refMode: refMode, readTypeInfo: true) ?? [:] + try readMapStringToAny(context: context, refMode: refMode, readTypeInfo: true) ?? [:] } } @@ -256,11 +259,12 @@ public final class Fory { @_disfavoredOverload public func deserialize(_ data: Data, as _: [Int32: Any].Type = [Int32: Any].self) throws - -> [Int32: Any] { + -> [Int32: Any] + { try deserializeRoot( data: data ) { context in - try context.readMapInt32ToAny(refMode: refMode, readTypeInfo: true) ?? [:] + try readMapInt32ToAny(context: context, refMode: refMode, readTypeInfo: true) ?? [:] } } @@ -274,11 +278,12 @@ public final class Fory { @_disfavoredOverload public func deserialize(_ data: Data, as _: [AnyHashable: Any].Type = [AnyHashable: Any].self) - throws -> [AnyHashable: Any] { + throws -> [AnyHashable: Any] + { try deserializeRoot( data: data ) { context in - try context.readMapAnyHashableToAny(refMode: refMode, readTypeInfo: true) ?? [:] + try readMapAnyHashableToAny(context: context, refMode: refMode, readTypeInfo: true) ?? [:] } } @@ -302,7 +307,7 @@ public final class Fory { from: buffer ) { context in try castAnyDynamicValue( - context.readAny(refMode: refMode, readTypeInfo: true), + readAny(context: context, refMode: refMode, readTypeInfo: true), to: Any.self ) } @@ -317,12 +322,13 @@ public final class Fory { @_disfavoredOverload public func deserialize(from buffer: ByteBuffer, as _: AnyObject.Type = AnyObject.self) throws - -> AnyObject { + -> AnyObject + { try deserializeRoot( from: buffer ) { context in try castAnyDynamicValue( - context.readAny(refMode: refMode, readTypeInfo: true), + readAny(context: context, refMode: refMode, readTypeInfo: true), to: AnyObject.self ) } @@ -344,7 +350,7 @@ public final class Fory { from: buffer ) { context in try castAnyDynamicValue( - context.readAny(refMode: refMode, readTypeInfo: true), + readAny(context: context, refMode: refMode, readTypeInfo: true), to: (any Serializer).self ) } @@ -355,7 +361,7 @@ public final class Fory { try deserializeRoot( from: buffer ) { context in - try context.readListOfAny(refMode: refMode, readTypeInfo: true) ?? [] + try readListOfAny(context: context, refMode: refMode, readTypeInfo: true) ?? [] } } @@ -369,11 +375,12 @@ public final class Fory { @_disfavoredOverload public func deserialize(from buffer: ByteBuffer, as _: [String: Any].Type = [String: Any].self) - throws -> [String: Any] { + throws -> [String: Any] + { try deserializeRoot( from: buffer ) { context in - try context.readMapStringToAny(refMode: refMode, readTypeInfo: true) ?? [:] + try readMapStringToAny(context: context, refMode: refMode, readTypeInfo: true) ?? [:] } } @@ -395,11 +402,12 @@ public final class Fory { @_disfavoredOverload public func deserialize(from buffer: ByteBuffer, as _: [Int32: Any].Type = [Int32: Any].self) - throws -> [Int32: Any] { + throws -> [Int32: Any] + { try deserializeRoot( from: buffer ) { context in - try context.readMapInt32ToAny(refMode: refMode, readTypeInfo: true) ?? [:] + try readMapInt32ToAny(context: context, refMode: refMode, readTypeInfo: true) ?? [:] } } @@ -410,7 +418,7 @@ public final class Fory { try deserializeRoot( from: buffer ) { context in - try context.readMapAnyHashableToAny(refMode: refMode, readTypeInfo: true) ?? [:] + try readMapAnyHashableToAny(context: context, refMode: refMode, readTypeInfo: true) ?? [:] } } @@ -462,6 +470,7 @@ public final class Fory { private func readRootTypedValue( context: ReadContext ) throws -> T { + try reserveRootGraphOwner(T.self, context: context) return try T.foryRead( context, refMode: refMode, @@ -469,13 +478,26 @@ public final class Fory { ) } + @inline(__always) + private func reserveRootGraphOwner( + _: T.Type, + context: ReadContext + ) throws { + switch T.staticTypeId { + case .list, .set, .map: + try context.reserveGraphMemory(max(1, MemoryLayout.stride)) + default: + break + } + } + @inline(__always) func withReusableReadContext( data: Data, _ body: (ReadContext) throws -> R ) throws -> R { readContext.buffer.replace(with: data) - try readContext.initContainerMemoryBudgetKnown(rootBytes: data.count) + try readContext.initGraphMemoryBudgetKnown(rootBytes: data.count) defer { readContext.reset() } @@ -537,7 +559,7 @@ public final class Fory { ) throws -> R { try typeResolver.finishRegistration() readContext.buffer.swapState(with: buffer) - try readContext.initContainerMemoryBudgetKnown(rootBytes: readContext.buffer.remaining) + try readContext.initGraphMemoryBudgetKnown(rootBytes: readContext.buffer.remaining) defer { readContext.buffer.swapState(with: buffer) readContext.reset() diff --git a/swift/Sources/Fory/ReadContext.swift b/swift/Sources/Fory/ReadContext.swift index 088b623608..413bcf6b00 100644 --- a/swift/Sources/Fory/ReadContext.swift +++ b/swift/Sources/Fory/ReadContext.swift @@ -18,12 +18,11 @@ import Foundation private let typeMetaSizeMask = 0xFF -private let materializedAnyReferenceBytes = 4 public final class ReadContext { - static let knownContainerBudgetSlackBytes = 64 * 1024 - static let unknownContainerBudgetBytes = 128 * 1024 * 1024 - private static let maxKnownContainerRootBytes = (Int.max - knownContainerBudgetSlackBytes) / 8 + static let knownGraphBudgetSlackBytes = 64 * 1024 + static let unknownGraphBudgetBytes = 128 * 1024 * 1024 + private static let maxKnownGraphRootBytes = (Int.max - knownGraphBudgetSlackBytes) / 8 public let buffer: ByteBuffer let typeResolver: TypeResolver @@ -40,8 +39,8 @@ public final class ReadContext { private var typeInfoScopeStack: [(typeKey: UInt64, previousTypeInfo: TypeInfo?)] = [] private var lastTypeInfo = TypeInfo.uncached private let config: Config - private let maxContainerMemoryBytes: Int - private var remainingContainerMemoryBytes = Int.max + private let maxGraphMemoryBytes: Int + private var remainingGraphMemoryBytes = Int.max init( buffer: ByteBuffer, @@ -55,57 +54,57 @@ public final class ReadContext { self.checkClassVersion = config.checkClassVersion self.maxDepth = config.maxDepth self.config = config - self.maxContainerMemoryBytes = Int(config.maxContainerMemoryBytes) + self.maxGraphMemoryBytes = Int(config.maxGraphMemoryBytes) self.refReader = RefReader() } @inline(__always) - func initContainerMemoryBudgetKnown(rootBytes: Int) throws { - var limit = maxContainerMemoryBytes + func initGraphMemoryBudgetKnown(rootBytes: Int) throws { + var limit = maxGraphMemoryBytes if limit < 0 { - if rootBytes > Self.maxKnownContainerRootBytes { - try throwContainerMemoryOverflow() + if rootBytes > Self.maxKnownGraphRootBytes { + try throwGraphMemoryOverflow() } - limit = rootBytes * 8 + Self.knownContainerBudgetSlackBytes + limit = rootBytes * 8 + Self.knownGraphBudgetSlackBytes } - remainingContainerMemoryBytes = limit + remainingGraphMemoryBytes = limit } @inline(__always) - func reserveContainerMemory(_ bytes: Int) throws { + public func reserveGraphMemory(_ bytes: Int) throws { if bytes < 0 { - try throwContainerMemoryOverflow() + try throwGraphMemoryOverflow() } - if bytes > remainingContainerMemoryBytes { - try throwContainerMemoryExceeded(bytes: bytes) + if bytes > remainingGraphMemoryBytes { + try throwGraphMemoryExceeded(bytes: bytes) } - remainingContainerMemoryBytes -= bytes + remainingGraphMemoryBytes -= bytes } @inline(__always) - func reserveCountedContainerMemory( + func reserveCountedGraphMemory( count: Int, elementBytes: Int ) throws { if count < 0 || elementBytes < 0 { - try throwContainerMemoryOverflow() + try throwGraphMemoryOverflow() } if elementBytes != 0 && count > Int.max / elementBytes { - try throwContainerMemoryOverflow() + try throwGraphMemoryOverflow() } - try reserveContainerMemory(count * elementBytes) + try reserveGraphMemory(count * elementBytes) } @inline(never) - private func throwContainerMemoryOverflow() throws -> Never { - throw ForyError.invalidData("container memory estimate overflows") + private func throwGraphMemoryOverflow() throws -> Never { + throw ForyError.invalidData("graph memory estimate overflows") } @inline(never) - private func throwContainerMemoryExceeded(bytes: Int) throws -> Never { + private func throwGraphMemoryExceeded(bytes: Int) throws -> Never { let message = - "estimated container memory request \(bytes) bytes exceeds maxContainerMemoryBytes " + - "remaining budget \(remainingContainerMemoryBytes) bytes" + "estimated graph memory request \(bytes) bytes exceeds maxGraphMemoryBytes " + + "remaining budget \(remainingGraphMemoryBytes) bytes" throw ForyError.invalidData(message) } @@ -258,7 +257,8 @@ public final class ReadContext { "received name-registered type info for id-registered local type") } if namespace.value != localTypeInfo.namespace.value - || typeName.value != localTypeInfo.typeName.value { + || typeName.value != localTypeInfo.typeName.value + { let expectedTypeName = "\(localTypeInfo.namespace.value)::\(localTypeInfo.typeName.value)" let actualTypeName = "\(namespace.value)::\(typeName.value)" throw ForyError.invalidData( @@ -290,7 +290,8 @@ public final class ReadContext { if !checkClassVersion, compatibleTypeDefTypeInfos.isEmpty, !localTypeInfo.typeDefHasUserTypeFields, - let localTypeDefHeader = localTypeInfo.typeDefHeader { + let localTypeDefHeader = localTypeInfo.typeDefHeader + { let indexMarker = try buffer.readVarUInt32() if indexMarker == 0 { let headerStart = buffer.getCursor() @@ -413,7 +414,8 @@ public final class ReadContext { for: localTypeInfo, wireTypeID: wireTypeID) compatibleTypeDefTypeInfos.push(cachedTypeInfo) - return try validateCompatibleTypeInfo(cachedTypeInfo, for: localTypeInfo, wireTypeID: wireTypeID) + return try validateCompatibleTypeInfo( + cachedTypeInfo, for: localTypeInfo, wireTypeID: wireTypeID) } @inline(__always) @@ -424,7 +426,8 @@ public final class ReadContext { let buffer = self.buffer let compatibleTypeDefTypeInfos = self.compatibleTypeDefTypeInfos if compatibleTypeDefTypeInfos.isEmpty, - let localTypeDefHeader = localTypeInfo.typeDefHeader { + let localTypeDefHeader = localTypeInfo.typeDefHeader + { let indexMarker = try buffer.readVarUInt32() if indexMarker != 0 { return try readCompatibleTypeInfo( @@ -536,7 +539,8 @@ public final class ReadContext { return false } guard let localTypeDefBytes = localTypeInfo.typeDefBytes, - end - start == localTypeDefBytes.count else { + end - start == localTypeDefBytes.count + else { return false } return buffer.matchesBytes(start: start, bytes: localTypeDefBytes) @@ -561,7 +565,8 @@ public final class ReadContext { wireTypeID: TypeId ) throws { if let localTypeMeta = localTypeInfo.typeMeta, - remoteTypeMeta === localTypeMeta { + remoteTypeMeta === localTypeMeta + { return } if remoteTypeMeta.registerByName { @@ -603,7 +608,8 @@ public final class ReadContext { registerByName: localTypeInfo.registerByName, compatible: compatible, evolving: localTypeInfo.evolving - ) { + ) + { throw ForyError.typeMismatch(expected: wireTypeID.rawValue, actual: remoteTypeID) } } @@ -691,7 +697,7 @@ public final class ReadContext { case .float64Array: value = try readPrimitiveArray(self) as [Double] case .array, .list: - value = try readListOfAny(refMode: .none) ?? [] + value = try readListOfAny(context: self, refMode: .none) ?? [] case .set: value = try Set.foryRead(self, refMode: .none, readTypeInfo: false) case .map: @@ -765,103 +771,3 @@ public final class ReadContext { metaStrings.reset() } } - -extension ReadContext { - public func readAny( - refMode: RefMode, - readTypeInfo: Bool = true - ) throws -> Any? { - try SerializableAny.foryRead(self, refMode: refMode, readTypeInfo: readTypeInfo).anyValue() - } - - public func readListOfAny( - refMode: RefMode, - readTypeInfo: Bool = false - ) throws -> [Any]? { - let wrapped: [SerializableAny]? = try [SerializableAny]?.foryRead( - self, - refMode: refMode, - readTypeInfo: readTypeInfo - ) - guard let wrapped else { - return nil - } - try reserveCountedContainerMemory( - count: wrapped.count, - elementBytes: materializedAnyReferenceBytes - ) - return wrapped.map { $0.anyValueForCollection() } - } - - public func readMapStringToAny( - refMode: RefMode, - readTypeInfo: Bool = false - ) throws -> [String: Any]? { - let wrapped: [String: SerializableAny]? = try [String: SerializableAny]?.foryRead( - self, - refMode: refMode, - readTypeInfo: readTypeInfo - ) - guard let wrapped else { - return nil - } - try reserveCountedContainerMemory( - count: wrapped.count, - elementBytes: 2 * materializedAnyReferenceBytes - ) - var map: [String: Any] = [:] - map.reserveCapacity(wrapped.count) - for pair in wrapped { - map[pair.key] = pair.value.anyValueForCollection() - } - return map - } - - public func readMapInt32ToAny( - refMode: RefMode, - readTypeInfo: Bool = false - ) throws -> [Int32: Any]? { - let wrapped: [Int32: SerializableAny]? = try [Int32: SerializableAny]?.foryRead( - self, - refMode: refMode, - readTypeInfo: readTypeInfo - ) - guard let wrapped else { - return nil - } - try reserveCountedContainerMemory( - count: wrapped.count, - elementBytes: 2 * materializedAnyReferenceBytes - ) - var map: [Int32: Any] = [:] - map.reserveCapacity(wrapped.count) - for pair in wrapped { - map[pair.key] = pair.value.anyValueForCollection() - } - return map - } - - public func readMapAnyHashableToAny( - refMode: RefMode, - readTypeInfo: Bool = false - ) throws -> [AnyHashable: Any]? { - let wrapped: [AnyHashable: SerializableAny]? = try [AnyHashable: SerializableAny]?.foryRead( - self, - refMode: refMode, - readTypeInfo: readTypeInfo - ) - guard let wrapped else { - return nil - } - try reserveCountedContainerMemory( - count: wrapped.count, - elementBytes: 2 * materializedAnyReferenceBytes - ) - var map: [AnyHashable: Any] = [:] - map.reserveCapacity(wrapped.count) - for pair in wrapped { - map[pair.key] = pair.value.anyValueForCollection() - } - return map - } -} diff --git a/swift/Sources/ForyMacro/ForyObjectMacroReadGeneration.swift b/swift/Sources/ForyMacro/ForyObjectMacroReadGeneration.swift index 0c37178078..ff21a1401f 100644 --- a/swift/Sources/ForyMacro/ForyObjectMacroReadGeneration.swift +++ b/swift/Sources/ForyMacro/ForyObjectMacroReadGeneration.swift @@ -16,659 +16,690 @@ // under the License. func buildReadDataDecl( - isClass: Bool, - fields: [ParsedField], - sortedFields: [ParsedField], - accessPrefix: String + isClass: Bool, + fields: [ParsedField], + sortedFields: [ParsedField], + accessPrefix: String ) -> String { - if isClass { - return buildClassReadDataDecl(sortedFields: sortedFields, accessPrefix: accessPrefix) - } - if fields.isEmpty { - return buildEmptyStructReadDataDecl(accessPrefix: accessPrefix) - } - return buildStructReadDataDecl( - fields: fields, sortedFields: sortedFields, accessPrefix: accessPrefix) + if isClass { + return buildClassReadDataDecl(sortedFields: sortedFields, accessPrefix: accessPrefix) + } + if fields.isEmpty { + return buildEmptyStructReadDataDecl(accessPrefix: accessPrefix) + } + return buildStructReadDataDecl( + fields: fields, sortedFields: sortedFields, accessPrefix: accessPrefix) } func buildReadCompatibleDataDecl( - isClass: Bool, - fields: [ParsedField], - sortedFields: [ParsedField], - accessPrefix: String + isClass: Bool, + fields: [ParsedField], + sortedFields: [ParsedField], + accessPrefix: String ) -> String { - if isClass { - return buildClassReadCompatibleDataDecl(sortedFields: sortedFields, accessPrefix: accessPrefix) - } - if fields.isEmpty { - return buildEmptyStructReadCompatibleDataDecl(accessPrefix: accessPrefix) - } - return buildStructReadCompatibleDataDecl( - fields: fields, sortedFields: sortedFields, accessPrefix: accessPrefix) + if isClass { + return buildClassReadCompatibleDataDecl(sortedFields: sortedFields, accessPrefix: accessPrefix) + } + if fields.isEmpty { + return buildEmptyStructReadCompatibleDataDecl(accessPrefix: accessPrefix) + } + return buildStructReadCompatibleDataDecl( + fields: fields, sortedFields: sortedFields, accessPrefix: accessPrefix) } -func buildClassReadWrapperDecl(accessPrefix: String) -> String { - """ - @inline(__always) - \(accessPrefix)static func foryRead( - _ context: ReadContext, - refMode: RefMode, - readTypeInfo: Bool - ) throws -> Self { - let __buffer = context.buffer - let __reservedRefID: UInt32? - if refMode != .none { - let rawFlag = try __buffer.readInt8() - guard let flag = RefFlag(rawValue: rawFlag) else { - throw ForyError.refError("invalid ref flag \\(rawFlag)") - } +private func graphFieldBytesExpr(_ field: ParsedField) -> String { + if field.primitiveSize > 0 { + return "\(field.primitiveSize)" + } + return "(\(field.typeText).isRefType ? 4 : max(1, MemoryLayout<\(field.typeText)>.stride))" +} - switch flag { - case .null: - return Self.foryDefault() - case .ref: - let refID = try __buffer.readVarUInt32() - return try context.refReader.readRef(refID, as: Self.self) - case .refValue: - __reservedRefID = context.trackRef ? context.refReader.reserveRefID() : nil - case .notNullValue: - __reservedRefID = nil - } - } else { - __reservedRefID = nil - } +private func classGraphOwnerBytesExpr(_ fields: [ParsedField]) -> String { + if fields.isEmpty { + return "1" + } + return "max(1, 1 + " + fields.map(graphFieldBytesExpr).joined(separator: " + ") + ")" +} - return try Self.foryReadPayload( - context, - readTypeInfo: readTypeInfo, - readData: { - try Self.__foryReadDataImpl(context, reservedRefID: __reservedRefID) - }, - readCompatibleData: { remoteTypeInfo in - try Self.__foryReadCompatibleDataImpl( - context, - remoteTypeInfo: remoteTypeInfo, - reservedRefID: __reservedRefID - ) - } - ) - } - """ +private func reserveClassGraphOwnerLine(fields: [ParsedField], indent: String) -> String { + "\(indent)try context.reserveGraphMemory(\(classGraphOwnerBytesExpr(fields)))" +} + +private func reserveValueGraphOwnerLine(indent: String) -> String { + "\(indent)try context.reserveGraphMemory(max(1, MemoryLayout.stride))" +} + +func buildClassReadWrapperDecl(accessPrefix: String) -> String { + """ + @inline(__always) + \(accessPrefix)static func foryRead( + _ context: ReadContext, + refMode: RefMode, + readTypeInfo: Bool + ) throws -> Self { + let __buffer = context.buffer + let __reservedRefID: UInt32? + if refMode != .none { + let rawFlag = try __buffer.readInt8() + guard let flag = RefFlag(rawValue: rawFlag) else { + throw ForyError.refError("invalid ref flag \\(rawFlag)") + } + + switch flag { + case .null: + return Self.foryDefault() + case .ref: + let refID = try __buffer.readVarUInt32() + return try context.refReader.readRef(refID, as: Self.self) + case .refValue: + __reservedRefID = context.trackRef ? context.refReader.reserveRefID() : nil + case .notNullValue: + __reservedRefID = nil + } + } else { + __reservedRefID = nil + } + + return try Self.foryReadPayload( + context, + readTypeInfo: readTypeInfo, + readData: { + try Self.__foryReadDataImpl(context, reservedRefID: __reservedRefID) + }, + readCompatibleData: { remoteTypeInfo in + try Self.__foryReadCompatibleDataImpl( + context, + remoteTypeInfo: remoteTypeInfo, + reservedRefID: __reservedRefID + ) + } + ) + } + """ } private func buildClassReadDataDecl( - sortedFields: [ParsedField], - accessPrefix: String + sortedFields: [ParsedField], + accessPrefix: String ) -> String { - let primitiveFastFields = leadingPrimitiveFastPathFields(sortedFields) - let schemaAssignBody = buildClassAssignBody( - sortedFields: sortedFields, primitiveFastFields: primitiveFastFields, compatibleAligned: false) - - return """ - @inline(__always) - private static func __foryReadDataImpl(_ context: ReadContext, reservedRefID: UInt32?) throws -> Self { - let __buffer = context.buffer - \(schemaHashCheckExpr()) - let value = Self.init() - if let reservedRefID { - context.refReader.storeRef(value, at: reservedRefID) - } - \(schemaAssignBody) - return value - } + let primitiveFastFields = leadingPrimitiveFastPathFields(sortedFields) + let schemaAssignBody = buildClassAssignBody( + sortedFields: sortedFields, primitiveFastFields: primitiveFastFields, compatibleAligned: false) - @inline(__always) - \(accessPrefix)static func foryReadData(_ context: ReadContext) throws -> Self { - try Self.__foryReadDataImpl(context, reservedRefID: nil) + return """ + @inline(__always) + private static func __foryReadDataImpl(_ context: ReadContext, reservedRefID: UInt32?) throws -> Self { + let __buffer = context.buffer + \(schemaHashCheckExpr()) + \(reserveClassGraphOwnerLine(fields: sortedFields, indent: " ")) + let value = Self.init() + if let reservedRefID { + context.refReader.storeRef(value, at: reservedRefID) } - """ -} + \(schemaAssignBody) + return value + } -private func buildEmptyStructReadDataDecl(accessPrefix: String) -> String { - """ @inline(__always) \(accessPrefix)static func foryReadData(_ context: ReadContext) throws -> Self { - let __buffer = context.buffer - \(schemaHashCheckExpr()) - return Self() + try Self.__foryReadDataImpl(context, reservedRefID: nil) } """ } -private func buildStructReadDataDecl( - fields: [ParsedField], - sortedFields: [ParsedField], - accessPrefix: String -) -> String { - let primitiveFastFields = leadingPrimitiveFastPathFields(sortedFields) - let schemaReadBody = buildStructReadBody( - sortedFields: sortedFields, - primitiveFastFields: primitiveFastFields, - compatibleAligned: false - ) - let ctorArgs = buildCtorArgs(fields) - - return """ - @inline(__always) - \(accessPrefix)static func foryReadData(_ context: ReadContext) throws -> Self { - let __buffer = context.buffer - \(schemaHashCheckExpr()) - \(schemaReadBody) - return Self( - \(ctorArgs) - ) - } - """ +private func buildEmptyStructReadDataDecl(accessPrefix: String) -> String { + """ + @inline(__always) + \(accessPrefix)static func foryReadData(_ context: ReadContext) throws -> Self { + let __buffer = context.buffer + \(schemaHashCheckExpr()) + \(reserveValueGraphOwnerLine(indent: " ")) + return Self() + } + """ } -private func buildClassReadCompatibleDataDecl( - sortedFields: [ParsedField], - accessPrefix: String +private func buildStructReadDataDecl( + fields: [ParsedField], + sortedFields: [ParsedField], + accessPrefix: String ) -> String { - let primitiveFastFields = leadingPrimitiveFastPathFields(sortedFields) - let schemaAssignBody = buildClassAssignBody( - sortedFields: sortedFields, primitiveFastFields: primitiveFastFields, compatibleAligned: false) - let compatibleAlignedAssignBody = buildClassAssignBody( - sortedFields: sortedFields, - primitiveFastFields: primitiveFastFields, - compatibleAligned: true - ) - let compatibleCases = buildCompatibleReadCases( - sortedFields: sortedFields, indent: " " - ) { sortedIndex, field, valueExpr in - "case \(sortedIndex): value.\(field.name) = \(valueExpr)" + let primitiveFastFields = leadingPrimitiveFastPathFields(sortedFields) + let schemaReadBody = buildStructReadBody( + sortedFields: sortedFields, + primitiveFastFields: primitiveFastFields, + compatibleAligned: false + ) + let ctorArgs = buildCtorArgs(fields) + + return """ + @inline(__always) + \(accessPrefix)static func foryReadData(_ context: ReadContext) throws -> Self { + let __buffer = context.buffer + \(schemaHashCheckExpr()) + \(reserveValueGraphOwnerLine(indent: " ")) + \(schemaReadBody) + return Self( + \(ctorArgs) + ) } - let bufferBinding = - (schemaAssignBody.contains("__buffer") || compatibleAlignedAssignBody.contains("__buffer") - || compatibleCases.contains("__buffer")) ? "let __buffer = context.buffer\n " : "" - let localFieldsBinding = - compatibleCases.contains("__foryLocalFields") - ? "let __foryLocalFields = remoteTypeInfo.typeMeta?.fields ?? Self.foryFieldsInfo(trackRef: context.trackRef)\n " : "" - - return """ - @inline(never) - private static func __foryReadCompatibleDataImpl( - _ context: ReadContext, - remoteTypeInfo: TypeInfo, - reservedRefID: UInt32? - ) throws -> Self { - \(bufferBinding)guard let typeMeta = remoteTypeInfo.compatibleTypeMeta else { - throw ForyError.invalidData("compatible type metadata is required") - } - let value = Self.init() - if let reservedRefID { - context.refReader.storeRef(value, at: reservedRefID) - } - if let localTypeMeta = remoteTypeInfo.typeMeta, - let localHeaderHash = remoteTypeInfo.typeDefHeaderHash, - typeMeta.headerHash == localHeaderHash, - typeMeta.fields == localTypeMeta.fields { - if !remoteTypeInfo.typeDefHasUserTypeFields { - \(schemaAssignBody) - return value - } - \(compatibleAlignedAssignBody) - return value - } - \(localFieldsBinding)for remoteField in typeMeta.fields { - switch Int(remoteField.fieldID ?? -1) { - \(compatibleCases) - case -1: - try context.skipFieldValue(remoteField.fieldType) - default: - throw ForyError.invalidData("invalid compatible matched id \\(remoteField.fieldID ?? -2)") - } - } - return value - } - - @inline(never) - \(accessPrefix)static func foryReadCompatibleData(_ context: ReadContext, remoteTypeInfo: TypeInfo) throws -> Self { - try Self.__foryReadCompatibleDataImpl(context, remoteTypeInfo: remoteTypeInfo, reservedRefID: nil) - } - """ + """ } -private func buildEmptyStructReadCompatibleDataDecl(accessPrefix: String) -> String { - """ +private func buildClassReadCompatibleDataDecl( + sortedFields: [ParsedField], + accessPrefix: String +) -> String { + let primitiveFastFields = leadingPrimitiveFastPathFields(sortedFields) + let schemaAssignBody = buildClassAssignBody( + sortedFields: sortedFields, primitiveFastFields: primitiveFastFields, compatibleAligned: false) + let compatibleAlignedAssignBody = buildClassAssignBody( + sortedFields: sortedFields, + primitiveFastFields: primitiveFastFields, + compatibleAligned: true + ) + let compatibleCases = buildCompatibleReadCases( + sortedFields: sortedFields, indent: " " + ) { sortedIndex, field, valueExpr in + "case \(sortedIndex): value.\(field.name) = \(valueExpr)" + } + let bufferBinding = + (schemaAssignBody.contains("__buffer") || compatibleAlignedAssignBody.contains("__buffer") + || compatibleCases.contains("__buffer")) ? "let __buffer = context.buffer\n " : "" + let localFieldsBinding = + compatibleCases.contains("__foryLocalFields") + ? "let __foryLocalFields = remoteTypeInfo.typeMeta?.fields ?? Self.foryFieldsInfo(trackRef: context.trackRef)\n " + : "" + + return """ @inline(never) - \(accessPrefix)static func foryReadCompatibleData(_ context: ReadContext, remoteTypeInfo: TypeInfo) throws -> Self { - guard let typeMeta = remoteTypeInfo.compatibleTypeMeta else { + private static func __foryReadCompatibleDataImpl( + _ context: ReadContext, + remoteTypeInfo: TypeInfo, + reservedRefID: UInt32? + ) throws -> Self { + \(bufferBinding)guard let typeMeta = remoteTypeInfo.compatibleTypeMeta else { throw ForyError.invalidData("compatible type metadata is required") } + \(reserveClassGraphOwnerLine(fields: sortedFields, indent: " ")) + let value = Self.init() + if let reservedRefID { + context.refReader.storeRef(value, at: reservedRefID) + } if let localTypeMeta = remoteTypeInfo.typeMeta, let localHeaderHash = remoteTypeInfo.typeDefHeaderHash, typeMeta.headerHash == localHeaderHash, typeMeta.fields == localTypeMeta.fields { - return Self() + if !remoteTypeInfo.typeDefHasUserTypeFields { + \(schemaAssignBody) + return value + } + \(compatibleAlignedAssignBody) + return value } - for remoteField in typeMeta.fields { - try context.skipFieldValue(remoteField.fieldType) + \(localFieldsBinding)for remoteField in typeMeta.fields { + switch Int(remoteField.fieldID ?? -1) { + \(compatibleCases) + case -1: + try context.skipFieldValue(remoteField.fieldType) + default: + throw ForyError.invalidData("invalid compatible matched id \\(remoteField.fieldID ?? -2)") + } } - return Self() + return value + } + + @inline(never) + \(accessPrefix)static func foryReadCompatibleData(_ context: ReadContext, remoteTypeInfo: TypeInfo) throws -> Self { + try Self.__foryReadCompatibleDataImpl(context, remoteTypeInfo: remoteTypeInfo, reservedRefID: nil) } """ } +private func buildEmptyStructReadCompatibleDataDecl(accessPrefix: String) -> String { + """ + @inline(never) + \(accessPrefix)static func foryReadCompatibleData(_ context: ReadContext, remoteTypeInfo: TypeInfo) throws -> Self { + guard let typeMeta = remoteTypeInfo.compatibleTypeMeta else { + throw ForyError.invalidData("compatible type metadata is required") + } + \(reserveValueGraphOwnerLine(indent: " ")) + if let localTypeMeta = remoteTypeInfo.typeMeta, + let localHeaderHash = remoteTypeInfo.typeDefHeaderHash, + typeMeta.headerHash == localHeaderHash, + typeMeta.fields == localTypeMeta.fields { + return Self() + } + for remoteField in typeMeta.fields { + try context.skipFieldValue(remoteField.fieldType) + } + return Self() + } + """ +} + private func buildStructReadCompatibleDataDecl( - fields: [ParsedField], - sortedFields: [ParsedField], - accessPrefix: String + fields: [ParsedField], + sortedFields: [ParsedField], + accessPrefix: String ) -> String { - let primitiveFastFields = leadingPrimitiveFastPathFields(sortedFields) - let schemaReadBody = buildStructReadBody( - sortedFields: sortedFields, - primitiveFastFields: primitiveFastFields, - compatibleAligned: false - ) - let compatibleAlignedReadBody = buildStructReadBody( - sortedFields: sortedFields, - primitiveFastFields: primitiveFastFields, - compatibleAligned: true - ) - let ctorArgs = buildCtorArgs(fields) - let compatibleDefaults = buildStructCompatibleDefaults(fields) - let compatibleCases = buildCompatibleReadCases( - sortedFields: sortedFields, indent: " " - ) { sortedIndex, field, valueExpr in - "case \(sortedIndex): __\(field.name) = \(valueExpr)" - } - let changedFallbackDecl = buildStructChangedFallbackDecl( - defaults: compatibleDefaults, - cases: compatibleCases, - ctorArgs: ctorArgs - ) - let bufferBinding = - (schemaReadBody.contains("__buffer") || compatibleAlignedReadBody.contains("__buffer")) - ? "let __buffer = context.buffer\n " : "" - - return """ - \(changedFallbackDecl) + let primitiveFastFields = leadingPrimitiveFastPathFields(sortedFields) + let schemaReadBody = buildStructReadBody( + sortedFields: sortedFields, + primitiveFastFields: primitiveFastFields, + compatibleAligned: false + ) + let compatibleAlignedReadBody = buildStructReadBody( + sortedFields: sortedFields, + primitiveFastFields: primitiveFastFields, + compatibleAligned: true + ) + let ctorArgs = buildCtorArgs(fields) + let compatibleDefaults = buildStructCompatibleDefaults(fields) + let compatibleCases = buildCompatibleReadCases( + sortedFields: sortedFields, indent: " " + ) { sortedIndex, field, valueExpr in + "case \(sortedIndex): __\(field.name) = \(valueExpr)" + } + let changedFallbackDecl = buildStructChangedFallbackDecl( + defaults: compatibleDefaults, + cases: compatibleCases, + ctorArgs: ctorArgs + ) + let bufferBinding = + (schemaReadBody.contains("__buffer") || compatibleAlignedReadBody.contains("__buffer")) + ? "let __buffer = context.buffer\n " : "" + + return """ + \(changedFallbackDecl) - @inline(never) - \(accessPrefix)static func foryReadCompatibleData(_ context: ReadContext, remoteTypeInfo: TypeInfo) throws -> Self { - \(bufferBinding)guard let typeMeta = remoteTypeInfo.compatibleTypeMeta else { - throw ForyError.invalidData("compatible type metadata is required") - } - if let localTypeMeta = remoteTypeInfo.typeMeta, - let localHeaderHash = remoteTypeInfo.typeDefHeaderHash, - typeMeta.headerHash == localHeaderHash, - typeMeta.fields == localTypeMeta.fields { - if !remoteTypeInfo.typeDefHasUserTypeFields { - \(schemaReadBody) - return Self( - \(ctorArgs) - ) - } - \(compatibleAlignedReadBody) + @inline(never) + \(accessPrefix)static func foryReadCompatibleData(_ context: ReadContext, remoteTypeInfo: TypeInfo) throws -> Self { + \(bufferBinding)guard let typeMeta = remoteTypeInfo.compatibleTypeMeta else { + throw ForyError.invalidData("compatible type metadata is required") + } + \(reserveValueGraphOwnerLine(indent: " ")) + if let localTypeMeta = remoteTypeInfo.typeMeta, + let localHeaderHash = remoteTypeInfo.typeDefHeaderHash, + typeMeta.headerHash == localHeaderHash, + typeMeta.fields == localTypeMeta.fields { + if !remoteTypeInfo.typeDefHasUserTypeFields { + \(schemaReadBody) return Self( \(ctorArgs) ) } - return try Self.__foryReadChangedData( - context, - typeMeta: typeMeta + \(compatibleAlignedReadBody) + return Self( + \(ctorArgs) ) } - """ + return try Self.__foryReadChangedData( + context, + typeMeta: typeMeta + ) + } + """ } private func buildStructChangedFallbackDecl( - defaults: String, - cases: String, - ctorArgs: String + defaults: String, + cases: String, + ctorArgs: String ) -> String { - let bufferBinding = cases.contains("__buffer") ? "let __buffer = context.buffer\n " : "" - let localFieldsBinding = - cases.contains("__foryLocalFields") - ? "let __foryLocalFields = Self.foryFieldsInfo(trackRef: context.trackRef)\n " : "" - return """ - @inline(never) - private static func __foryReadChangedData( - _ context: ReadContext, - typeMeta: TypeMeta - ) throws -> Self { - \(bufferBinding) - \(defaults) - \(localFieldsBinding)for remoteField in typeMeta.fields { - switch Int(remoteField.fieldID ?? -1) { - \(cases) - case -1: - try context.skipFieldValue(remoteField.fieldType) - default: - throw ForyError.invalidData("invalid compatible matched id \\(remoteField.fieldID ?? -2)") - } + let bufferBinding = cases.contains("__buffer") ? "let __buffer = context.buffer\n " : "" + let localFieldsBinding = + cases.contains("__foryLocalFields") + ? "let __foryLocalFields = Self.foryFieldsInfo(trackRef: context.trackRef)\n " : "" + return """ + @inline(never) + private static func __foryReadChangedData( + _ context: ReadContext, + typeMeta: TypeMeta + ) throws -> Self { + \(bufferBinding) + \(defaults) + \(localFieldsBinding)for remoteField in typeMeta.fields { + switch Int(remoteField.fieldID ?? -1) { + \(cases) + case -1: + try context.skipFieldValue(remoteField.fieldType) + default: + throw ForyError.invalidData("invalid compatible matched id \\(remoteField.fieldID ?? -2)") } - return Self( - \(ctorArgs) - ) } - """ + return Self( + \(ctorArgs) + ) + } + """ } private func buildClassAssignBody( - sortedFields: [ParsedField], - primitiveFastFields: [ParsedField], - compatibleAligned: Bool + sortedFields: [ParsedField], + primitiveFastFields: [ParsedField], + compatibleAligned: Bool ) -> String { - let remainingAssignLines = sortedFields.dropFirst(primitiveFastFields.count).map { field -> String in - let valueExpr: String - if compatibleAligned { - valueExpr = compatibleSchemaReadFieldExpr(field) - } else { - valueExpr = readFieldExpr( - field, - refModeExpr: fieldRefModeExpression(field), - readTypeInfoExpr: "false" - ) - } - return "value.\(field.name) = \(valueExpr)" - } - - var sections: [String] = [] - if let primitiveReadBlock = buildPrimitiveFastClassReadBlock(primitiveFastFields) { - sections.append(primitiveReadBlock) + let remainingAssignLines = sortedFields.dropFirst(primitiveFastFields.count).map { + field -> String in + let valueExpr: String + if compatibleAligned { + valueExpr = compatibleSchemaReadFieldExpr(field) + } else { + valueExpr = readFieldExpr( + field, + refModeExpr: fieldRefModeExpression(field), + readTypeInfoExpr: "false" + ) } - if !remainingAssignLines.isEmpty { - sections.append(remainingAssignLines.joined(separator: "\n ")) - } - if sections.isEmpty { - sections.append("_ = context") - } - return sections.joined(separator: "\n ") + return "value.\(field.name) = \(valueExpr)" + } + + var sections: [String] = [] + if let primitiveReadBlock = buildPrimitiveFastClassReadBlock(primitiveFastFields) { + sections.append(primitiveReadBlock) + } + if !remainingAssignLines.isEmpty { + sections.append(remainingAssignLines.joined(separator: "\n ")) + } + if sections.isEmpty { + sections.append("_ = context") + } + return sections.joined(separator: "\n ") } private func buildStructReadBody( - sortedFields: [ParsedField], - primitiveFastFields: [ParsedField], - compatibleAligned: Bool + sortedFields: [ParsedField], + primitiveFastFields: [ParsedField], + compatibleAligned: Bool ) -> String { - let remainingReadLines = sortedFields.dropFirst(primitiveFastFields.count).map { field -> String in - let valueExpr = - compatibleAligned ? compatibleSchemaReadFieldExpr(field) : schemaReadFieldExpr(field) - return "let __\(field.name) = \(valueExpr)" - } - - var sections: [String] = [] - if let primitiveDeclarations = buildPrimitiveFastStructReadDeclarations(primitiveFastFields) { - sections.append(primitiveDeclarations) - } - if let primitiveReadBlock = buildPrimitiveFastStructReadBlock(primitiveFastFields) { - sections.append(primitiveReadBlock) - } - if !remainingReadLines.isEmpty { - sections.append(remainingReadLines.joined(separator: "\n ")) - } - return sections.joined(separator: "\n ") + let remainingReadLines = sortedFields.dropFirst(primitiveFastFields.count).map { + field -> String in + let valueExpr = + compatibleAligned ? compatibleSchemaReadFieldExpr(field) : schemaReadFieldExpr(field) + return "let __\(field.name) = \(valueExpr)" + } + + var sections: [String] = [] + if let primitiveDeclarations = buildPrimitiveFastStructReadDeclarations(primitiveFastFields) { + sections.append(primitiveDeclarations) + } + if let primitiveReadBlock = buildPrimitiveFastStructReadBlock(primitiveFastFields) { + sections.append(primitiveReadBlock) + } + if !remainingReadLines.isEmpty { + sections.append(remainingReadLines.joined(separator: "\n ")) + } + return sections.joined(separator: "\n ") } private func buildCtorArgs(_ fields: [ParsedField]) -> String { - fields - .sorted(by: { $0.originalIndex < $1.originalIndex }) - .map { "\($0.name): __\($0.name)" } - .joined(separator: ",\n ") + fields + .sorted(by: { $0.originalIndex < $1.originalIndex }) + .map { "\($0.name): __\($0.name)" } + .joined(separator: ",\n ") } private func buildStructCompatibleDefaults(_ fields: [ParsedField]) -> String { - fields - .sorted(by: { $0.originalIndex < $1.originalIndex }) - .map(compatibleDefaultDecl) - .joined(separator: "\n ") + fields + .sorted(by: { $0.originalIndex < $1.originalIndex }) + .map(compatibleDefaultDecl) + .joined(separator: "\n ") } private func schemaHashCheckExpr(indent: String = " ") -> String { - """ - \(indent)if context.checkClassVersion { - \(indent) let __schemaHash = UInt32(bitPattern: try __buffer.readInt32()) - \(indent) let __expectedHash = Self.__forySchemaHash(context.trackRef) - \(indent) if __schemaHash != __expectedHash { - \(indent) throw ForyError.invalidData("class version hash mismatch: expected \\(__expectedHash), got \\(__schemaHash)") - \(indent) } - \(indent)} - """ + """ + \(indent)if context.checkClassVersion { + \(indent) let __schemaHash = UInt32(bitPattern: try __buffer.readInt32()) + \(indent) let __expectedHash = Self.__forySchemaHash(context.trackRef) + \(indent) if __schemaHash != __expectedHash { + \(indent) throw ForyError.invalidData("class version hash mismatch: expected \\(__expectedHash), got \\(__schemaHash)") + \(indent) } + \(indent)} + """ } private func buildCompatibleReadCases( - sortedFields: [ParsedField], - indent: String, - assignCase: (Int, ParsedField, String) -> String + sortedFields: [ParsedField], + indent: String, + assignCase: (Int, ParsedField, String) -> String ) -> String { - sortedFields.enumerated().map { sortedIndex, field -> String in - let directValueExpr = compatibleSchemaReadFieldExpr(field) - let compatibleValueExpr = readFieldExpr( - field, - refModeExpr: - "RefMode.from(nullable: remoteField.fieldType.nullable, trackRef: remoteField.fieldType.trackRef)", - readTypeInfoExpr: - "TypeId.needsTypeInfoForField(TypeId(rawValue: remoteField.fieldType.typeID) ?? .unknown)" - ) - let compatibleCaseExpr = compatibleScalarReadExpr( - field, - sortedIndex: sortedIndex, - compatibleValueExpr: compatibleValueExpr - ) - return [ - assignCase(sortedIndex * 2, field, directValueExpr), - assignCase(sortedIndex * 2 + 1, field, compatibleCaseExpr) - ].joined(separator: "\n\(indent)") - }.joined(separator: "\n\(indent)") + sortedFields.enumerated().map { sortedIndex, field -> String in + let directValueExpr = compatibleSchemaReadFieldExpr(field) + let compatibleValueExpr = readFieldExpr( + field, + refModeExpr: + "RefMode.from(nullable: remoteField.fieldType.nullable, trackRef: remoteField.fieldType.trackRef)", + readTypeInfoExpr: + "TypeId.needsTypeInfoForField(TypeId(rawValue: remoteField.fieldType.typeID) ?? .unknown)" + ) + let compatibleCaseExpr = compatibleScalarReadExpr( + field, + sortedIndex: sortedIndex, + compatibleValueExpr: compatibleValueExpr + ) + return [ + assignCase(sortedIndex * 2, field, directValueExpr), + assignCase(sortedIndex * 2 + 1, field, compatibleCaseExpr), + ].joined(separator: "\n\(indent)") + }.joined(separator: "\n\(indent)") } private func compatibleScalarReadExpr( - _ field: ParsedField, - sortedIndex: Int, - compatibleValueExpr: String + _ field: ParsedField, + sortedIndex: Int, + compatibleValueExpr: String ) -> String { - guard - field.dynamicAnyCodec == nil, - let helperTarget = compatibleScalarReaderTarget(field) - else { - return compatibleValueExpr - } - let helperName = - field.isOptional - ? "foryReadCompatibleOptional\(helperTarget)Field" - : "foryReadCompatible\(helperTarget)Field" - return """ - try \(helperName)( - context, - remoteField: remoteField, - localField: __foryLocalFields[\(sortedIndex)] - ) - """ + guard + field.dynamicAnyCodec == nil, + let helperTarget = compatibleScalarReaderTarget(field) + else { + return compatibleValueExpr + } + let helperName = + field.isOptional + ? "foryReadCompatibleOptional\(helperTarget)Field" + : "foryReadCompatible\(helperTarget)Field" + return """ + try \(helperName)( + context, + remoteField: remoteField, + localField: __foryLocalFields[\(sortedIndex)] + ) + """ } private func compatibleScalarReaderTarget(_ field: ParsedField) -> String? { - guard compatibleScalarTypeID(field.typeID) else { - return nil - } - switch compatibleScalarPayloadType(field.typeText) { - case "Bool": - return "Bool" - case "Int8": - return "Int8" - case "Int16": - return "Int16" - case "Int32": - return "Int32" - case "Int64": - return "Int64" - case "Int": - return "Int" - case "UInt8": - return "UInt8" - case "UInt16": - return "UInt16" - case "UInt32": - return "UInt32" - case "UInt64": - return "UInt64" - case "UInt": - return "UInt" - case "Float16": - return "Float16" - case "BFloat16": - return "BFloat16" - case "Float": - return "Float" - case "Double": - return "Double" - case "String": - return "String" - case "Decimal": - return "Decimal" - default: - return nil - } + guard compatibleScalarTypeID(field.typeID) else { + return nil + } + switch compatibleScalarPayloadType(field.typeText) { + case "Bool": + return "Bool" + case "Int8": + return "Int8" + case "Int16": + return "Int16" + case "Int32": + return "Int32" + case "Int64": + return "Int64" + case "Int": + return "Int" + case "UInt8": + return "UInt8" + case "UInt16": + return "UInt16" + case "UInt32": + return "UInt32" + case "UInt64": + return "UInt64" + case "UInt": + return "UInt" + case "Float16": + return "Float16" + case "BFloat16": + return "BFloat16" + case "Float": + return "Float" + case "Double": + return "Double" + case "String": + return "String" + case "Decimal": + return "Decimal" + default: + return nil + } } private func compatibleScalarPayloadType(_ typeText: String) -> String { - var type = trimType(typeText) - if type.hasSuffix("?") { - type.removeLast() - } else if type.hasPrefix("Optional<"), type.hasSuffix(">") { - type = String(type.dropFirst("Optional<".count).dropLast()) - } - for prefix in ["Swift.", "Foundation.", "Fory."] where type.hasPrefix(prefix) { - return String(type.dropFirst(prefix.count)) - } - return type + var type = trimType(typeText) + if type.hasSuffix("?") { + type.removeLast() + } else if type.hasPrefix("Optional<"), type.hasSuffix(">") { + type = String(type.dropFirst("Optional<".count).dropLast()) + } + for prefix in ["Swift.", "Foundation.", "Fory."] where type.hasPrefix(prefix) { + return String(type.dropFirst(prefix.count)) + } + return type } private func compatibleScalarTypeID(_ typeID: UInt32) -> Bool { - switch typeID { - case 1...15, 17...21, 40: - return true - default: - return false - } + switch typeID { + case 1...15, 17...21, 40: + return true + default: + return false + } } private func swiftStringLiteral(_ value: String) -> String { - let escaped = - value - .replacingOccurrences(of: "\\", with: "\\\\") - .replacingOccurrences(of: "\"", with: "\\\"") - return "\"\(escaped)\"" + let escaped = + value + .replacingOccurrences(of: "\\", with: "\\\\") + .replacingOccurrences(of: "\"", with: "\\\"") + return "\"\(escaped)\"" } private func readFieldExpr( - _ field: ParsedField, - refModeExpr: String, - readTypeInfoExpr: String + _ field: ParsedField, + refModeExpr: String, + readTypeInfoExpr: String ) -> String { - if let dynamicAnyCodec = field.dynamicAnyCodec { - return dynamicAnyReadExpr( - field: field, - dynamicAnyCodec: dynamicAnyCodec, - refModeExpr: refModeExpr + if let dynamicAnyCodec = field.dynamicAnyCodec { + return dynamicAnyReadExpr( + field: field, + dynamicAnyCodec: dynamicAnyCodec, + refModeExpr: refModeExpr + ) + } + if let codecType = field.customCodecType { + let fieldCodec = field.isOptional ? "OptionalFieldCodec<\(codecType)>" : codecType + if readTypeInfoExpr.contains("remoteField.fieldType") { + return """ + try \(fieldCodec).readCompatibleField( + context, + remoteFieldType: remoteField.fieldType, + refMode: \(refModeExpr) ) + """ } - if let codecType = field.customCodecType { - let fieldCodec = field.isOptional ? "OptionalFieldCodec<\(codecType)>" : codecType - if readTypeInfoExpr.contains("remoteField.fieldType") { - return """ - try \(fieldCodec).readCompatibleField( - context, - remoteFieldType: remoteField.fieldType, - refMode: \(refModeExpr) - ) - """ - } - return "try \(fieldCodec).read(context, refMode: \(refModeExpr), readTypeInfo: false)" - } - return - "try \(field.typeText).foryRead(context, refMode: \(refModeExpr), readTypeInfo: \(readTypeInfoExpr))" + return "try \(fieldCodec).read(context, refMode: \(refModeExpr), readTypeInfo: false)" + } + return + "try \(field.typeText).foryRead(context, refMode: \(refModeExpr), readTypeInfo: \(readTypeInfoExpr))" } private func schemaReadFieldExpr(_ field: ParsedField) -> String { - if fieldNeedsGeneralSchemaRead(field) { - return readFieldExpr( - field, - refModeExpr: fieldRefModeExpression(field), - readTypeInfoExpr: "false" - ) - } - if let primitiveExpr = primitiveSchemaReadExpr(field) { - return primitiveExpr - } - return "try \(field.typeText).foryReadData(context)" + if fieldNeedsGeneralSchemaRead(field) { + return readFieldExpr( + field, + refModeExpr: fieldRefModeExpression(field), + readTypeInfoExpr: "false" + ) + } + if let primitiveExpr = primitiveSchemaReadExpr(field) { + return primitiveExpr + } + return "try \(field.typeText).foryReadData(context)" } private func compatibleSchemaReadFieldExpr(_ field: ParsedField) -> String { - if fieldNeedsGeneralCompatibleRead(field) { - return readFieldExpr( - field, - refModeExpr: fieldRefModeExpression(field), - readTypeInfoExpr: "TypeId.needsTypeInfoForField(\(field.typeText).staticTypeId)" - ) - } - if let primitiveExpr = primitiveSchemaReadExpr(field) { - return primitiveExpr - } - return "try \(field.typeText).foryReadData(context)" + if fieldNeedsGeneralCompatibleRead(field) { + return readFieldExpr( + field, + refModeExpr: fieldRefModeExpression(field), + readTypeInfoExpr: "TypeId.needsTypeInfoForField(\(field.typeText).staticTypeId)" + ) + } + if let primitiveExpr = primitiveSchemaReadExpr(field) { + return primitiveExpr + } + return "try \(field.typeText).foryReadData(context)" } private func primitiveSchemaReadExpr(_ field: ParsedField) -> String? { - let type = trimType(field.typeText) - switch type { - case "Bool": - return "try __buffer.readUInt8() != 0" - case "Int8": - return "try __buffer.readInt8()" - case "Int16": - return "try __buffer.readInt16()" - case "Int32": - return "try __buffer.readVarInt32()" - case "Int64": - return "try __buffer.readVarInt64()" - case "Int": - return "Int(try __buffer.readVarInt64())" - case "UInt8": - return "try __buffer.readUInt8()" - case "UInt16": - return "try __buffer.readUInt16()" - case "UInt32": - return "try __buffer.readVarUInt32()" - case "UInt64": - return "try __buffer.readVarUInt64()" - case "UInt": - return "UInt(try __buffer.readVarUInt64())" - case "Float": - return "try __buffer.readFloat32()" - case "Double": - return "try __buffer.readFloat64()" - default: - return nil - } + let type = trimType(field.typeText) + switch type { + case "Bool": + return "try __buffer.readUInt8() != 0" + case "Int8": + return "try __buffer.readInt8()" + case "Int16": + return "try __buffer.readInt16()" + case "Int32": + return "try __buffer.readVarInt32()" + case "Int64": + return "try __buffer.readVarInt64()" + case "Int": + return "Int(try __buffer.readVarInt64())" + case "UInt8": + return "try __buffer.readUInt8()" + case "UInt16": + return "try __buffer.readUInt16()" + case "UInt32": + return "try __buffer.readVarUInt32()" + case "UInt64": + return "try __buffer.readVarUInt64()" + case "UInt": + return "UInt(try __buffer.readVarUInt64())" + case "Float": + return "try __buffer.readFloat32()" + case "Double": + return "try __buffer.readFloat64()" + default: + return nil + } } private func dynamicAnyReadExpr( - field: ParsedField, - dynamicAnyCodec: DynamicAnyCodecKind, - refModeExpr: String + field: ParsedField, + dynamicAnyCodec: DynamicAnyCodecKind, + refModeExpr: String ) -> String { - let metatypeExpr = "(\(field.typeText)).self" - let method = dynamicAnyReadMethodName(dynamicAnyCodec) - let readTypeInfoExpr = - dynamicAnyReadsTypeInfo(dynamicAnyCodec) - ? ", readTypeInfo: true" - : "" - return - "try castAnyDynamicValue(context.\(method)(refMode: \(refModeExpr)\(readTypeInfoExpr)), to: \(metatypeExpr))" + let metatypeExpr = "(\(field.typeText)).self" + let method = dynamicAnyReadMethodName(dynamicAnyCodec) + let readTypeInfoExpr = + dynamicAnyReadsTypeInfo(dynamicAnyCodec) + ? ", readTypeInfo: true" + : "" + return + "try castAnyDynamicValue(\(method)(context: context, refMode: \(refModeExpr)\(readTypeInfoExpr)), to: \(metatypeExpr))" } private func compatibleDefaultDecl(_ field: ParsedField) -> String { - let explicitType = - (field.dynamicAnyCodec != nil || field.customCodecType != nil) ? ": \(field.typeText)" : "" - return "var __\(field.name)\(explicitType) = \(fieldDefaultExpr(field))" + let explicitType = + (field.dynamicAnyCodec != nil || field.customCodecType != nil) ? ": \(field.typeText)" : "" + return "var __\(field.name)\(explicitType) = \(fieldDefaultExpr(field))" } private func fieldNeedsGeneralSchemaRead(_ field: ParsedField) -> Bool { - field.dynamicAnyCodec != nil || field.customCodecType != nil || field.isOptional - || field.typeID == 27 + field.dynamicAnyCodec != nil || field.customCodecType != nil || field.isOptional + || field.typeID == 27 } private func fieldNeedsGeneralCompatibleRead(_ field: ParsedField) -> Bool { - fieldNeedsGeneralSchemaRead(field) || compatibleFieldNeedsTypeInfo(field) + fieldNeedsGeneralSchemaRead(field) || compatibleFieldNeedsTypeInfo(field) } diff --git a/swift/Tests/ForyTests/ContainerMemoryBudgetTests.swift b/swift/Tests/ForyTests/ContainerMemoryBudgetTests.swift deleted file mode 100644 index a9b9f15081..0000000000 --- a/swift/Tests/ForyTests/ContainerMemoryBudgetTests.swift +++ /dev/null @@ -1,307 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -import Foundation -import Testing -@testable import Fory - -@ForyStruct -private final class BudgetNode { - var id: Int32 = 0 - - required init() {} - - init(id: Int32) { - self.id = id - } -} - -@ForyStruct -private struct BudgetSiblings { - var left: [BudgetNode] = [] - var right: [BudgetNode] = [] -} - -@ForyStruct -private struct BudgetDenseHolder: Equatable { - var text: String = "" - var data: Data = Data() - @ArrayField(element: .int32()) - var dense: [Int32] = [] -} - -private func makeBudgetFory(maxContainerMemoryBytes: Int64 = -1) -> Fory { - let fory = Fory(config: .init( - trackRef: false, - compatible: false, - maxContainerMemoryBytes: maxContainerMemoryBytes - )) - fory.register(BudgetNode.self, id: 9801) - fory.register(BudgetSiblings.self, id: 9802) - fory.register(BudgetDenseHolder.self, id: 9803) - return fory -} - -private let testReferenceBytes = 4 - -private func elementBytes(_ type: Element.Type) -> Int { - type.isRefType ? testReferenceBytes : max(1, MemoryLayout.stride) -} - -private func arrayBudget(_ type: Element.Type, count: Int) -> Int { - count * elementBytes(type) -} - -private func mapBudget( - key: Key.Type, - value: Value.Type, - count: Int -) -> Int { - count * (elementBytes(key) + elementBytes(value)) -} - -private func expectInvalidData(_ body: () throws -> Void) { - do { - try body() - Issue.record("expected invalid data") - } catch ForyError.invalidData { - } catch { - Issue.record("expected invalid data, got \(error)") - } -} - -@Test -func knownLengthAutoBudgetUsesInputBytes() throws { - let expected = 17 * 8 + ReadContext.knownContainerBudgetSlackBytes - let config = Config(trackRef: false, compatible: false) - let context = ReadContext( - buffer: ByteBuffer(), - typeResolver: TypeResolver(config: config), - config: config - ) - - try context.initContainerMemoryBudgetKnown(rootBytes: 17) - try context.reserveContainerMemory(expected) - expectInvalidData { - try context.reserveContainerMemory(testReferenceBytes) - } -} - -@Test -func byteBufferRootUsesKnownLengthAutoBudget() throws { - let count = 6 - let value = Array(repeating: [String](), count: count) - let bytes = try makeBudgetFory().serialize(value) - let buffer = ByteBuffer(data: bytes) - - let decoded: [[String]] = try makeBudgetFory().deserialize(from: buffer) - #expect(decoded.count == count) -} - -@Test -func explicitConfigOverridesAutoBudget() throws { - let values = (0..<16).map(Int32.init) - let bytes = try makeBudgetFory().serialize(values) - let required = arrayBudget(Int32.self, count: values.count) - - expectInvalidData { - let _: [Int32] = try makeBudgetFory(maxContainerMemoryBytes: Int64(required - 1)).deserialize(bytes) - } - let decoded: [Int32] = try makeBudgetFory(maxContainerMemoryBytes: Int64(required)).deserialize(bytes) - #expect(decoded == values) -} - -@Test -func siblingContainersShareOneBudget() throws { - let value = BudgetSiblings( - left: (0..<16).map { BudgetNode(id: Int32($0)) }, - right: (16..<32).map { BudgetNode(id: Int32($0)) } - ) - let bytes = try makeBudgetFory().serialize(value) - let oneList = arrayBudget(BudgetNode.self, count: 16) - - expectInvalidData { - let _: BudgetSiblings = try makeBudgetFory(maxContainerMemoryBytes: Int64(oneList)).deserialize(bytes) - } - let decoded: BudgetSiblings = try makeBudgetFory(maxContainerMemoryBytes: Int64(oneList * 2)).deserialize(bytes) - #expect(decoded.left.count == 16) - #expect(decoded.right.count == 16) -} - -@Test -func mapBudgetIsCharged() throws { - let value: [String: Int32] = ["a": 1, "b": 2, "c": 3] - let bytes = try makeBudgetFory().serialize(value) - let required = mapBudget(key: String.self, value: Int32.self, count: value.count) - - expectInvalidData { - let _: [String: Int32] = try makeBudgetFory(maxContainerMemoryBytes: Int64(required - 1)).deserialize(bytes) - } - let decoded: [String: Int32] = try makeBudgetFory(maxContainerMemoryBytes: Int64(required)).deserialize(bytes) - #expect(decoded == value) -} - -@Test -func referenceAndInlineValueArraysAreCharged() throws { - let nodes = (0..<4).map { BudgetNode(id: Int32($0)) } - let nodeBytes = try makeBudgetFory().serialize(nodes) - let nodeBudget = arrayBudget(BudgetNode.self, count: nodes.count) - expectInvalidData { - let _: [BudgetNode] = try makeBudgetFory(maxContainerMemoryBytes: Int64(nodeBudget - 1)).deserialize(nodeBytes) - } - let decodedNodes: [BudgetNode] = try makeBudgetFory(maxContainerMemoryBytes: Int64(nodeBudget)).deserialize(nodeBytes) - #expect(decodedNodes.count == nodes.count) - - let ints: [Int32] = [1, 2, 3, 4] - let intBytes = try makeBudgetFory().serialize(ints) - let intBudget = arrayBudget(Int32.self, count: ints.count) - expectInvalidData { - let _: [Int32] = try makeBudgetFory(maxContainerMemoryBytes: Int64(intBudget - 1)).deserialize(intBytes) - } - #expect(try makeBudgetFory(maxContainerMemoryBytes: Int64(intBudget)).deserialize(intBytes) as [Int32] == ints) -} - -@Test -func stringBinaryAndPrimitiveDenseArrayOwnersAreSkipped() throws { - let value = BudgetDenseHolder( - text: "budget", - data: Data([1, 2, 3]), - dense: [1, 2, 3] - ) - let bytes = try makeBudgetFory().serialize(value) - - let decoded: BudgetDenseHolder = try makeBudgetFory(maxContainerMemoryBytes: 1).deserialize(bytes) - #expect(decoded == value) -} - -@Test -func dynamicAnyEmptyMapHasNoDynamicStorage() throws { - let value = [:] as [AnyHashable: Any] - let bytes = try makeBudgetFory().serialize(value as Any) - - let decoded: Any = try makeBudgetFory(maxContainerMemoryBytes: 1) - .deserialize(bytes) - #expect((decoded as? [String: Any])?.isEmpty == true) -} - -@Test -func publicAnyArrayBudget() throws { - let value: [Any] = [Int32(1), Int32(2), Int32(3)] - let bytes = try makeBudgetFory().serialize(value) - let wrappedBudget = arrayBudget(SerializableAny.self, count: value.count) - let finalBudget = value.count * testReferenceBytes - - expectInvalidData { - let _: [Any] = try makeBudgetFory(maxContainerMemoryBytes: Int64(wrappedBudget)) - .deserialize(bytes, as: [Any].self) - } - let decoded = try makeBudgetFory(maxContainerMemoryBytes: Int64(wrappedBudget + finalBudget)) - .deserialize(bytes, as: [Any].self) - #expect(decoded.count == value.count) -} - -@Test -func publicAnyMapBudget() throws { - let stringMap: [String: Any] = ["a": Int32(1), "b": Int32(2), "c": Int32(3)] - let stringBytes = try makeBudgetFory().serialize(stringMap) - let stringWrapped = mapBudget( - key: String.self, - value: SerializableAny.self, - count: stringMap.count - ) - let stringFinal = stringMap.count * 2 * testReferenceBytes - expectInvalidData { - let _: [String: Any] = try makeBudgetFory(maxContainerMemoryBytes: Int64(stringWrapped)) - .deserialize(stringBytes, as: [String: Any].self) - } - let decodedString = try makeBudgetFory(maxContainerMemoryBytes: Int64(stringWrapped + stringFinal)) - .deserialize(stringBytes, as: [String: Any].self) - #expect(decodedString.count == stringMap.count) - - let intMap: [Int32: Any] = [1: Int32(10), 2: Int32(20), 3: Int32(30)] - let intBytes = try makeBudgetFory().serialize(intMap) - let intWrapped = mapBudget( - key: Int32.self, - value: SerializableAny.self, - count: intMap.count - ) - let intFinal = intMap.count * 2 * testReferenceBytes - expectInvalidData { - let _: [Int32: Any] = try makeBudgetFory(maxContainerMemoryBytes: Int64(intWrapped)) - .deserialize(intBytes, as: [Int32: Any].self) - } - let decodedInt = try makeBudgetFory(maxContainerMemoryBytes: Int64(intWrapped + intFinal)) - .deserialize(intBytes, as: [Int32: Any].self) - #expect(decodedInt.count == intMap.count) - - let anyHashableMap: [AnyHashable: Any] = [ - AnyHashable("a"): Int32(1), - AnyHashable(Int32(2)): Int32(2), - AnyHashable(true): Int32(3) - ] - let anyHashableBytes = try makeBudgetFory().serialize(anyHashableMap) - let anyHashableWrapped = mapBudget( - key: AnyHashable.self, - value: SerializableAny.self, - count: anyHashableMap.count - ) - let anyHashableFinal = anyHashableMap.count * 2 * testReferenceBytes - expectInvalidData { - let _: [AnyHashable: Any] = try makeBudgetFory( - maxContainerMemoryBytes: Int64(anyHashableWrapped) - ).deserialize(anyHashableBytes, as: [AnyHashable: Any].self) - } - let decodedAnyHashable = try makeBudgetFory( - maxContainerMemoryBytes: Int64(anyHashableWrapped + anyHashableFinal) - ).deserialize(anyHashableBytes, as: [AnyHashable: Any].self) - #expect(decodedAnyHashable.count == anyHashableMap.count) -} - -@Test -func dynamicAnyArrayBudget() throws { - let list: [Any] = [Int32(1), "two", Int32(3)] - let value: Any = list - let bytes = try makeBudgetFory().serialize(value) - let count = list.count - let wrappedBudget = arrayBudget(SerializableAny.self, count: count) - let finalBudget = count * testReferenceBytes - - expectInvalidData { - let _: Any = try makeBudgetFory(maxContainerMemoryBytes: Int64(wrappedBudget)) - .deserialize(bytes, as: Any.self) - } - let decoded = try makeBudgetFory(maxContainerMemoryBytes: Int64(wrappedBudget + finalBudget)) - .deserialize(bytes, as: Any.self) - #expect((decoded as? [Any])?.count == count) -} - -@Test -func byteAvailabilityCheckStillRejectsLargeLength() throws { - let buffer = ByteBuffer() - buffer.writeVarUInt32(64) - buffer.writeUInt8(CollectionHeader.sameType | CollectionHeader.declaredElementType) - let config = Config(trackRef: false, compatible: false) - let context = ReadContext( - buffer: buffer, - typeResolver: TypeResolver(config: config), - config: config - ) - - expectInvalidData { - let _: [String] = try [String].foryReadData(context) - } -} diff --git a/swift/Tests/ForyTests/ForySwiftTests.swift b/swift/Tests/ForyTests/ForySwiftTests.swift index c81b70b30f..303f0dfc6c 100644 --- a/swift/Tests/ForyTests/ForySwiftTests.swift +++ b/swift/Tests/ForyTests/ForySwiftTests.swift @@ -326,7 +326,7 @@ func floatingSpecialsRoundTrip() throws { -.infinity, .leastNonzeroMagnitude, .greatestFiniteMagnitude, - Float(bitPattern: 0x7FC0_1234) + Float(bitPattern: 0x7FC0_1234), ] for value in floatValues { let decoded: Float = try fory.deserialize(try fory.serialize(value)) @@ -340,7 +340,7 @@ func floatingSpecialsRoundTrip() throws { -.infinity, .leastNonzeroMagnitude, .greatestFiniteMagnitude, - Double(bitPattern: 0x7FF8_0000_0000_1234) + Double(bitPattern: 0x7FF8_0000_0000_1234), ] for value in doubleValues { let decoded: Double = try fory.deserialize(try fory.serialize(value)) @@ -354,7 +354,7 @@ func floatingSpecialsRoundTrip() throws { .init(bitPattern: 0xFC00), .init(bitPattern: 0x0001), .init(bitPattern: 0x7BFF), - .init(bitPattern: 0x7E11) + .init(bitPattern: 0x7E11), ] for value in float16Values { let decoded: Float16 = try fory.deserialize(try fory.serialize(value)) @@ -367,7 +367,7 @@ func floatingSpecialsRoundTrip() throws { .init(rawValue: 0x7F80), .init(rawValue: 0xFF80), .init(rawValue: 0x0001), - .init(rawValue: 0x7FC1) + .init(rawValue: 0x7FC1), ] for value in bfloat16Values { let decoded: BFloat16 = try fory.deserialize(try fory.serialize(value)) @@ -382,7 +382,7 @@ func namedInitializerBuildsConfig() { #expect(defaultConfig.config.compatible == true) #expect(defaultConfig.config.checkClassVersion == false) #expect(defaultConfig.config.maxDepth == 5) - #expect(defaultConfig.config.maxContainerMemoryBytes == -1) + #expect(defaultConfig.config.maxGraphMemoryBytes == -1) #expect(defaultConfig.config.maxTypeFields == 512) #expect(defaultConfig.config.maxTypeMetaBytes == 4096) #expect(defaultConfig.config.maxSchemaVersionsPerType == 10) @@ -392,7 +392,7 @@ func namedInitializerBuildsConfig() { ref: true, compatible: true, maxDepth: 7, - maxContainerMemoryBytes: 65_536, + maxGraphMemoryBytes: 65_536, maxTypeFields: 31, maxTypeMetaBytes: 1234, maxSchemaVersionsPerType: 12, @@ -402,7 +402,7 @@ func namedInitializerBuildsConfig() { #expect(explicitConfig.config.compatible == true) #expect(explicitConfig.config.checkClassVersion == false) #expect(explicitConfig.config.maxDepth == 7) - #expect(explicitConfig.config.maxContainerMemoryBytes == 65_536) + #expect(explicitConfig.config.maxGraphMemoryBytes == 65_536) #expect(explicitConfig.config.maxTypeFields == 31) #expect(explicitConfig.config.maxTypeMetaBytes == 1234) #expect(explicitConfig.config.maxSchemaVersionsPerType == 12) @@ -413,7 +413,7 @@ func namedInitializerBuildsConfig() { trackRef: false, compatible: true, maxDepth: 9, - maxContainerMemoryBytes: 131_072, + maxGraphMemoryBytes: 131_072, maxTypeFields: 41, maxTypeMetaBytes: 2048, maxSchemaVersionsPerType: 14, @@ -423,7 +423,7 @@ func namedInitializerBuildsConfig() { #expect(configInit.config.compatible == true) #expect(configInit.config.checkClassVersion == false) #expect(configInit.config.maxDepth == 9) - #expect(configInit.config.maxContainerMemoryBytes == 131_072) + #expect(configInit.config.maxGraphMemoryBytes == 131_072) #expect(configInit.config.maxTypeFields == 41) #expect(configInit.config.maxTypeMetaBytes == 2048) #expect(configInit.config.maxSchemaVersionsPerType == 14) @@ -544,7 +544,7 @@ func typeMetaFieldLimitRejectsLargeStruct() throws { registerByName: false, fields: [ TypeMeta.FieldInfo(fieldID: nil, fieldName: "first", fieldType: fieldType), - TypeMeta.FieldInfo(fieldID: nil, fieldName: "second", fieldType: fieldType) + TypeMeta.FieldInfo(fieldID: nil, fieldName: "second", fieldType: fieldType), ] ) let encoded = try meta.encode() @@ -772,22 +772,24 @@ func failedSchemaDoesNotConsumeLimit() throws { } #expect(throws: (any Error).self) { - try cache(remoteTypeMeta( - fieldName: "id", - fieldType: TypeMeta.FieldType( - typeID: TypeId.map.rawValue, - nullable: false, - generics: [ - TypeMeta.FieldType(typeID: TypeId.string.rawValue, nullable: false), - TypeMeta.FieldType(typeID: TypeId.int32.rawValue, nullable: false) - ] - ) - )) + try cache( + remoteTypeMeta( + fieldName: "id", + fieldType: TypeMeta.FieldType( + typeID: TypeId.map.rawValue, + nullable: false, + generics: [ + TypeMeta.FieldType(typeID: TypeId.string.rawValue, nullable: false), + TypeMeta.FieldType(typeID: TypeId.int32.rawValue, nullable: false), + ] + ) + )) } - try cache(remoteTypeMeta( - fieldName: "remoteA", - fieldType: TypeMeta.FieldType(typeID: TypeId.int32.rawValue, nullable: false) - )) + try cache( + remoteTypeMeta( + fieldName: "remoteA", + fieldType: TypeMeta.FieldType(typeID: TypeId.int32.rawValue, nullable: false) + )) } @Test @@ -1148,7 +1150,7 @@ func macroDynamicAnyObjectAndAnySerializerFieldsRoundTrip() throws { items: [Int32(11), Address(street: "Nested", zip: 10002)], map: [ "age": Int64(19), - "address": Address(street: "Mapped", zip: 10003) + "address": Address(street: "Mapped", zip: 10003), ] ) let serializerData = try fory.serialize(serializerHolder) @@ -1201,13 +1203,13 @@ func macroAnyFieldsRoundTrip() throws { "count": Int64(3), "name": "map", "address": Address(street: "AnyMap", zip: 11003), - "empty": NSNull() + "empty": NSNull(), ], int32Map: [ 1: Int32(-9), 2: "v2", 3: Address(street: "AnyIntMap", zip: 11004), - 4: NSNull() + 4: NSNull(), ] ) let data = try fory.serialize(value) @@ -1288,7 +1290,8 @@ func macroFieldOrderFollowsForyRules() throws { let second = try buffer.readVarInt64() let third = try buffer.readVarInt32() - let tailContext = ReadContext(buffer: buffer, typeResolver: fory.typeResolver, config: fory.config) + let tailContext = ReadContext( + buffer: buffer, typeResolver: fory.typeResolver, config: fory.config) let fourth = try String.foryReadData(tailContext) #expect(first == value.shortValue) @@ -1316,7 +1319,8 @@ func macroTaggedFieldsKeepGroupedPayloadOrder() throws { _ = try buffer.readInt32() #expect(try buffer.readVarInt32() == value.intValue) - let tailContext = ReadContext(buffer: buffer, typeResolver: fory.typeResolver, config: fory.config) + let tailContext = ReadContext( + buffer: buffer, typeResolver: fory.typeResolver, config: fory.config) #expect(try String.foryReadData(tailContext) == value.textTail) } @@ -1326,7 +1330,7 @@ func macroNonPrimitiveFieldsSortByFieldIdentifier() throws { #expect( fields.map(\.fieldName) == [ - "intValue", "mapValue", "stringValue", "addressValue", "binaryValue" + "intValue", "mapValue", "stringValue", "addressValue", "binaryValue", ]) #expect(fields.map(\.fieldID) == [nil, 10, 20, nil, nil]) } @@ -1393,7 +1397,7 @@ func macroReducedPrecisionFieldsUseXlangTypeIDs() { TypeId.float16.rawValue, TypeId.bfloat16.rawValue, TypeId.bfloat16Array.rawValue, - TypeId.float16Array.rawValue + TypeId.float16Array.rawValue, ]) } @@ -1468,7 +1472,7 @@ func compatibleNestedStructArrayRoundTrip() throws { let value = CompatibleNestedArrayHolder( items: [ CompatibleNestedItem(id: 1, name: "alpha"), - CompatibleNestedItem(id: 2, name: "beta") + CompatibleNestedItem(id: 2, name: "beta"), ] ) let bytes = try writer.serialize(value) @@ -1490,7 +1494,7 @@ func compatibleNestedStructOptionalArrayRoundTrip() throws { items: [ CompatibleNestedItem(id: 1, name: "alpha"), nil, - CompatibleNestedItem(id: 2, name: "beta") + CompatibleNestedItem(id: 2, name: "beta"), ] ) let bytes = try writer.serialize(value) @@ -1511,7 +1515,7 @@ func compatibleNestedStructMapRoundTrip() throws { let value = CompatibleNestedMapHolder( items: [ 1: CompatibleNestedItem(id: 10, name: "first"), - 2: CompatibleNestedItem(id: 20, name: "second") + 2: CompatibleNestedItem(id: 20, name: "second"), ] ) let bytes = try writer.serialize(value) @@ -1541,7 +1545,7 @@ func pvlVarInt64AndVarUInt64Extremes() throws { 72_057_594_037_927_935, 72_057_594_037_927_936, UInt64(Int64.max), - UInt64.max + UInt64.max, ] let intValues: [Int64] = [ Int64.min, @@ -1558,7 +1562,7 @@ func pvlVarInt64AndVarUInt64Extremes() throws { 1_000_000, 1_000_000_000_000, Int64.max - 1, - Int64.max + Int64.max, ] let writeBuffer = ByteBuffer() @@ -1644,7 +1648,7 @@ func typeMetaRoundTripByName() throws { nullable: true, generics: [ .init(typeID: TypeId.string.rawValue, nullable: false), - .init(typeID: TypeId.varint32.rawValue, nullable: true) + .init(typeID: TypeId.varint32.rawValue, nullable: true), ] ) ), @@ -1652,7 +1656,7 @@ func typeMetaRoundTripByName() throws { fieldID: 7, fieldName: "ignored_for_tag_mode", fieldType: .init(typeID: TypeId.varint32.rawValue, nullable: false) - ) + ), ] let meta = try TypeMeta( diff --git a/swift/Tests/ForyTests/GraphMemoryBudgetTests.swift b/swift/Tests/ForyTests/GraphMemoryBudgetTests.swift new file mode 100644 index 0000000000..35181745ff --- /dev/null +++ b/swift/Tests/ForyTests/GraphMemoryBudgetTests.swift @@ -0,0 +1,359 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +import Foundation +import Testing + +@testable import Fory + +@ForyStruct +private final class BudgetNode { + var id: Int32 = 0 + + required init() {} + + init(id: Int32) { + self.id = id + } +} + +@ForyStruct +private struct BudgetSiblings { + var left: [BudgetNode] = [] + var right: [BudgetNode] = [] +} + +@ForyStruct +private struct BudgetDenseHolder: Equatable { + var text: String = "" + var data: Data = Data() + @ArrayField(element: .int32()) + var dense: [Int32] = [] +} + +private func makeBudgetFory(maxGraphMemoryBytes: Int64 = -1) -> Fory { + let fory = Fory( + config: .init( + trackRef: false, + compatible: false, + maxGraphMemoryBytes: maxGraphMemoryBytes + )) + fory.register(BudgetNode.self, id: 9801) + fory.register(BudgetSiblings.self, id: 9802) + fory.register(BudgetDenseHolder.self, id: 9803) + return fory +} + +private let testReferenceBytes = 4 +private let budgetNodeGraphBytes = 1 + 4 + +private func elementBytes(_ type: Element.Type) -> Int { + type.isRefType ? testReferenceBytes : max(1, MemoryLayout.stride) +} + +private func ownerBytes(_ type: T.Type) -> Int { + max(1, MemoryLayout.stride) +} + +private func arrayBudget(_ type: Element.Type, count: Int) -> Int { + count * elementBytes(type) +} + +private func rootArrayBudget( + _ type: Element.Type, + count: Int, + elementOwnerBytes: Int = 0 +) -> Int { + ownerBytes([Element].self) + arrayBudget(type, count: count) + count * elementOwnerBytes +} + +private func mapBudget( + key: Key.Type, + value: Value.Type, + count: Int +) -> Int { + count * (elementBytes(key) + elementBytes(value)) +} + +private func rootMapBudget( + key: Key.Type, + value: Value.Type, + count: Int +) -> Int { + ownerBytes(Dictionary.self) + mapBudget(key: key, value: value, count: count) +} + +private func expectInvalidData(_ body: () throws -> Void) { + do { + try body() + Issue.record("expected invalid data") + } catch ForyError.invalidData { + } catch { + Issue.record("expected invalid data, got \(error)") + } +} + +@Test +func knownLengthAutoBudgetUsesInputBytes() throws { + let expected = 17 * 8 + ReadContext.knownGraphBudgetSlackBytes + let config = Config(trackRef: false, compatible: false) + let context = ReadContext( + buffer: ByteBuffer(), + typeResolver: TypeResolver(config: config), + config: config + ) + + try context.initGraphMemoryBudgetKnown(rootBytes: 17) + try context.reserveGraphMemory(expected) + expectInvalidData { + try context.reserveGraphMemory(testReferenceBytes) + } +} + +@Test +func byteBufferRootUsesKnownLengthAutoBudget() throws { + let count = 6 + let value = Array(repeating: [String](), count: count) + let bytes = try makeBudgetFory().serialize(value) + let buffer = ByteBuffer(data: bytes) + + let decoded: [[String]] = try makeBudgetFory().deserialize(from: buffer) + #expect(decoded.count == count) +} + +@Test +func explicitConfigOverridesAutoBudget() throws { + let values = (0..<16).map { "value-\($0)" } + let bytes = try makeBudgetFory().serialize(values) + let required = rootArrayBudget(String.self, count: values.count) + + expectInvalidData { + let _: [String] = try makeBudgetFory(maxGraphMemoryBytes: Int64(required - 1)).deserialize( + bytes) + } + let decoded: [String] = try makeBudgetFory(maxGraphMemoryBytes: Int64(required)).deserialize( + bytes) + #expect(decoded == values) +} + +@Test +func siblingContainersShareOneBudget() throws { + let value = BudgetSiblings( + left: (0..<16).map { BudgetNode(id: Int32($0)) }, + right: (16..<32).map { BudgetNode(id: Int32($0)) } + ) + let bytes = try makeBudgetFory().serialize(value) + let oneList = arrayBudget(BudgetNode.self, count: 16) + 16 * budgetNodeGraphBytes + let required = ownerBytes(BudgetSiblings.self) + oneList * 2 + + expectInvalidData { + let _: BudgetSiblings = try makeBudgetFory(maxGraphMemoryBytes: Int64(required - 1)) + .deserialize(bytes) + } + let decoded: BudgetSiblings = try makeBudgetFory(maxGraphMemoryBytes: Int64(required)) + .deserialize(bytes) + #expect(decoded.left.count == 16) + #expect(decoded.right.count == 16) +} + +@Test +func mapBudgetIsCharged() throws { + let value: [String: Int32] = ["a": 1, "b": 2, "c": 3] + let bytes = try makeBudgetFory().serialize(value) + let required = rootMapBudget(key: String.self, value: Int32.self, count: value.count) + + expectInvalidData { + let _: [String: Int32] = try makeBudgetFory(maxGraphMemoryBytes: Int64(required - 1)) + .deserialize(bytes) + } + let decoded: [String: Int32] = try makeBudgetFory(maxGraphMemoryBytes: Int64(required)) + .deserialize(bytes) + #expect(decoded == value) +} + +@Test +func referenceAndInlineValueArraysAreCharged() throws { + let nodes = (0..<4).map { BudgetNode(id: Int32($0)) } + let nodeBytes = try makeBudgetFory().serialize(nodes) + let nodeBudget = rootArrayBudget( + BudgetNode.self, + count: nodes.count, + elementOwnerBytes: budgetNodeGraphBytes + ) + expectInvalidData { + let _: [BudgetNode] = try makeBudgetFory(maxGraphMemoryBytes: Int64(nodeBudget - 1)) + .deserialize(nodeBytes) + } + let decodedNodes: [BudgetNode] = try makeBudgetFory(maxGraphMemoryBytes: Int64(nodeBudget)) + .deserialize(nodeBytes) + #expect(decodedNodes.count == nodes.count) + + let ints: [Int32] = [1, 2, 3, 4] + let intBytes = try makeBudgetFory().serialize(ints) + let intBudget = rootArrayBudget(Int32.self, count: ints.count) + expectInvalidData { + let _: [Int32] = try makeBudgetFory(maxGraphMemoryBytes: Int64(intBudget - 1)) + .deserialize(intBytes) + } + #expect(try makeBudgetFory(maxGraphMemoryBytes: Int64(intBudget)).deserialize(intBytes) == ints) +} + +@Test +func stringBinaryAndPrimitiveDenseArrayOwnersAreSkipped() throws { + let value = BudgetDenseHolder( + text: "budget", + data: Data([1, 2, 3]), + dense: [1, 2, 3] + ) + let bytes = try makeBudgetFory().serialize(value) + let required = ownerBytes(BudgetDenseHolder.self) + + expectInvalidData { + let _: BudgetDenseHolder = try makeBudgetFory(maxGraphMemoryBytes: Int64(required - 1)) + .deserialize(bytes) + } + let decoded: BudgetDenseHolder = try makeBudgetFory(maxGraphMemoryBytes: Int64(required)) + .deserialize(bytes) + #expect(decoded == value) +} + +@Test +func dynamicAnyEmptyMapOwnerSelf() throws { + let value = [:] as [AnyHashable: Any] + let bytes = try makeBudgetFory().serialize(value as Any) + let required = + ownerBytes(Dictionary.self) + + ownerBytes(Dictionary.self) + + expectInvalidData { + let _: Any = try makeBudgetFory(maxGraphMemoryBytes: Int64(required - 1)) + .deserialize(bytes) + } + let decoded: Any = try makeBudgetFory(maxGraphMemoryBytes: Int64(required)) + .deserialize(bytes) + #expect((decoded as? [String: Any])?.isEmpty == true) +} + +@Test +func publicAnyArrayBudget() throws { + let value: [Any] = [Int32(1), Int32(2), Int32(3)] + let bytes = try makeBudgetFory().serialize(value) + let wrappedBudget = arrayBudget(SerializableAny.self, count: value.count) + let finalBudget = ownerBytes([Any].self) + value.count * testReferenceBytes + + expectInvalidData { + let _: [Any] = try makeBudgetFory(maxGraphMemoryBytes: Int64(wrappedBudget)) + .deserialize(bytes, as: [Any].self) + } + let decoded = try makeBudgetFory(maxGraphMemoryBytes: Int64(wrappedBudget + finalBudget)) + .deserialize(bytes, as: [Any].self) + #expect(decoded.count == value.count) +} + +@Test +func publicAnyMapBudget() throws { + let stringMap: [String: Any] = ["a": Int32(1), "b": Int32(2), "c": Int32(3)] + let stringBytes = try makeBudgetFory().serialize(stringMap) + let stringWrapped = mapBudget( + key: String.self, + value: SerializableAny.self, + count: stringMap.count + ) + let stringFinal = + ownerBytes(Dictionary.self) + stringMap.count * 2 * testReferenceBytes + expectInvalidData { + let _: [String: Any] = try makeBudgetFory(maxGraphMemoryBytes: Int64(stringWrapped)) + .deserialize(stringBytes, as: [String: Any].self) + } + let decodedString = try makeBudgetFory(maxGraphMemoryBytes: Int64(stringWrapped + stringFinal)) + .deserialize(stringBytes, as: [String: Any].self) + #expect(decodedString.count == stringMap.count) + + let intMap: [Int32: Any] = [1: Int32(10), 2: Int32(20), 3: Int32(30)] + let intBytes = try makeBudgetFory().serialize(intMap) + let intWrapped = mapBudget( + key: Int32.self, + value: SerializableAny.self, + count: intMap.count + ) + let intFinal = ownerBytes(Dictionary.self) + intMap.count * 2 * testReferenceBytes + expectInvalidData { + let _: [Int32: Any] = try makeBudgetFory(maxGraphMemoryBytes: Int64(intWrapped)) + .deserialize(intBytes, as: [Int32: Any].self) + } + let decodedInt = try makeBudgetFory(maxGraphMemoryBytes: Int64(intWrapped + intFinal)) + .deserialize(intBytes, as: [Int32: Any].self) + #expect(decodedInt.count == intMap.count) + + let anyHashableMap: [AnyHashable: Any] = [ + AnyHashable("a"): Int32(1), + AnyHashable(Int32(2)): Int32(2), + AnyHashable(true): Int32(3), + ] + let anyHashableBytes = try makeBudgetFory().serialize(anyHashableMap) + let anyHashableWrapped = mapBudget( + key: AnyHashable.self, + value: SerializableAny.self, + count: anyHashableMap.count + ) + let anyHashableFinal = + ownerBytes(Dictionary.self) + anyHashableMap.count * 2 * testReferenceBytes + expectInvalidData { + let _: [AnyHashable: Any] = try makeBudgetFory( + maxGraphMemoryBytes: Int64(anyHashableWrapped) + ).deserialize(anyHashableBytes, as: [AnyHashable: Any].self) + } + let decodedAnyHashable = try makeBudgetFory( + maxGraphMemoryBytes: Int64(anyHashableWrapped + anyHashableFinal) + ).deserialize(anyHashableBytes, as: [AnyHashable: Any].self) + #expect(decodedAnyHashable.count == anyHashableMap.count) +} + +@Test +func dynamicAnyArrayBudget() throws { + let list: [Any] = [Int32(1), "two", Int32(3)] + let value: Any = list + let bytes = try makeBudgetFory().serialize(value) + let count = list.count + let wrappedBudget = arrayBudget(SerializableAny.self, count: count) + let finalBudget = ownerBytes([Any].self) + count * testReferenceBytes + + expectInvalidData { + let _: Any = try makeBudgetFory(maxGraphMemoryBytes: Int64(wrappedBudget)) + .deserialize(bytes, as: Any.self) + } + let decoded = try makeBudgetFory(maxGraphMemoryBytes: Int64(wrappedBudget + finalBudget)) + .deserialize(bytes, as: Any.self) + #expect((decoded as? [Any])?.count == count) +} + +@Test +func byteAvailabilityCheckStillRejectsLargeLength() throws { + let buffer = ByteBuffer() + buffer.writeVarUInt32(64) + buffer.writeUInt8(CollectionHeader.sameType | CollectionHeader.declaredElementType) + let config = Config(trackRef: false, compatible: false) + let context = ReadContext( + buffer: buffer, + typeResolver: TypeResolver(config: config), + config: config + ) + + expectInvalidData { + let _: [String] = try [String].foryReadData(context) + } +} From 8c5beb902bbdfb39383b50a680d7bc9f7897afba Mon Sep 17 00:00:00 2001 From: chaokunyang Date: Wed, 1 Jul 2026 01:26:16 +0800 Subject: [PATCH 15/54] perf: trim graph memory budget hot paths --- cpp/fory/serialization/context.h | 58 ++++++++ cpp/fory/serialization/fory.h | 137 ++++++++++++++++-- cpp/fory/serialization/serializer_traits.h | 88 ++++++++++-- go/fory/fory.go | 160 +++++++++++++++------ go/fory/reader.go | 38 +++++ go/fory/stream.go | 23 +-- 6 files changed, 429 insertions(+), 75 deletions(-) diff --git a/cpp/fory/serialization/context.h b/cpp/fory/serialization/context.h index 12fc33ccd8..ee4a6e9dbc 100644 --- a/cpp/fory/serialization/context.h +++ b/cpp/fory/serialization/context.h @@ -525,6 +525,35 @@ class ReadContext { return true; } + FORY_ALWAYS_INLINE bool init_graph_budget_known(size_t root_bytes, + size_t reserve_bytes) { + const int64_t configured = config_->max_graph_memory_bytes; + if (FORY_PREDICT_FALSE(configured > 0)) { + if constexpr (sizeof(size_t) < sizeof(uint64_t)) { + if (FORY_PREDICT_FALSE( + static_cast(configured) > + static_cast(std::numeric_limits::max()))) { + return set_graph_memory_error( + "max_graph_memory_bytes does not fit size_t"); + } + } + return init_graph_budget_limit(static_cast(configured), + reserve_bytes); + } + if constexpr (sizeof(size_t) <= sizeof(uint32_t)) { + constexpr size_t max_root_bytes = + (std::numeric_limits::max() - kKnownGraphBudgetSlackBytes) / + kKnownGraphBudgetMultiplier; + if (FORY_PREDICT_FALSE(root_bytes > max_root_bytes)) { + return set_graph_memory_error( + "root input size overflows automatic graph memory budget"); + } + } + return init_graph_budget_limit(root_bytes * kKnownGraphBudgetMultiplier + + kKnownGraphBudgetSlackBytes, + reserve_bytes); + } + FORY_ALWAYS_INLINE bool init_graph_budget_unknown() { const int64_t configured = config_->max_graph_memory_bytes; if (FORY_PREDICT_FALSE(configured > 0)) { @@ -535,6 +564,23 @@ class ReadContext { return true; } + FORY_ALWAYS_INLINE bool init_graph_budget_unknown(size_t reserve_bytes) { + const int64_t configured = config_->max_graph_memory_bytes; + if (FORY_PREDICT_FALSE(configured > 0)) { + if constexpr (sizeof(size_t) < sizeof(uint64_t)) { + if (FORY_PREDICT_FALSE( + static_cast(configured) > + static_cast(std::numeric_limits::max()))) { + return set_graph_memory_error( + "max_graph_memory_bytes does not fit size_t"); + } + } + return init_graph_budget_limit(static_cast(configured), + reserve_bytes); + } + return init_graph_budget_limit(kUnknownGraphBudgetBytes, reserve_bytes); + } + FORY_ALWAYS_INLINE void defer_graph_budget_known(size_t root_bytes) { pending_graph_root_bytes_ = root_bytes; graph_budget_state_ = kGraphBudgetPendingKnown; @@ -558,6 +604,16 @@ class ReadContext { return true; } + FORY_ALWAYS_INLINE bool init_graph_budget_limit(size_t limit, + size_t reserve_bytes) { + graph_budget_state_ = kGraphBudgetReady; + if (FORY_PREDICT_FALSE(reserve_bytes > limit)) { + return set_graph_memory_exceeded(reserve_bytes, limit); + } + remaining_graph_memory_bytes_ = limit - reserve_bytes; + return true; + } + template FORY_ALWAYS_INLINE bool reserve_counted_graph_memory(uint32_t length) { constexpr size_t kMaxLength = @@ -773,6 +829,8 @@ class ReadContext { meta::MetaStringTable meta_string_table_; fory::flat_hash_map remote_schema_versions_by_type_; size_t total_accepted_schema_versions_ = 0; + + friend class Fory; }; /// Implementation of DynDepthGuard destructor diff --git a/cpp/fory/serialization/fory.h b/cpp/fory/serialization/fory.h index bdd86d25b4..084df9d4d0 100644 --- a/cpp/fory/serialization/fory.h +++ b/cpp/fory/serialization/fory.h @@ -40,6 +40,7 @@ #include "fory/util/result.h" #include "fory/util/stream.h" #include +#include #include #include #include @@ -513,7 +514,19 @@ class BaseFory { /// Protected constructor - only derived classes can instantiate. explicit BaseFory(const Config &config, std::shared_ptr resolver) - : config_(config), type_resolver_(std::move(resolver)) {} + : config_(config), type_resolver_(std::move(resolver)) { + auto_graph_memory_budget_ = config_.max_graph_memory_bytes <= 0; + explicit_graph_memory_bytes_ = std::numeric_limits::max(); + if (config_.max_graph_memory_bytes > 0) { + if constexpr (sizeof(size_t) < sizeof(uint64_t)) { + graph_budget_limit_fits_size_t_ = + static_cast(config_.max_graph_memory_bytes) <= + static_cast(std::numeric_limits::max()); + } + explicit_graph_memory_bytes_ = + static_cast(config_.max_graph_memory_bytes); + } + } // Non-copyable BaseFory(const BaseFory &) = delete; @@ -525,6 +538,9 @@ class BaseFory { Config config_; std::shared_ptr type_resolver_; + size_t explicit_graph_memory_bytes_ = std::numeric_limits::max(); + bool auto_graph_memory_budget_ = true; + bool graph_budget_limit_fits_size_t_ = true; mutable std::mutex registration_mutex_; mutable bool registration_locked_{false}; }; @@ -824,6 +840,92 @@ class Fory : public BaseFory { ", local xlang=" + std::string(config_.xlang ? "true" : "false")); } + FORY_NOINLINE Error root_graph_config_too_large() const { + return Error::invalid_data("max_graph_memory_bytes does not fit size_t"); + } + + FORY_NOINLINE Error root_graph_budget_overflow() const { + return Error::invalid_data( + "root input size overflows automatic graph memory budget"); + } + + FORY_NOINLINE Error root_graph_budget_exceeded(size_t bytes, + size_t limit) const { + return Error::invalid_data( + "estimated graph memory request " + std::to_string(bytes) + + " bytes exceeds max_graph_memory_bytes remaining budget " + + std::to_string(limit) + " bytes"); + } + + template + FORY_ALWAYS_INLINE bool reserve_root_graph_self(size_t root_bytes) { + if constexpr (unknown_root) { + if constexpr (root_owner_bytes <= ReadContext::kUnknownGraphBudgetBytes) { + if (FORY_PREDICT_TRUE(auto_graph_memory_budget_)) { + return true; + } + } + } else if constexpr (root_owner_bytes <= + ReadContext::kKnownGraphBudgetSlackBytes) { + if (FORY_PREDICT_TRUE(auto_graph_memory_budget_)) { + return true; + } + } + if constexpr (sizeof(size_t) < sizeof(uint64_t)) { + if (FORY_PREDICT_FALSE(!graph_budget_limit_fits_size_t_)) { + read_ctx_->set_error(root_graph_config_too_large()); + return false; + } + } + if (FORY_PREDICT_FALSE(root_owner_bytes > explicit_graph_memory_bytes_)) { + if constexpr (sizeof(size_t) < sizeof(uint64_t)) { + if (FORY_PREDICT_FALSE(!graph_budget_limit_fits_size_t_)) { + read_ctx_->set_error(root_graph_config_too_large()); + return false; + } + } + if (FORY_PREDICT_FALSE(config_.max_graph_memory_bytes > 0)) { + read_ctx_->set_error(root_graph_budget_exceeded( + root_owner_bytes, explicit_graph_memory_bytes_)); + return false; + } + } + if constexpr (unknown_root) { + if constexpr (root_owner_bytes > ReadContext::kUnknownGraphBudgetBytes) { + if (FORY_PREDICT_FALSE(config_.max_graph_memory_bytes > 0)) { + return true; + } + read_ctx_->set_error(root_graph_budget_exceeded( + root_owner_bytes, ReadContext::kUnknownGraphBudgetBytes)); + return false; + } + } else if constexpr (root_owner_bytes > + ReadContext::kKnownGraphBudgetSlackBytes) { + if (FORY_PREDICT_FALSE(config_.max_graph_memory_bytes > 0)) { + return true; + } + if constexpr (sizeof(size_t) <= sizeof(uint32_t)) { + constexpr size_t max_root_bytes = + (std::numeric_limits::max() - + ReadContext::kKnownGraphBudgetSlackBytes) / + ReadContext::kKnownGraphBudgetMultiplier; + if (FORY_PREDICT_FALSE(root_bytes > max_root_bytes)) { + read_ctx_->set_error(root_graph_budget_overflow()); + return false; + } + } + const size_t limit = + root_bytes * ReadContext::kKnownGraphBudgetMultiplier + + ReadContext::kKnownGraphBudgetSlackBytes; + if (FORY_PREDICT_FALSE(root_owner_bytes > limit)) { + read_ctx_->set_error( + root_graph_budget_exceeded(root_owner_bytes, limit)); + return false; + } + } + return true; + } + /// Core serialization implementation. /// TypeMeta is written inline using streaming protocol (no deferred writing). template @@ -890,16 +992,33 @@ class Fory : public BaseFory { read_ctx_->attach(buffer); if constexpr (needs_graph_budget_v) { - if constexpr (unknown_root) { - read_ctx_->defer_graph_budget_unknown(); - } else { - read_ctx_->defer_graph_budget_known(root_bytes); - } constexpr size_t root_owner_bytes = graph_value_owner_self_bytes(); + constexpr bool has_child_budget = has_graph_budget_children_v; if constexpr (root_owner_bytes != 0) { - if (FORY_PREDICT_FALSE( - !read_ctx_->reserve_graph_memory(root_owner_bytes))) { - return Unexpected(read_ctx_->take_error()); + if constexpr (has_child_budget) { + if constexpr (unknown_root) { + if (FORY_PREDICT_FALSE( + !read_ctx_->init_graph_budget_unknown(root_owner_bytes))) { + return Unexpected(read_ctx_->take_error()); + } + } else { + if (FORY_PREDICT_FALSE(!read_ctx_->init_graph_budget_known( + root_bytes, root_owner_bytes))) { + return Unexpected(read_ctx_->take_error()); + } + } + } else { + if (FORY_PREDICT_FALSE( + (!reserve_root_graph_self( + root_bytes)))) { + return Unexpected(read_ctx_->take_error()); + } + } + } else if constexpr (has_child_budget) { + if constexpr (unknown_root) { + read_ctx_->defer_graph_budget_unknown(); + } else { + read_ctx_->defer_graph_budget_known(root_bytes); } } } diff --git a/cpp/fory/serialization/serializer_traits.h b/cpp/fory/serialization/serializer_traits.h index 59aa361d25..2ee4ae5723 100644 --- a/cpp/fory/serialization/serializer_traits.h +++ b/cpp/fory/serialization/serializer_traits.h @@ -298,15 +298,6 @@ struct needs_graph_budget, void> std::remove_reference_t>>::value || ...)> {}; -template -constexpr bool struct_needs_graph_budget_impl(std::index_sequence) { - return ( - needs_graph_budget< - std::remove_cv_t>>>>::value || - ...); -} - template struct needs_graph_budget>> : std::true_type {}; @@ -337,6 +328,85 @@ template constexpr size_t graph_value_owner_self_bytes() { } } +template +struct has_graph_budget_children : std::false_type {}; + +template +struct has_graph_budget_children, void> + : std::bool_constant>> { +}; + +template +struct has_graph_budget_children, void> + : std::true_type {}; + +template +struct has_graph_budget_children< + T, std::enable_if_t || is_deque_v || is_forward_list_v || + is_set_like_v || is_map_like_v>> + : std::true_type {}; + +template +struct has_graph_budget_children, void> + : has_graph_budget_children>> { +}; + +template +struct has_graph_budget_children, void> + : has_graph_budget_children>> { +}; + +template +struct has_graph_budget_children, void> + : std::bool_constant<(graph_value_owner_self_bytes() != 0) || + has_graph_budget_children>>::value> {}; + +template +struct has_graph_budget_children, void> + : std::bool_constant<(graph_value_owner_self_bytes() != 0) || + has_graph_budget_children>>::value> {}; + +template +struct has_graph_budget_children, void> + : std::bool_constant<(has_graph_budget_children>>::value || + ...)> {}; + +template +struct has_graph_budget_children, void> + : std::bool_constant<(has_graph_budget_children>>::value || + ...)> {}; + +template +constexpr bool struct_has_graph_children_impl(std::index_sequence) { + return ( + has_graph_budget_children< + std::remove_cv_t>>>>::value || + ...); +} + +template +struct has_graph_budget_children>> { +private: + using Value = std::remove_cv_t>; + using FieldInfo = + decltype(::fory::meta::fory_field_info(std::declval())); + using Ptrs = typename FieldInfo::PtrsType; + +public: + static constexpr bool value = struct_has_graph_children_impl( + std::make_index_sequence>{}); +}; + +template +inline constexpr bool has_graph_budget_children_v = has_graph_budget_children< + std::remove_cv_t>>::value; + template FORY_ALWAYS_INLINE bool reserve_allocated_value_owner(Context &ctx) { constexpr size_t bytes = graph_value_owner_self_bytes(); diff --git a/go/fory/fory.go b/go/fory/fory.go index 26811dcfaf..5b4d168878 100644 --- a/go/fory/fory.go +++ b/go/fory/fory.go @@ -196,6 +196,11 @@ type Fory struct { // Resolvers shared between contexts typeResolver *TypeResolver refResolver *RefResolver + + rootGraphType reflect.Type + rootGraphBytes int64 + rootGraphHasChildren bool + rootGraphSkipType reflect.Type } // New creates a new Fory instance with the given options @@ -570,9 +575,9 @@ func (f *Fory) Serialize(value any) ([]byte, error) { func (f *Fory) Deserialize(data []byte, v any) error { defer f.resetReadState() f.readCtx.SetData(data) - f.readCtx.initGraphMemoryBudget(len(data), false) - if f.readCtx.HasError() { - return f.readCtx.TakeError() + target := reflect.ValueOf(v).Elem() + if err := f.initRootGraphBudget(target, len(data), false); err != nil { + return err } readHeader(f.readCtx) @@ -581,11 +586,7 @@ func (f *Fory) Deserialize(data []byte, v any) error { } // Deserialize the value - TypeMeta is read inline using streaming protocol - target := reflect.ValueOf(v).Elem() - if err := f.reserveRootGraphOwner(target); err != nil { - return err - } - f.readCtx.ReadValue(target, RefModeTracking, true) + f.readRootValue(target) if f.readCtx.HasError() { return f.readCtx.TakeError() } @@ -669,10 +670,10 @@ func (f *Fory) DeserializeFrom(buf *ByteBuffer, v any) error { // Temporarily swap buffer origBuffer := f.readCtx.buffer f.readCtx.buffer = buf - f.readCtx.initGraphMemoryBudget(buf.readableBytes(), false) - if f.readCtx.HasError() { + target := reflect.ValueOf(v).Elem() + if err := f.initRootGraphBudget(target, buf.readableBytes(), false); err != nil { f.readCtx.buffer = origBuffer - return f.readCtx.TakeError() + return err } readHeader(f.readCtx) @@ -682,12 +683,7 @@ func (f *Fory) DeserializeFrom(buf *ByteBuffer, v any) error { } // Deserialize the value - TypeMeta is read inline using streaming protocol - target := reflect.ValueOf(v).Elem() - if err := f.reserveRootGraphOwner(target); err != nil { - f.readCtx.buffer = origBuffer - return err - } - f.readCtx.ReadValue(target, RefModeTracking, true) + f.readRootValue(target) if f.readCtx.HasError() { f.readCtx.buffer = origBuffer return f.readCtx.TakeError() @@ -773,21 +769,11 @@ func (f *Fory) DeserializeWithCallbackBuffers(buffer *ByteBuffer, v any, buffers f.readCtx.buffer = nil f.readCtx.outOfBandBuffers = nil }() - f.readCtx.initGraphMemoryBudget(buffer.readableBytes(), false) - if f.readCtx.HasError() { - return f.readCtx.TakeError() - } // Set up out-of-band buffers if provided if buffers != nil { f.readCtx.outOfBandBuffers = buffers } - // ReadData and validate header - readHeader(f.readCtx) - if f.readCtx.HasError() { - return f.readCtx.TakeError() - } - // v must be a pointer so we can deserialize into it if v == nil { return fmt.Errorf("v cannot be nil") @@ -800,12 +786,19 @@ func (f *Fory) DeserializeWithCallbackBuffers(buffer *ByteBuffer, v any, buffers return fmt.Errorf("v must be a non-nil pointer") } - // Deserialize the value - TypeMeta is read inline using streaming protocol target := rv.Elem() - if err := f.reserveRootGraphOwner(target); err != nil { + if err := f.initRootGraphBudget(target, buffer.readableBytes(), false); err != nil { return err } - f.readCtx.ReadValue(target, RefModeTracking, true) + + // ReadData and validate header + readHeader(f.readCtx) + if f.readCtx.HasError() { + return f.readCtx.TakeError() + } + + // Deserialize the value - TypeMeta is read inline using streaming protocol + f.readRootValue(target) if f.readCtx.HasError() { return f.readCtx.TakeError() } @@ -1054,9 +1047,16 @@ func Deserialize[T any](f *Fory, data []byte, target *T) error { // Reuse context, reset and set new data f.readCtx.Reset() f.readCtx.SetData(data) - f.readCtx.initGraphMemoryBudget(len(data), false) - if f.readCtx.HasError() { - return f.readCtx.TakeError() + + var targetVal reflect.Value + switch any(target).(type) { + case *bool, *int8, *int16, *int32, *int64, *int, *float32, *float64, *string, + *[]byte, *[]int8, *[]int16, *[]int32, *[]int64, *[]int, *[]float32, *[]float64, *[]bool: + default: + targetVal = reflect.ValueOf(target).Elem() + if err := f.initRootGraphBudget(targetVal, len(data), false); err != nil { + return err + } } // ReadData and validate header @@ -1196,11 +1196,10 @@ func Deserialize[T any](f *Fory, data []byte, target *T) error { return f.readCtx.CheckError() default: // Slow path: use serializer-based deserialization - targetVal := reflect.ValueOf(target).Elem() - targetType := targetVal.Type() - if err := f.reserveRootGraphOwner(targetVal); err != nil { - return err + if !targetVal.IsValid() { + targetVal = reflect.ValueOf(target).Elem() } + targetType := targetVal.Type() // Get serializer for the target type serializer, err := f.typeResolver.getSerializerByType(targetType, false) @@ -1214,16 +1213,95 @@ func Deserialize[T any](f *Fory, data []byte, target *T) error { } } -func (f *Fory) reserveRootGraphOwner(target reflect.Value) error { - if !target.IsValid() || target.Kind() != reflect.Struct { +func (f *Fory) initRootGraphBudget(target reflect.Value, rootInputBytes int, unknownLengthInput bool) error { + if target.IsValid() && target.Type() == f.rootGraphSkipType { + return nil + } + return f.initRootGraphBudgetSlow(target, rootInputBytes, unknownLengthInput) +} + +func (f *Fory) readRootValue(target reflect.Value) { + if target.IsValid() { + targetType := target.Type() + if targetType.Kind() == reflect.Struct && !f.typeResolver.IsUnionType(targetType) { + f.readCtx.ReadStruct(target) + return + } + } + f.readCtx.ReadValue(target, RefModeTracking, true) +} + +//go:noinline +func (f *Fory) initRootGraphBudgetSlow(target reflect.Value, rootInputBytes int, unknownLengthInput bool) error { + bytes, hasChildren, isStruct := f.rootGraphInfo(target) + if !isStruct { + f.readCtx.initGraphMemoryBudget(rootInputBytes, unknownLengthInput) + if f.readCtx.HasError() { + return f.readCtx.TakeError() + } return nil } + if hasChildren { + f.readCtx.initGraphMemoryBudget(rootInputBytes, unknownLengthInput) + if f.readCtx.HasError() { + return f.readCtx.TakeError() + } + if bytes != 0 && !f.readCtx.ReserveGraphMemory(bytes) { + return f.readCtx.TakeError() + } + return nil + } + if f.config.MaxGraphMemoryBytes <= 0 && bytes <= knownRootBudgetSlackBytes { + f.rootGraphSkipType = target.Type() + return nil + } + return f.checkRootGraphSelf(bytes, rootInputBytes, unknownLengthInput) +} + +func (f *Fory) rootGraphInfo(target reflect.Value) (int64, bool, bool) { + if !target.IsValid() || target.Kind() != reflect.Struct { + return 0, false, false + } targetType := target.Type() if targetType == dateReflectType || targetType == timeReflectType { + return 0, false, true + } + if targetType == f.rootGraphType { + return f.rootGraphBytes, f.rootGraphHasChildren, true + } + bytes := structGraphBytes(targetType) + hasChildren := typeHasGraphChildren(targetType) + f.rootGraphType = targetType + f.rootGraphBytes = bytes + f.rootGraphHasChildren = hasChildren + return bytes, hasChildren, true +} + +func (f *Fory) checkRootGraphSelf(bytes int64, rootInputBytes int, unknownLengthInput bool) error { + if bytes <= 0 { return nil } - if !reserveStructGraph(f.readCtx, targetType) { - return f.readCtx.TakeError() + limit := f.config.MaxGraphMemoryBytes + if limit <= 0 { + if unknownLengthInput { + limit = streamRootBudgetBytes + } else { + if rootInputBytes < 0 { + return DeserializationErrorf("root input size must be non-negative: %d", rootInputBytes) + } + if int64(rootInputBytes) > (MaxInt64-knownRootBudgetSlackBytes)/knownRootBudgetMultiplier { + return DeserializationErrorf("root input size %d overflows automatic graph memory budget", rootInputBytes) + } + if bytes <= knownRootBudgetSlackBytes { + return nil + } + limit = int64(rootInputBytes)*knownRootBudgetMultiplier + knownRootBudgetSlackBytes + } + } + if bytes > limit { + return DeserializationErrorf( + "estimated graph memory request %d bytes exceeds maxGraphMemoryBytes remaining budget %d bytes out of effective limit %d bytes", + bytes, limit, limit) } return nil } diff --git a/go/fory/reader.go b/go/fory/reader.go index 8f7c582bfc..5579659b96 100644 --- a/go/fory/reader.go +++ b/go/fory/reader.go @@ -90,6 +90,44 @@ func reserveStructGraph(ctx *ReadContext, type_ reflect.Type) bool { return ctx.ReserveGraphMemory(bytes) } +func typeHasGraphChildren(type_ reflect.Type) bool { + for type_.Kind() == reflect.Ptr { + elem := type_.Elem() + if structGraphBytes(elem) != 0 { + return true + } + type_ = elem + } + switch type_.Kind() { + case reflect.Struct: + if type_ == dateReflectType || type_ == timeReflectType { + return false + } + for i := 0; i < type_.NumField(); i++ { + if typeHasGraphChildren(type_.Field(i).Type) { + return true + } + } + return false + case reflect.Slice: + elem := type_.Elem() + switch elem.Kind() { + case reflect.Bool, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, + reflect.Float32, reflect.Float64: + return false + default: + return true + } + case reflect.Array: + return typeHasGraphChildren(type_.Elem()) + case reflect.Map, reflect.Interface: + return true + default: + return false + } +} + // IsXlang returns whether cross-language serialization mode is enabled func (c *ReadContext) IsXlang() bool { return c.xlang diff --git a/go/fory/stream.go b/go/fory/stream.go index 40cb642ae8..a032c7a64f 100644 --- a/go/fory/stream.go +++ b/go/fory/stream.go @@ -96,9 +96,8 @@ func (is *InputStream) Shrink() { func (f *Fory) DeserializeFromStream(is *InputStream, v any) error { origBuffer := f.readCtx.buffer f.readCtx.buffer = is.buffer - f.readCtx.initGraphMemoryBudget(0, true) - if f.readCtx.HasError() { - err := f.readCtx.TakeError() + target := reflect.ValueOf(v).Elem() + if err := f.initRootGraphBudget(target, 0, true); err != nil { f.readCtx.buffer = origBuffer f.resetReadState() return err @@ -113,11 +112,7 @@ func (f *Fory) DeserializeFromStream(is *InputStream, v any) error { return f.readCtx.TakeError() } - target := reflect.ValueOf(v).Elem() - if err := f.reserveRootGraphOwner(target); err != nil { - return err - } - f.readCtx.ReadValue(target, RefModeTracking, true) + f.readRootValue(target) if f.readCtx.HasError() { return f.readCtx.TakeError() } @@ -133,9 +128,9 @@ func (f *Fory) DeserializeFromReader(r io.Reader, v any) error { defer f.resetReadState() // Always reset to enforce stateless semantics. f.readCtx.buffer.ResetWithReader(r, 0) - f.readCtx.initGraphMemoryBudget(0, true) - if f.readCtx.HasError() { - return f.readCtx.TakeError() + target := reflect.ValueOf(v).Elem() + if err := f.initRootGraphBudget(target, 0, true); err != nil { + return err } readHeader(f.readCtx) @@ -143,11 +138,7 @@ func (f *Fory) DeserializeFromReader(r io.Reader, v any) error { return f.readCtx.TakeError() } - target := reflect.ValueOf(v).Elem() - if err := f.reserveRootGraphOwner(target); err != nil { - return err - } - f.readCtx.ReadValue(target, RefModeTracking, true) + f.readRootValue(target) if f.readCtx.HasError() { return f.readCtx.TakeError() } From ec743cbb1e531fda0e4dd0351bed949bf0b965ec Mon Sep 17 00:00:00 2001 From: chaokunyang Date: Wed, 1 Jul 2026 02:26:39 +0800 Subject: [PATCH 16/54] feat: use fixed root graph memory budget --- .agents/languages/cpp.md | 8 +- .agents/languages/dart.md | 8 +- .agents/languages/go.md | 9 +- .agents/languages/java.md | 8 +- .agents/languages/javascript.md | 7 +- .agents/languages/python.md | 8 +- .agents/languages/rust.md | 8 +- AGENTS.md | 2 +- cpp/fory/meta/field_info.h | 4 +- .../serialization/collection_serializer.h | 34 +- cpp/fory/serialization/config.h | 6 +- cpp/fory/serialization/context.cc | 48 +-- cpp/fory/serialization/context.h | 108 +----- cpp/fory/serialization/fory.h | 145 +------- .../serialization/graph_memory_budget_test.cc | 61 ++-- cpp/fory/serialization/map_serializer.h | 25 +- csharp/src/Fory/Config.cs | 17 +- csharp/src/Fory/Fory.cs | 6 +- csharp/src/Fory/ReadContext.cs | 44 +-- .../Fory.Tests/GraphMemoryBudgetTests.cs | 24 +- dart/packages/fory/lib/src/config.dart | 8 +- .../fory/lib/src/context/read_context.dart | 11 +- dart/packages/fory/lib/src/fory.dart | 8 - .../fory/test/graph_memory_budget_test.dart | 27 +- docs/guide/cpp/configuration.md | 32 +- docs/guide/csharp/configuration.md | 28 +- docs/guide/dart/configuration.md | 34 +- docs/guide/go/configuration.md | 43 ++- docs/guide/java/configuration.md | 8 +- docs/guide/javascript/configuration.md | 37 +- docs/guide/python/configuration.md | 39 +-- docs/guide/rust/configuration.md | 40 +-- docs/guide/swift/configuration.md | 8 +- docs/security/deserialization.md | 15 +- .../xlang_implementation_guide.md | 18 +- go/fory/array.go | 7 +- go/fory/codegen/decoder.go | 58 ++-- go/fory/fory.go | 80 ++--- go/fory/graph_memory_budget_test.go | 36 +- go/fory/map.go | 10 +- go/fory/map_primitive.go | 10 +- go/fory/reader.go | 85 ++--- go/fory/set.go | 12 +- go/fory/slice.go | 14 +- go/fory/slice_dyn.go | 14 +- go/fory/slice_primitive.go | 10 +- go/fory/slice_primitive_list.go | 22 +- go/fory/stream.go | 8 +- go/fory/tests/structs_fory_gen.go | 322 ++++++++++++++++-- .../src/main/java/org/apache/fory/Fory.java | 33 +- .../java/org/apache/fory/config/Config.java | 2 +- .../org/apache/fory/config/ForyBuilder.java | 10 +- .../org/apache/fory/context/ReadContext.java | 26 +- .../java/org/apache/fory/ForyTestBase.java | 2 +- .../fory/io/MemoryBufferObjectInputTest.java | 2 +- .../fory/io/MemoryBufferObjectOutputTest.java | 2 +- .../fory/resolver/ClassResolverTest.java | 6 +- .../fory/serializer/ArraySerializersTest.java | 14 +- .../serializer/CompatibleSerializerTest.java | 2 +- .../serializer/ExceptionSerializersTest.java | 4 +- .../serializer/GraphMemoryBudgetTest.java | 68 ++-- .../serializer/PrimitiveSerializersTest.java | 4 +- .../ChildContainerSerializersTest.java | 2 +- .../collection/CollectionSerializersTest.java | 2 +- javascript/packages/core/lib/context.ts | 10 +- javascript/packages/core/lib/fory.ts | 9 +- javascript/test/graphMemoryBudget.test.ts | 28 +- python/pyfory/_fory.py | 10 +- python/pyfory/context.pxi | 39 +-- python/pyfory/context.py | 28 +- python/pyfory/serialization.pyx | 30 +- .../pyfory/tests/test_graph_memory_budget.py | 50 +-- rust/fory-core/src/config.rs | 5 +- rust/fory-core/src/context.rs | 68 +--- rust/fory-core/src/fory.rs | 12 +- rust/fory-core/src/serializer/codec.rs | 22 +- rust/fory-core/src/serializer/collection.rs | 20 +- rust/fory-core/src/serializer/map.rs | 18 +- rust/tests/tests/test_graph_memory_budget.rs | 106 +++--- swift/Sources/Fory/AnySerializer.swift | 13 +- .../Sources/Fory/CollectionSerializers.swift | 25 +- swift/Sources/Fory/FieldCodecs.swift | 24 +- swift/Sources/Fory/Fory.swift | 11 +- swift/Sources/Fory/ReadContext.swift | 32 +- swift/Tests/ForyTests/ForySwiftTests.swift | 2 +- .../ForyTests/GraphMemoryBudgetTests.swift | 24 +- 86 files changed, 1202 insertions(+), 1187 deletions(-) diff --git a/.agents/languages/cpp.md b/.agents/languages/cpp.md index be74b77b22..564b7f1885 100644 --- a/.agents/languages/cpp.md +++ b/.agents/languages/cpp.md @@ -18,11 +18,13 @@ Load this file when changing `cpp/`, Cython build plumbing, or C++ xlang behavio - For cross-language feature ports, match protocol behavior but use idiomatic C++ ownership and layering instead of mirroring Java structure literally. - Compatible scalar, list-array, and binary/uint8-array adaptations are immediate-field-only. Recursive matched-field comparison for collection elements, array elements, map keys, and map values must require exact nullability, ref tracking, generic arity, and type shape except documented user-type family normalization. - Root deserialization graph budgets are owned by `ReadContext` and initialized by the root - `Fory::deserialize` overload. Keep `max_graph_memory_bytes` as `-1 / auto` or a positive explicit - limit; known byte roots use `inputBytes * 8 + 64 KiB`, while stream roots use fixed `128 MiB`. + `Fory::deserialize` overload. Keep `max_graph_memory_bytes` as a fixed-default graph limit: + unset/default is `128 MiB`, positive explicit values override it, and explicit non-positive + values intentionally disable graph-memory enforcement. Byte and stream roots use the same + configured/default budget behavior. Reserve estimated shallow graph-owner memory before allocation while preserving existing byte-availability checks and their non-empty metadata ordering. `ReadContext` may expose only raw - byte reservation and generic counted-byte arithmetic; collection, map, array, struct, and object + byte reservation; collection, map, array, struct, and object formulas belong in serializer owners. Skip dedicated string, binary, primitive scalar, primitive vector, and primitive dense-array leaf owners; `std::vector` charges rounded packed-bit storage. General `std::vector` for non-primitive `T` is inline value storage and must be diff --git a/.agents/languages/dart.md b/.agents/languages/dart.md index 5a80cb98e5..d7a5dea99e 100644 --- a/.agents/languages/dart.md +++ b/.agents/languages/dart.md @@ -15,10 +15,10 @@ Load this file when changing `dart/`. - Dart 64-bit carriers are optimized for each platform. Do not replace native extension-type wrappers with allocation-heavy classes or route web/native hot paths through `BigInt` unless the user approves a representation change. - In `Buffer`, cursor, serializer, and generated-code hot paths, prefer direct byte/local integer operations and conditional import/export files over callbacks, records, holder objects, wrapper round-trips, or runtime platform branches. - Root deserialization graph memory budgets are owned by `ReadContext`; - `maxGraphMemoryBytes` defaults to `-1 / auto`, positive explicit values win, - and Dart auto uses `buffer.readableBytes * 8 + 64 KiB` because roots are - memory-backed. `ReadContext` may expose only raw byte reservation and generic - counted-byte arithmetic; list, set, map, array, struct, and object formulas + `maxGraphMemoryBytes` defaults to fixed `128 MiB`, positive explicit values override it, and + explicit non-positive values intentionally disable graph-memory enforcement. Do not derive the + budget from `buffer.readableBytes`. `ReadContext` may expose only raw byte reservation; list, set, + map, array, struct, and object formulas belong in serializer owners. Reserve Dart list/set/object-array reference slots plus nonzero owner self cost, map key/value slots plus nonzero owner self cost, compatible list-to-array inline storage, compatible array-to-list diff --git a/.agents/languages/go.md b/.agents/languages/go.md index fcffe384af..87442b91af 100644 --- a/.agents/languages/go.md +++ b/.agents/languages/go.md @@ -8,10 +8,11 @@ Load this file when changing `go/fory/` or Go xlang behavior. - Changes under `go/` must pass formatting and tests. - The Go implementation focuses on reflection-based and codegen-based serialization. - Root deserialization graph memory budgets are owned by `ReadContext`. - `WithMaxGraphMemoryBytes` defaults to `-1 / auto`; byte-slice roots use - `inputBytes * 8 + 64 KiB`, and `DeserializeFromReader`/`DeserializeFromStream` - use fixed `128 MiB`. `ReadContext` may expose only raw byte reservation and - generic counted-byte arithmetic; slice, map, array, struct, and object + `WithMaxGraphMemoryBytes` uses a fixed `128 MiB` default; positive explicit + values override it, and explicit non-positive values intentionally disable + graph-memory enforcement. Byte-slice and stream roots use the same + configured/default budget behavior. `ReadContext` may expose only raw byte + reservation; slice, map, array, struct, and object formulas belong in handwritten or generated serializer owners. Reserve Go slices as `len * elemBytes`, maps as `len * (keyBytes + valueBytes)`, map-backed sets, and LIST-encoded inline/value slices in the owner that diff --git a/.agents/languages/java.md b/.agents/languages/java.md index 2150b74d8d..a3f9aeacb6 100644 --- a/.agents/languages/java.md +++ b/.agents/languages/java.md @@ -16,9 +16,11 @@ Load this file when changing anything under `java/` or when Java drives a cross- - `WriteContext`, `ReadContext`, and `CopyContext` must stay explicit. Do not reintroduce `ThreadLocal` or ambient runtime-context patterns. - Java root deserialization graph memory budgeting belongs to `ReadContext` and is initialized by `Fory` root APIs. Public config is `maxGraphMemoryBytes` - with `-1` auto, positive explicit override, known-length auto - `inputBytes * 8 + 64 KiB`, and stream/unknown auto `128 MiB`. `ReadContext` - may expose only raw byte reservation and generic counted-byte arithmetic; + with fixed `128 MiB` default. Positive explicit values override the default; + explicit non-positive values intentionally disable graph-memory enforcement. + Byte-array, memory-buffer, and stream roots use the same configured/default + budget behavior. `ReadContext` + may expose only raw byte reservation; collection, map, array, struct, and object formulas belong in the concrete serializer or generated serializer owner. Java collection, map, and object-array owners reserve nonzero shallow self cost plus reference storage; diff --git a/.agents/languages/javascript.md b/.agents/languages/javascript.md index c97fce2be9..56a14c630a 100644 --- a/.agents/languages/javascript.md +++ b/.agents/languages/javascript.md @@ -15,9 +15,10 @@ Load this file when changing `javascript/`. - Keep `TypeInfo` as schema metadata. Compatibility-sensitive decisions belong on `TypeResolver` or explicit operations, not as retained resolver state on metadata objects. - Normalize optional boolean config values at config construction; do not carry `null` through runtime paths when it means `false`. - JavaScript root deserialization graph memory budgeting belongs to `ReadContext`. - `maxGraphMemoryBytes` uses `-1` auto, positive explicit limits, and known - `Uint8Array` root length as `inputBytes * 8 + 64 KiB`. `ReadContext` may expose only raw - byte reservation and generic counted-byte arithmetic; generated and dynamic + `maxGraphMemoryBytes` uses a fixed `128 MiB` default, positive explicit limits override it, and + explicit non-positive values intentionally disable graph-memory enforcement. Do not derive the + budget from the `Uint8Array` root length. `ReadContext` may expose only raw + byte reservation; generated and dynamic list/set/map/array/struct/object readers must reserve before allocation while preserving existing byte checks. Lists/sets/object arrays reserve nonzero owner self cost plus 4-byte reference slots, maps reserve nonzero owner self cost plus key/value reference storage, object/struct readers diff --git a/.agents/languages/python.md b/.agents/languages/python.md index 0964eb4459..08db459019 100644 --- a/.agents/languages/python.md +++ b/.agents/languages/python.md @@ -14,9 +14,11 @@ Load this file when changing `python/`, Cython serialization, or Python xlang be - Use explicit Cython fields and methods for fixed hot-path shapes. Avoid `__getattr__`, generic `object` fields, public bridge internals, or `Fory` backreferences where ownership can stay explicit. - Keep Python and Cython context/ref-tracking branch conditions and stack mutations semantically aligned unless a documented intentional difference exists. - Root deserialization graph memory budgets are owned by pure-Python and Cython `ReadContext`. - Keep `max_graph_memory_bytes` public on `pyfory.Fory`/`Config`; `-1` uses known-length - `inputBytes * 8 + 64 KiB` or fixed `128 MiB` for stream roots. `ReadContext` may expose only raw - byte reservation and generic counted-byte arithmetic; collection, dict, array, struct, and object + Keep `max_graph_memory_bytes` public on `pyfory.Fory`/`Config`; the default effective limit is + fixed `128 MiB`, positive explicit values override it, and explicit non-positive values + intentionally disable graph-memory enforcement. Byte and stream roots use the same + configured/default budget behavior. `ReadContext` may expose only raw + byte reservation; collection, dict, array, struct, and object formulas belong in the pure-Python or Cython serializer owner. Lists, tuples, sets, and object-dtype ndarray item storage reserve nonzero owner self cost plus `count * PyObject*`; dicts reserve nonzero owner self cost plus `entryCount * 2 * PyObject*`. Python object owners reserve a diff --git a/.agents/languages/rust.md b/.agents/languages/rust.md index 92db836eda..2c14970133 100644 --- a/.agents/languages/rust.md +++ b/.agents/languages/rust.md @@ -19,9 +19,11 @@ Load this file when changing `rust/` or Rust xlang behavior. - For panic-safety in hot paths, preserve TLS context reuse. Add scoped guards or owned fallbacks rather than per-call context allocation, and reset reused contexts at entry and successful exit. - Compatible scalar, list-array, and binary/uint8-array adaptations are immediate-field-only. Keep recursive matched-field shape classification owned by `fory-core/src/meta/type_meta.rs`; collection elements, array elements, map keys, and map values must require exact nullability, ref tracking, generic arity, and type shape except documented user-type family normalization. - Root deserialization graph memory budget state belongs to `ReadContext` and is initialized by the - root `Fory` read methods before the header is consumed. Rust roots are byte-slice/`Reader` backed, - so auto budget uses `inputBytes * 8 + 64 KiB`; do not add dynamic bytes-read accounting. - `ReadContext` may expose only raw byte reservation and generic counted-byte arithmetic; `Vec`, + root `Fory` read methods before the header is consumed. Use the fixed `128 MiB` default unless a + positive explicit value overrides it or an explicit non-positive value intentionally disables + graph-memory enforcement; do not derive the budget from root input size or add dynamic bytes-read + accounting. + `ReadContext` may expose only raw byte reservation; `Vec`, collection, map, array, struct, object, and derive codec formulas belong in their serializer owners. - Rust `Vec` stores inline element storage, so general LIST paths reserve diff --git a/AGENTS.md b/AGENTS.md index 81ffbd39ae..81ac87c2dd 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -33,7 +33,7 @@ This is the entry point for AI guidance in Apache Fory. Read this file first, th - Check the spec before implementation. For wire behavior and xlang mapping, use the specs as the source of truth and never copy one runtime's bug into another runtime just to make tests pass. - Do not make assumptions about runtime behavior, ownership, registration, metadata construction, protocol semantics, or test coverage. Read the current code, owning docs/specs, and relevant tests before making a design judgment or implementation decision. If the evidence is incomplete, inspect more or state the uncertainty explicitly instead of filling gaps from memory or analogy with another runtime. - For untrusted deserialization, read `docs/security/deserialization.md` before changing allocation, stream filling, skip, reference, metadata, or policy validation behavior. Variable-length deserialization must not allocate or reserve backing/output capacity from attacker-declared lengths or counts before the byte owner has proven proportional readable bytes with `checkReadableBytes` or the runtime equivalent. Root graph memory reservation is accounting only and may happen before that byte check, but it must not replace the byte check. -- Root deserialization graph memory budgets estimate lower-bound shallow memory for materialized graph owners, not exact heap accounting and not raw element counts. Positive `maxGraphMemoryBytes` wins; `-1` auto uses `inputBytes * 8 + 64 KiB` for known-length roots and fixed `128 MiB` for true stream or unknown-length roots. Read context/read state owns only raw byte accounting plus generic counted-byte arithmetic, such as reserving `bytes` or `count * elementBytes` with overflow checks; it must not expose collection, map, array, struct, or object semantic reservation APIs. Concrete serializers and generated serializers own formulas for their allocation path. Reserve self storage exactly once at the owner that stores or allocates the value: reference/object runtimes reserve parent owner self cost plus reference storage and every referenced heap owner reserves its own shallow self cost when materialized; inline/value runtimes reserve value storage in the root, struct/product, collection, map, set, array, smart-pointer, or box owner that actually owns that storage, and nested value serializers do not unconditionally reserve their own `sizeof(T)` again. Collection, map, set, and reference-array owners reserve nonzero shallow self cost when independently materialized, plus backing/reference/inline storage. Struct, record, POJO, tuple/product, compatible, generated, and dynamic object owners reserve a nonzero shallow self cost plus shallow field storage. Reference fields use 4 bytes when reference size is not cheap or reliable to query; primitive/value fields use their storage width. Parents do not recursively include child object, collection, map, string, binary, or primitive dense-array contents. Skip enum/union as separate owners and skip dedicated string, binary, primitive scalar, primitive array, and primitive dense-array leaf owners. Do not guess allocator, bucket-table, node, debug, or map-entry overhead unless it is a documented lower-bound owner allocation. Do not add dynamic stream bytes-read accounting or nested hot-path cleanup just for this budget. +- Root deserialization graph memory budgets estimate lower-bound shallow memory for materialized graph owners, not exact heap accounting, input byte accounting, or raw element counts. `maxGraphMemoryBytes` defaults to fixed `128 MiB`; positive values override the default; explicit non-positive values intentionally disable graph-memory enforcement and must be documented as deserialization DoS risk for compact inputs that materialize large graphs. Do not derive this budget from root input size, and do not split known-length and stream root behavior. Read context/read state owns only raw byte reservation with `reserveGraphMemory(bytes)`; it must not expose counted arithmetic helpers or collection, map, array, struct, or object semantic reservation APIs. Concrete serializers and generated serializers own counted formulas and overflow checks for their allocation path. Reserve self storage exactly once at the owner that stores or allocates the value: reference/object runtimes reserve parent owner self cost plus reference storage and every referenced heap owner reserves its own shallow self cost when materialized; inline/value runtimes reserve value storage in the root, struct/product, collection, map, set, array, smart-pointer, or box owner that actually owns that storage, and nested value serializers do not unconditionally reserve their own `sizeof(T)` again. Collection, map, set, and reference-array owners reserve nonzero shallow self cost when independently materialized, plus backing/reference/inline storage. Struct, record, POJO, tuple/product, compatible, generated, and dynamic object owners reserve a nonzero shallow self cost plus shallow field storage. Reference fields use 4 bytes when reference size is not cheap or reliable to query; primitive/value fields use their storage width. Parents do not recursively include child object, collection, map, string, binary, or primitive dense-array contents. Skip enum/union as separate owners and skip dedicated string, binary, primitive scalar, primitive array, and primitive dense-array leaf owners. Do not guess allocator, bucket-table, node, debug, or map-entry overhead unless it is a documented lower-bound owner allocation. Do not add dynamic stream bytes-read accounting or nested hot-path cleanup just for this budget. - For remote TypeDef/TypeMeta reads, the checked metadata cache is the only owner of remote "already validated" state. Cache hit means the header was previously parsed, body/hash-validated, policy-checked, and published by that cache, so the hot path must skip the body and use cached metadata without extra validation, hashing, limit checks, exact-local checks, allocation, or policy work. A known expected local TypeDef/TypeMeta header/hash match is a local-schema hit, not a remote cache miss: it may skip the body and use the local TypeInfo/TypeMeta without schema-version counting or cache publish. Cache miss is the only path that parses and validates non-local metadata, enforces limits, performs exact-local byte comparison when needed, and publishes remote metadata to the cache. Do not add nullable accepted-header fields, sentinel headers, per-TypeInfo markers, pending metadata state, parallel header-low/header-high slots, or parallel acceptance state for this decision. If a runtime needs a metadata hit hint, cache the concrete checked metadata owner object, such as the TypeInfo, TypeDef, or TypeMeta used by that runtime, and compare its validated header identity directly. - When a user corrects a non-obvious invariant, encode it in the nearest source comment before continuing, and also update `AGENTS.md`, `.agents/**`, docs, or specs when the rule is reusable beyond one file. Do not rely only on chat history, task notes, commit messages, or benchmark logs for corrections that protect security, protocol behavior, ownership, naming, or hot-path performance. - Reject semantic hacks. Do not bypass broken semantics by deleting cases, simplifying callers, adding coercion hooks, or using workaround fallbacks; fix the underlying bug and prove it with focused tests. diff --git a/cpp/fory/meta/field_info.h b/cpp/fory/meta/field_info.h index 0da5d22b07..4e066c72a8 100644 --- a/cpp/fory/meta/field_info.h +++ b/cpp/fory/meta/field_info.h @@ -740,8 +740,8 @@ constexpr auto concat_tuples_from_tuple(const Tuple &tuple) { static inline constexpr size_t Size = 0; \ static inline constexpr std::string_view Name = #type; \ static inline constexpr std::array Names = {}; \ - static constexpr bool has_config = false; \ - static inline constexpr auto entries = std::tuple{}; \ + [[maybe_unused]] static constexpr bool has_config = false; \ + [[maybe_unused]] static inline constexpr auto entries = std::tuple{}; \ [[maybe_unused]] static constexpr size_t field_count = 0; \ using PtrsType = decltype(std::tuple{}); \ static constexpr PtrsType ptrs() { return {}; } \ diff --git a/cpp/fory/serialization/collection_serializer.h b/cpp/fory/serialization/collection_serializer.h index 9eee2711cd..bf1e9cc46b 100644 --- a/cpp/fory/serialization/collection_serializer.h +++ b/cpp/fory/serialization/collection_serializer.h @@ -390,6 +390,28 @@ constexpr size_t collection_element_memory_bytes() { return sizeof(Elem); } +template +inline bool reserve_collection_storage(ReadContext &ctx, uint32_t length) { + constexpr size_t kMaxLength = + static_cast(std::numeric_limits::max()); + if constexpr (elem_bytes <= + std::numeric_limits::max() / kMaxLength) { + return ctx.reserve_graph_memory(static_cast(length) * elem_bytes); + } else { + if (FORY_PREDICT_FALSE( + elem_bytes != 0 && + static_cast(length) > + std::numeric_limits::max() / elem_bytes)) { + ctx.set_error(Error::invalid_data( + "graph memory estimate overflows: length=" + + std::to_string(length) + " elementBytes=" + + std::to_string(elem_bytes))); + return false; + } + return ctx.reserve_graph_memory(static_cast(length) * elem_bytes); + } +} + template inline bool reserve_collection(Container &result, ReadContext &ctx, uint32_t length) { @@ -400,7 +422,7 @@ inline bool reserve_collection(Container &result, ReadContext &ctx, } constexpr size_t elem_bytes = collection_element_memory_bytes(); if (FORY_PREDICT_FALSE( - (!ctx.template reserve_counted_graph_memory(length)))) { + (!reserve_collection_storage(ctx, length)))) { return false; } if (FORY_PREDICT_FALSE(!ctx.buffer().ensure_readable(length, ctx.error()))) { @@ -418,7 +440,15 @@ inline bool reserve_collection(std::vector &result, if (FORY_PREDICT_FALSE(ctx.has_error())) { return false; } - const size_t packed_bytes = (static_cast(length) + 7) / 8; + const size_t length_bytes = static_cast(length); + if (FORY_PREDICT_FALSE(length_bytes > + std::numeric_limits::max() - 7)) { + ctx.set_error(Error::invalid_data( + "graph memory estimate overflows: length=" + + std::to_string(length) + " elementBytes=1")); + return false; + } + const size_t packed_bytes = (length_bytes + 7) / 8; if (FORY_PREDICT_FALSE(!ctx.reserve_graph_memory(packed_bytes))) { return false; } diff --git a/cpp/fory/serialization/config.h b/cpp/fory/serialization/config.h index fc46221506..910f72c37a 100644 --- a/cpp/fory/serialization/config.h +++ b/cpp/fory/serialization/config.h @@ -52,9 +52,9 @@ struct Config { /// When enabled, avoids duplicating shared objects and handles cycles. bool track_ref = true; - /// Maximum estimated graph memory accepted during one root - /// deserialization. `-1` selects the automatic input-shaped limit. - int64_t max_graph_memory_bytes = -1; + /// Maximum estimated graph memory accepted during one root deserialization. + /// Positive values are byte limits; non-positive values disable enforcement. + int64_t max_graph_memory_bytes = 128LL * 1024LL * 1024LL; /// Maximum accepted field count in one received struct TypeMeta. uint32_t max_type_fields = 512; diff --git a/cpp/fory/serialization/context.cc b/cpp/fory/serialization/context.cc index 0af96a03c1..0c3045efc6 100644 --- a/cpp/fory/serialization/context.cc +++ b/cpp/fory/serialization/context.cc @@ -739,54 +739,11 @@ const TypeInfo *ReadContext::read_any_type_info(Error &error) { return result.value(); } -bool ReadContext::reserve_counted_graph_checked(uint32_t length, - size_t elem_bytes) { - if (FORY_PREDICT_FALSE(elem_bytes != 0 && - static_cast(length) > - std::numeric_limits::max() / elem_bytes)) { - return set_graph_memory_overflow(length, elem_bytes); - } - return reserve_graph_memory(static_cast(length) * elem_bytes); -} - -bool ReadContext::init_explicit_graph_budget(int64_t configured) { - const uint64_t limit = static_cast(configured); - if constexpr (sizeof(size_t) < sizeof(uint64_t)) { - if (FORY_PREDICT_FALSE(limit > static_cast( - std::numeric_limits::max()))) { - return set_graph_memory_error( - "max_graph_memory_bytes does not fit size_t"); - } - } - remaining_graph_memory_bytes_ = static_cast(limit); - graph_budget_state_ = kGraphBudgetReady; - return true; -} - -bool ReadContext::materialize_graph_budget() { - switch (graph_budget_state_) { - case kGraphBudgetPendingKnown: - return init_graph_budget_known(pending_graph_root_bytes_); - case kGraphBudgetPendingUnknown: - return init_graph_budget_unknown(); - default: - return true; - } -} - bool ReadContext::set_graph_memory_error(const std::string &message) { set_error(Error::invalid_data(message)); return false; } -bool ReadContext::set_graph_memory_overflow(uint32_t length, - size_t elem_bytes) { - set_error(Error::invalid_data( - "graph memory estimate overflows: length=" + std::to_string(length) + - " elementBytes=" + std::to_string(elem_bytes))); - return false; -} - bool ReadContext::set_graph_memory_exceeded(size_t bytes, size_t remaining) { set_error(Error::invalid_data( "estimated graph memory request " + std::to_string(bytes) + @@ -803,9 +760,8 @@ void ReadContext::reset() { } reading_type_infos_.clear(); current_dyn_depth_ = 0; - // Root deserialization initializes the graph budget before reading the - // header; direct ReadContext users start with the unlimited sentinel fields. - // Leave those fields untouched here so root guard cleanup stays store-light. + graph_budget_enabled_ = false; + remaining_graph_memory_bytes_ = std::numeric_limits::max(); if (meta_string_table_active_) { meta_string_table_.reset(); meta_string_table_active_ = false; diff --git a/cpp/fory/serialization/context.h b/cpp/fory/serialization/context.h index ee4a6e9dbc..c405795be9 100644 --- a/cpp/fory/serialization/context.h +++ b/cpp/fory/serialization/context.h @@ -505,28 +505,11 @@ class ReadContext { } } - FORY_ALWAYS_INLINE bool init_graph_budget_known(size_t root_bytes) { - const int64_t configured = config_->max_graph_memory_bytes; - if (FORY_PREDICT_FALSE(configured > 0)) { - return init_explicit_graph_budget(configured); - } - if constexpr (sizeof(size_t) <= sizeof(uint32_t)) { - constexpr size_t max_root_bytes = - (std::numeric_limits::max() - kKnownGraphBudgetSlackBytes) / - kKnownGraphBudgetMultiplier; - if (FORY_PREDICT_FALSE(root_bytes > max_root_bytes)) { - return set_graph_memory_error( - "root input size overflows automatic graph memory budget"); - } - } - remaining_graph_memory_bytes_ = - root_bytes * kKnownGraphBudgetMultiplier + kKnownGraphBudgetSlackBytes; - graph_budget_state_ = kGraphBudgetReady; - return true; + FORY_ALWAYS_INLINE bool init_graph_budget() { + return init_graph_budget(0); } - FORY_ALWAYS_INLINE bool init_graph_budget_known(size_t root_bytes, - size_t reserve_bytes) { + FORY_ALWAYS_INLINE bool init_graph_budget(size_t reserve_bytes) { const int64_t configured = config_->max_graph_memory_bytes; if (FORY_PREDICT_FALSE(configured > 0)) { if constexpr (sizeof(size_t) < sizeof(uint64_t)) { @@ -540,61 +523,14 @@ class ReadContext { return init_graph_budget_limit(static_cast(configured), reserve_bytes); } - if constexpr (sizeof(size_t) <= sizeof(uint32_t)) { - constexpr size_t max_root_bytes = - (std::numeric_limits::max() - kKnownGraphBudgetSlackBytes) / - kKnownGraphBudgetMultiplier; - if (FORY_PREDICT_FALSE(root_bytes > max_root_bytes)) { - return set_graph_memory_error( - "root input size overflows automatic graph memory budget"); - } - } - return init_graph_budget_limit(root_bytes * kKnownGraphBudgetMultiplier + - kKnownGraphBudgetSlackBytes, - reserve_bytes); - } - - FORY_ALWAYS_INLINE bool init_graph_budget_unknown() { - const int64_t configured = config_->max_graph_memory_bytes; - if (FORY_PREDICT_FALSE(configured > 0)) { - return init_explicit_graph_budget(configured); - } - remaining_graph_memory_bytes_ = kUnknownGraphBudgetBytes; - graph_budget_state_ = kGraphBudgetReady; + graph_budget_enabled_ = false; + remaining_graph_memory_bytes_ = std::numeric_limits::max(); return true; } - FORY_ALWAYS_INLINE bool init_graph_budget_unknown(size_t reserve_bytes) { - const int64_t configured = config_->max_graph_memory_bytes; - if (FORY_PREDICT_FALSE(configured > 0)) { - if constexpr (sizeof(size_t) < sizeof(uint64_t)) { - if (FORY_PREDICT_FALSE( - static_cast(configured) > - static_cast(std::numeric_limits::max()))) { - return set_graph_memory_error( - "max_graph_memory_bytes does not fit size_t"); - } - } - return init_graph_budget_limit(static_cast(configured), - reserve_bytes); - } - return init_graph_budget_limit(kUnknownGraphBudgetBytes, reserve_bytes); - } - - FORY_ALWAYS_INLINE void defer_graph_budget_known(size_t root_bytes) { - pending_graph_root_bytes_ = root_bytes; - graph_budget_state_ = kGraphBudgetPendingKnown; - } - - FORY_ALWAYS_INLINE void defer_graph_budget_unknown() { - graph_budget_state_ = kGraphBudgetPendingUnknown; - } - FORY_ALWAYS_INLINE bool reserve_graph_memory(size_t bytes) { - if (FORY_PREDICT_FALSE(graph_budget_state_ != kGraphBudgetReady)) { - if (FORY_PREDICT_FALSE(!materialize_graph_budget())) { - return false; - } + if (FORY_PREDICT_FALSE(!graph_budget_enabled_)) { + return true; } const size_t remaining = remaining_graph_memory_bytes_; if (FORY_PREDICT_FALSE(bytes > remaining)) { @@ -606,7 +542,7 @@ class ReadContext { FORY_ALWAYS_INLINE bool init_graph_budget_limit(size_t limit, size_t reserve_bytes) { - graph_budget_state_ = kGraphBudgetReady; + graph_budget_enabled_ = true; if (FORY_PREDICT_FALSE(reserve_bytes > limit)) { return set_graph_memory_exceeded(reserve_bytes, limit); } @@ -614,18 +550,6 @@ class ReadContext { return true; } - template - FORY_ALWAYS_INLINE bool reserve_counted_graph_memory(uint32_t length) { - constexpr size_t kMaxLength = - static_cast(std::numeric_limits::max()); - if constexpr (elem_bytes <= - std::numeric_limits::max() / kMaxLength) { - return reserve_graph_memory(static_cast(length) * elem_bytes); - } else { - return reserve_counted_graph_checked(length, elem_bytes); - } - } - // =========================================================================== // Read methods with Error& parameter // All methods accept Error& as parameter for reduced overhead. @@ -781,23 +705,10 @@ class ReadContext { inline const Config &config() const { return *config_; } private: - static constexpr size_t kKnownGraphBudgetMultiplier = 8; - static constexpr size_t kKnownGraphBudgetSlackBytes = 64 * 1024; - static constexpr size_t kUnknownGraphBudgetBytes = 128ULL * 1024ULL * 1024ULL; - static constexpr uint8_t kGraphBudgetReady = 0; - static constexpr uint8_t kGraphBudgetPendingKnown = 1; - static constexpr uint8_t kGraphBudgetPendingUnknown = 2; - FORY_NOINLINE Result check_remote_type_meta_limit(const TypeMeta &type_meta); void record_remote_type_meta(const std::string &type_key); - FORY_NOINLINE bool reserve_counted_graph_checked(uint32_t length, - size_t elem_bytes); - FORY_NOINLINE bool init_explicit_graph_budget(int64_t configured); - FORY_NOINLINE bool materialize_graph_budget(); FORY_NOINLINE bool set_graph_memory_error(const std::string &message); - FORY_NOINLINE bool set_graph_memory_overflow(uint32_t length, - size_t elem_bytes); FORY_NOINLINE bool set_graph_memory_exceeded(size_t bytes, size_t remaining); // Error state - accumulated during deserialization, checked at the end @@ -808,8 +719,7 @@ class ReadContext { std::unique_ptr type_resolver_; RefReader ref_reader_; uint32_t current_dyn_depth_; - uint8_t graph_budget_state_ = kGraphBudgetReady; - size_t pending_graph_root_bytes_ = 0; + bool graph_budget_enabled_ = false; size_t remaining_graph_memory_bytes_ = std::numeric_limits::max(); // Meta sharing state (for compatible mode) diff --git a/cpp/fory/serialization/fory.h b/cpp/fory/serialization/fory.h index 084df9d4d0..fc0fd862da 100644 --- a/cpp/fory/serialization/fory.h +++ b/cpp/fory/serialization/fory.h @@ -112,10 +112,9 @@ class ForyBuilder { /// Set maximum estimated graph memory for one root deserialization. /// - /// Use `-1` for automatic limits. Positive values are explicit byte limits. + /// Defaults to 128 MiB. Positive values are explicit byte limits; non-positive + /// values intentionally disable this protection. ForyBuilder &max_graph_memory_bytes(int64_t max_bytes) { - FORY_CHECK(max_bytes == -1 || max_bytes > 0) - << "max_graph_memory_bytes must be positive or -1 for auto"; config_.max_graph_memory_bytes = max_bytes; return *this; } @@ -514,19 +513,7 @@ class BaseFory { /// Protected constructor - only derived classes can instantiate. explicit BaseFory(const Config &config, std::shared_ptr resolver) - : config_(config), type_resolver_(std::move(resolver)) { - auto_graph_memory_budget_ = config_.max_graph_memory_bytes <= 0; - explicit_graph_memory_bytes_ = std::numeric_limits::max(); - if (config_.max_graph_memory_bytes > 0) { - if constexpr (sizeof(size_t) < sizeof(uint64_t)) { - graph_budget_limit_fits_size_t_ = - static_cast(config_.max_graph_memory_bytes) <= - static_cast(std::numeric_limits::max()); - } - explicit_graph_memory_bytes_ = - static_cast(config_.max_graph_memory_bytes); - } - } + : config_(config), type_resolver_(std::move(resolver)) {} // Non-copyable BaseFory(const BaseFory &) = delete; @@ -538,9 +525,6 @@ class BaseFory { Config config_; std::shared_ptr type_resolver_; - size_t explicit_graph_memory_bytes_ = std::numeric_limits::max(); - bool auto_graph_memory_budget_ = true; - bool graph_budget_limit_fits_size_t_ = true; mutable std::mutex registration_mutex_; mutable bool registration_locked_{false}; }; @@ -699,7 +683,7 @@ class Fory : public BaseFory { Buffer buffer(const_cast(data), static_cast(size), false); - return deserialize_buffer(buffer); + return deserialize_buffer(buffer); } /// Deserialize an object from a byte vector. @@ -725,7 +709,7 @@ class Fory : public BaseFory { if (FORY_PREDICT_FALSE(!finalized_)) { ensure_finalized(); } - return deserialize_buffer(buffer); + return deserialize_buffer(buffer); } /// Deserialize an object from an input stream. @@ -751,7 +735,7 @@ class Fory : public BaseFory { if (FORY_PREDICT_FALSE(!finalized_)) { ensure_finalized(); } - return deserialize_buffer(buffer); + return deserialize_buffer(buffer); } /// Deserialize an object from StdInputStream. @@ -840,92 +824,6 @@ class Fory : public BaseFory { ", local xlang=" + std::string(config_.xlang ? "true" : "false")); } - FORY_NOINLINE Error root_graph_config_too_large() const { - return Error::invalid_data("max_graph_memory_bytes does not fit size_t"); - } - - FORY_NOINLINE Error root_graph_budget_overflow() const { - return Error::invalid_data( - "root input size overflows automatic graph memory budget"); - } - - FORY_NOINLINE Error root_graph_budget_exceeded(size_t bytes, - size_t limit) const { - return Error::invalid_data( - "estimated graph memory request " + std::to_string(bytes) + - " bytes exceeds max_graph_memory_bytes remaining budget " + - std::to_string(limit) + " bytes"); - } - - template - FORY_ALWAYS_INLINE bool reserve_root_graph_self(size_t root_bytes) { - if constexpr (unknown_root) { - if constexpr (root_owner_bytes <= ReadContext::kUnknownGraphBudgetBytes) { - if (FORY_PREDICT_TRUE(auto_graph_memory_budget_)) { - return true; - } - } - } else if constexpr (root_owner_bytes <= - ReadContext::kKnownGraphBudgetSlackBytes) { - if (FORY_PREDICT_TRUE(auto_graph_memory_budget_)) { - return true; - } - } - if constexpr (sizeof(size_t) < sizeof(uint64_t)) { - if (FORY_PREDICT_FALSE(!graph_budget_limit_fits_size_t_)) { - read_ctx_->set_error(root_graph_config_too_large()); - return false; - } - } - if (FORY_PREDICT_FALSE(root_owner_bytes > explicit_graph_memory_bytes_)) { - if constexpr (sizeof(size_t) < sizeof(uint64_t)) { - if (FORY_PREDICT_FALSE(!graph_budget_limit_fits_size_t_)) { - read_ctx_->set_error(root_graph_config_too_large()); - return false; - } - } - if (FORY_PREDICT_FALSE(config_.max_graph_memory_bytes > 0)) { - read_ctx_->set_error(root_graph_budget_exceeded( - root_owner_bytes, explicit_graph_memory_bytes_)); - return false; - } - } - if constexpr (unknown_root) { - if constexpr (root_owner_bytes > ReadContext::kUnknownGraphBudgetBytes) { - if (FORY_PREDICT_FALSE(config_.max_graph_memory_bytes > 0)) { - return true; - } - read_ctx_->set_error(root_graph_budget_exceeded( - root_owner_bytes, ReadContext::kUnknownGraphBudgetBytes)); - return false; - } - } else if constexpr (root_owner_bytes > - ReadContext::kKnownGraphBudgetSlackBytes) { - if (FORY_PREDICT_FALSE(config_.max_graph_memory_bytes > 0)) { - return true; - } - if constexpr (sizeof(size_t) <= sizeof(uint32_t)) { - constexpr size_t max_root_bytes = - (std::numeric_limits::max() - - ReadContext::kKnownGraphBudgetSlackBytes) / - ReadContext::kKnownGraphBudgetMultiplier; - if (FORY_PREDICT_FALSE(root_bytes > max_root_bytes)) { - read_ctx_->set_error(root_graph_budget_overflow()); - return false; - } - } - const size_t limit = - root_bytes * ReadContext::kKnownGraphBudgetMultiplier + - ReadContext::kKnownGraphBudgetSlackBytes; - if (FORY_PREDICT_FALSE(root_owner_bytes > limit)) { - read_ctx_->set_error( - root_graph_budget_exceeded(root_owner_bytes, limit)); - return false; - } - } - return true; - } - /// Core serialization implementation. /// TypeMeta is written inline using streaming protocol (no deferred writing). template @@ -975,10 +873,8 @@ class Fory : public BaseFory { return result; } - template + template FORY_ALWAYS_INLINE Result deserialize_buffer(Buffer &buffer) { - const size_t root_bytes = unknown_root ? 0 : buffer.remaining_size(); - Error header_error; const uint8_t header = buffer.read_uint8(header_error); if (FORY_PREDICT_FALSE(!header_error.ok())) { @@ -995,30 +891,13 @@ class Fory : public BaseFory { constexpr size_t root_owner_bytes = graph_value_owner_self_bytes(); constexpr bool has_child_budget = has_graph_budget_children_v; if constexpr (root_owner_bytes != 0) { - if constexpr (has_child_budget) { - if constexpr (unknown_root) { - if (FORY_PREDICT_FALSE( - !read_ctx_->init_graph_budget_unknown(root_owner_bytes))) { - return Unexpected(read_ctx_->take_error()); - } - } else { - if (FORY_PREDICT_FALSE(!read_ctx_->init_graph_budget_known( - root_bytes, root_owner_bytes))) { - return Unexpected(read_ctx_->take_error()); - } - } - } else { - if (FORY_PREDICT_FALSE( - (!reserve_root_graph_self( - root_bytes)))) { - return Unexpected(read_ctx_->take_error()); - } + if (FORY_PREDICT_FALSE( + !read_ctx_->init_graph_budget(root_owner_bytes))) { + return Unexpected(read_ctx_->take_error()); } } else if constexpr (has_child_budget) { - if constexpr (unknown_root) { - read_ctx_->defer_graph_budget_unknown(); - } else { - read_ctx_->defer_graph_budget_known(root_bytes); + if (FORY_PREDICT_FALSE(!read_ctx_->init_graph_budget())) { + return Unexpected(read_ctx_->take_error()); } } } diff --git a/cpp/fory/serialization/graph_memory_budget_test.cc b/cpp/fory/serialization/graph_memory_budget_test.cc index e212d9ee56..b8edb86abc 100644 --- a/cpp/fory/serialization/graph_memory_budget_test.cc +++ b/cpp/fory/serialization/graph_memory_budget_test.cc @@ -38,7 +38,7 @@ namespace fory { namespace serialization { namespace { -constexpr size_t kKnownBudgetSlack = 64 * 1024; +constexpr int64_t kDefaultGraphMemoryBytes = 128LL * 1024LL * 1024LL; struct BudgetItem { int32_t id = 0; @@ -94,7 +94,8 @@ template auto with_fory(int64_t max_graph_memory_bytes, Fn &&fn) { } template std::vector serialize_value(const T &value) { - auto bytes = with_fory(-1, [&](Fory &fory) { return fory.serialize(value); }); + auto bytes = with_fory(kDefaultGraphMemoryBytes, + [&](Fory &fory) { return fory.serialize(value); }); EXPECT_TRUE(bytes.ok()) << bytes.error().to_string(); return std::move(bytes).value(); } @@ -125,40 +126,58 @@ void expect_budget_boundary(const T &value, size_t required) { EXPECT_EQ(exact_result.value(), value); } -TEST(GraphMemoryBudgetTest, KnownLengthAutoBudget) { +TEST(GraphMemoryBudgetTest, FixedDefaultBudgetAndDisable) { Config config; - config.max_graph_memory_bytes = -1; ReadContext context(config, std::make_unique()); - constexpr size_t root_bytes = 17; - const size_t expected = root_bytes * 8 + kKnownBudgetSlack; - ASSERT_TRUE(context.init_graph_budget_known(root_bytes)); - ASSERT_TRUE(context.reserve_graph_memory(expected)); + ASSERT_TRUE(context.init_graph_budget()); + ASSERT_TRUE(context.reserve_graph_memory( + static_cast(kDefaultGraphMemoryBytes))); ASSERT_FALSE(context.reserve_graph_memory(1)); EXPECT_EQ(context.take_error().code(), ErrorCode::InvalidData); + + Config disabled_config; + disabled_config.max_graph_memory_bytes = 0; + ReadContext disabled(disabled_config, std::make_unique()); + ASSERT_TRUE(disabled.init_graph_budget()); + ASSERT_TRUE(disabled.reserve_graph_memory(std::numeric_limits::max())); } -TEST(GraphMemoryBudgetTest, StreamAutoBudget) { - constexpr size_t count = 10000; +TEST(GraphMemoryBudgetTest, RootKindsShareConfiguredBudget) { + constexpr size_t count = 3; std::vector> value(count); auto bytes = serialize_value(value); - const size_t known_limit = bytes.size() * 8 + kKnownBudgetSlack; - ASSERT_GT(nested_empty_budget(count), known_limit); + const size_t required = nested_empty_budget(count); - auto known_result = with_fory(-1, [&](Fory &fory) { - return fory.deserialize>>(bytes); - }); - ASSERT_FALSE(known_result.ok()); - EXPECT_EQ(known_result.error().code(), ErrorCode::InvalidData); + auto byte_result = with_fory(static_cast(required - 1), + [&](Fory &fory) { + return fory.deserialize< + std::vector>>( + bytes); + }); + ASSERT_FALSE(byte_result.ok()); + EXPECT_EQ(byte_result.error().code(), ErrorCode::InvalidData); std::string input(reinterpret_cast(bytes.data()), bytes.size()); std::istringstream source(input); StdInputStream stream(source, 8); - auto stream_result = with_fory(-1, [&](Fory &fory) { - return fory.deserialize>>(stream); + auto stream_result = with_fory(static_cast(required - 1), + [&](Fory &fory) { + return fory.deserialize< + std::vector>>( + stream); + }); + ASSERT_FALSE(stream_result.ok()); + EXPECT_EQ(stream_result.error().code(), ErrorCode::InvalidData); + + std::istringstream exact_source(input); + StdInputStream exact_stream(exact_source, 8); + auto exact_result = with_fory(static_cast(required), [&](Fory &fory) { + return fory.deserialize>>( + exact_stream); }); - ASSERT_TRUE(stream_result.ok()) << stream_result.error().to_string(); - EXPECT_EQ(stream_result.value(), value); + ASSERT_TRUE(exact_result.ok()) << exact_result.error().to_string(); + EXPECT_EQ(exact_result.value(), value); } TEST(GraphMemoryBudgetTest, ExplicitOverride) { diff --git a/cpp/fory/serialization/map_serializer.h b/cpp/fory/serialization/map_serializer.h index 82b41c322e..9d95d5c042 100644 --- a/cpp/fory/serialization/map_serializer.h +++ b/cpp/fory/serialization/map_serializer.h @@ -82,6 +82,28 @@ struct MapReserver +inline bool reserve_map_storage(ReadContext &ctx, uint32_t length) { + constexpr size_t kMaxLength = + static_cast(std::numeric_limits::max()); + if constexpr (elem_bytes <= + std::numeric_limits::max() / kMaxLength) { + return ctx.reserve_graph_memory(static_cast(length) * elem_bytes); + } else { + if (FORY_PREDICT_FALSE( + elem_bytes != 0 && + static_cast(length) > + std::numeric_limits::max() / elem_bytes)) { + ctx.set_error(Error::invalid_data( + "graph memory estimate overflows: length=" + + std::to_string(length) + " elementBytes=" + + std::to_string(elem_bytes))); + return false; + } + return ctx.reserve_graph_memory(static_cast(length) * elem_bytes); + } +} + template inline bool reserve_map(MapType &map, ReadContext &ctx, uint32_t length) { // Lazy error propagation may continue into later readers; do not let that @@ -97,8 +119,7 @@ inline bool reserve_map(MapType &map, ReadContext &ctx, uint32_t length) { std::numeric_limits::max() - sizeof(Value), "map entry memory estimate overflows"); constexpr size_t elem_bytes = sizeof(Key) + sizeof(Value); - if (FORY_PREDICT_FALSE( - (!ctx.template reserve_counted_graph_memory(length)))) { + if (FORY_PREDICT_FALSE((!reserve_map_storage(ctx, length)))) { return false; } if (FORY_PREDICT_FALSE(!ctx.buffer().ensure_readable(length, ctx.error()))) { diff --git a/csharp/src/Fory/Config.cs b/csharp/src/Fory/Config.cs index 207ca40b77..4aeed16315 100644 --- a/csharp/src/Fory/Config.cs +++ b/csharp/src/Fory/Config.cs @@ -38,12 +38,6 @@ internal Config( { throw new ArgumentOutOfRangeException(nameof(maxDepth), "MaxDepth must be greater than 0."); } - if (maxGraphMemoryBytes != -1 && maxGraphMemoryBytes <= 0) - { - throw new ArgumentOutOfRangeException( - nameof(maxGraphMemoryBytes), - "MaxGraphMemoryBytes must be positive or -1 for auto."); - } if (maxTypeFields <= 0) { throw new ArgumentOutOfRangeException(nameof(maxTypeFields), "MaxTypeFields must be greater than 0."); @@ -127,7 +121,7 @@ public sealed class ForyBuilder private bool? _compatible; private bool _checkStructVersion; private int _maxDepth = 20; - private long _maxGraphMemoryBytes = -1; + private long _maxGraphMemoryBytes = 128L * 1024 * 1024; private int _maxTypeFields = 512; private int _maxTypeMetaBytes = 4096; private int _maxSchemaVersionsPerType = 10; @@ -185,17 +179,10 @@ public ForyBuilder MaxDepth(int value) /// /// Sets the maximum estimated graph memory accepted during one root deserialization. - /// Use -1 for the automatic root-size-based limit, or a positive byte limit. + /// Positive values are byte limits. Explicit non-positive values disable this budget. /// public ForyBuilder MaxGraphMemoryBytes(long value) { - if (value != -1 && value <= 0) - { - throw new ArgumentOutOfRangeException( - nameof(value), - "MaxGraphMemoryBytes must be positive or -1 for auto."); - } - _maxGraphMemoryBytes = value; return this; } diff --git a/csharp/src/Fory/Fory.cs b/csharp/src/Fory/Fory.cs index 6591682b7b..79bfa9dfa6 100644 --- a/csharp/src/Fory/Fory.cs +++ b/csharp/src/Fory/Fory.cs @@ -190,7 +190,7 @@ public T Deserialize(ReadOnlySpan payload) { ByteReader reader = _readContext.Reader; reader.Reset(payload); - _readContext.InitGraphBudgetKnown(payload.Length); + _readContext.InitGraphBudget(); T value = DeserializeFromReader(reader); if (reader.Remaining != 0) { @@ -211,7 +211,7 @@ public T Deserialize(byte[] payload) { ByteReader reader = _readContext.Reader; reader.Reset(payload); - _readContext.InitGraphBudgetKnown(payload.Length); + _readContext.InitGraphBudget(); T value = DeserializeFromReader(reader); if (reader.Remaining != 0) { @@ -232,7 +232,7 @@ public T Deserialize(ref ReadOnlySequence payload) byte[] bytes = payload.ToArray(); ByteReader reader = _readContext.Reader; reader.Reset(bytes); - _readContext.InitGraphBudgetKnown(bytes.Length); + _readContext.InitGraphBudget(); T value = DeserializeFromReader(reader); payload = payload.Slice(reader.Cursor); return value; diff --git a/csharp/src/Fory/ReadContext.cs b/csharp/src/Fory/ReadContext.cs index 5098fdd532..9cece92aa6 100644 --- a/csharp/src/Fory/ReadContext.cs +++ b/csharp/src/Fory/ReadContext.cs @@ -22,7 +22,6 @@ namespace Apache.Fory; public sealed class ReadContext { private const int MinRemoteTypeMetaLimit = 8192; - internal const long KnownGraphBudgetSlackBytes = 64 * 1024; private readonly ReusableArray _typeMetaRefs = new(); private readonly UInt64Map _typeMetasByHeader = new(); @@ -76,19 +75,20 @@ public ReadContext( internal RefReader RefReader { get; } [MethodImpl(MethodImplOptions.AggressiveInlining)] - internal void InitGraphBudgetKnown(int rootBytes) + internal void InitGraphBudget() { long limit = _config.MaxGraphMemoryBytes; - if (limit < 0) + if (limit <= 0) { - limit = (long)rootBytes * 8 + KnownGraphBudgetSlackBytes; + _graphMemoryLimitBytes = 0; + _remainingGraphMemoryBytes = long.MaxValue; + return; } _graphMemoryLimitBytes = limit; _remainingGraphMemoryBytes = limit; } - [MethodImpl(MethodImplOptions.AggressiveInlining)] /// /// Reserves estimated graph memory for the current root deserialization. /// @@ -96,40 +96,24 @@ internal void InitGraphBudgetKnown(int rootBytes) /// Serializer owners compute owner-specific formulas and pass raw bytes here. This /// accounting does not replace byte-availability checks before backing allocation. /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] public void ReserveGraphMemory(long bytes) { - long remaining = _remainingGraphMemoryBytes; - if ((ulong)bytes > (ulong)remaining) + if (bytes < 0) { - ThrowGraphBudgetExceeded(bytes, remaining, _graphMemoryLimitBytes); + ThrowGraphBudgetOverflow(); } - - _remainingGraphMemoryBytes = remaining - bytes; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - /// - /// Reserves multiplied by estimated - /// graph-owner bytes for the current root deserialization. - /// - /// - /// This helper owns only overflow-safe arithmetic; concrete serializers and generated - /// serializers still own the collection, array, and map storage formulas. - /// - public void ReserveCountedGraphMemory(int count, long elementBytes) - { - if (count < 0 || elementBytes < 0) + if (_graphMemoryLimitBytes <= 0) { - ThrowGraphBudgetOverflow(); + return; } - - uint length = (uint)count; - if (elementBytes != 0 && length > long.MaxValue / elementBytes) + long remaining = _remainingGraphMemoryBytes; + if (bytes > remaining) { - ThrowGraphBudgetOverflow(); + ThrowGraphBudgetExceeded(bytes, remaining, _graphMemoryLimitBytes); } - ReserveGraphMemory((long)length * elementBytes); + _remainingGraphMemoryBytes = remaining - bytes; } [MethodImpl(MethodImplOptions.NoInlining)] diff --git a/csharp/tests/Fory.Tests/GraphMemoryBudgetTests.cs b/csharp/tests/Fory.Tests/GraphMemoryBudgetTests.cs index fee7ce963c..2971dcf065 100644 --- a/csharp/tests/Fory.Tests/GraphMemoryBudgetTests.cs +++ b/csharp/tests/Fory.Tests/GraphMemoryBudgetTests.cs @@ -85,10 +85,11 @@ public sealed class GraphMemoryBudgetTests private const long BudgetArrayHolderBytes = ObjectBytes + ReferenceBytes; private const long GeneratedGraphHolderBytes = ObjectBytes + ReferenceBytes; private const long BudgetValueBytes = 4; + private const long DefaultGraphMemoryBytes = 128L * 1024 * 1024; private static int ElementBytes() => typeof(T).IsValueType ? Unsafe.SizeOf() : ReferenceBytes; - private static ForyRuntime NewFory(long maxGraphMemoryBytes = -1) + private static ForyRuntime NewFory(long maxGraphMemoryBytes = DefaultGraphMemoryBytes) { return ForyRuntime.Builder() .Compatible(false) @@ -126,19 +127,24 @@ private static long MapBudget(int count) } [Fact] - public void KnownLengthAutoBudgetUsesInputBytes() + public void DefaultFixedBudgetAndDisable() { - const int rootBytes = 17; - long expected = rootBytes * 8 + ReadContext.KnownGraphBudgetSlackBytes; - ReadContext context = new(new ByteReader([]), new TypeResolver(), NewFory().Config); + Assert.Equal(DefaultGraphMemoryBytes, NewFory().Config.MaxGraphMemoryBytes); + Assert.Equal(0, NewFory(0).Config.MaxGraphMemoryBytes); + Assert.Equal(-2, NewFory(-2).Config.MaxGraphMemoryBytes); - context.InitGraphBudgetKnown(rootBytes); - context.ReserveGraphMemory(expected); + ReadContext context = new(new ByteReader([]), new TypeResolver(), NewFory().Config); + context.InitGraphBudget(); + context.ReserveGraphMemory(DefaultGraphMemoryBytes); Assert.Throws(() => context.ReserveGraphMemory(ReferenceBytes)); + + ReadContext disabled = new(new ByteReader([]), new TypeResolver(), NewFory(0).Config); + disabled.InitGraphBudget(); + disabled.ReserveGraphMemory(long.MaxValue); } [Fact] - public void ReadOnlySequenceUsesKnownLengthRoot() + public void ReadOnlySequenceUsesSameBudget() { const int count = 6; List> value = Enumerable.Range(0, count).Select(_ => new List()).ToList(); @@ -149,7 +155,7 @@ public void ReadOnlySequenceUsesKnownLengthRoot() } [Fact] - public void ExplicitConfigOverridesAutoBudget() + public void ExplicitConfigOverridesDefault() { List value = Enumerable.Range(0, 8).Select(i => new BudgetItem { Id = i }).ToList(); byte[] bytes = Serialize(value); diff --git a/dart/packages/fory/lib/src/config.dart b/dart/packages/fory/lib/src/config.dart index 3fea4e99df..6232285dbe 100644 --- a/dart/packages/fory/lib/src/config.dart +++ b/dart/packages/fory/lib/src/config.dart @@ -28,7 +28,7 @@ final class Config { static const int defaultMaxTypeMetaBytes = 4096; static const int defaultMaxSchemaVersionsPerType = 10; static const int defaultMaxAverageSchemaVersionsPerType = 3; - static const int defaultMaxGraphMemoryBytes = -1; + static const int defaultMaxGraphMemoryBytes = 128 * 1024 * 1024; /// Enables compatible struct encoding and decoding. /// @@ -59,7 +59,7 @@ final class Config { /// Maximum estimated graph memory per root deserialization. /// - /// `-1` means auto. Positive values are explicit byte limits. + /// Positive values are explicit byte limits. Non-positive values disable enforcement. final int maxGraphMemoryBytes; /// Creates an immutable configuration object. @@ -87,9 +87,5 @@ final class Config { assert( maxAverageSchemaVersionsPerType > 0, 'maxAverageSchemaVersionsPerType must be positive', - ), - assert( - maxGraphMemoryBytes == -1 || maxGraphMemoryBytes > 0, - 'maxGraphMemoryBytes must be -1 or positive', ); } diff --git a/dart/packages/fory/lib/src/context/read_context.dart b/dart/packages/fory/lib/src/context/read_context.dart index a47b382c33..a6b78262af 100644 --- a/dart/packages/fory/lib/src/context/read_context.dart +++ b/dart/packages/fory/lib/src/context/read_context.dart @@ -45,8 +45,6 @@ import 'package:fory/src/types/uint64.dart'; /// deserialization operation. Application code normally interacts with [Fory] /// instead of preparing contexts directly. final class ReadContext { - static const int _knownRootBudgetMultiplier = 8; - static const int _knownRootBudgetSlackBytes = 64 * 1024; static const int _maxSafeBudgetBytes = 9007199254740991; /// Effective runtime configuration for the active operation. @@ -74,11 +72,7 @@ final class ReadContext { void prepare(Buffer buffer) { _buffer = buffer; final configured = config.maxGraphMemoryBytes; - final limit = - configured > 0 - ? configured - : buffer.readableBytes * _knownRootBudgetMultiplier + - _knownRootBudgetSlackBytes; + final limit = configured > 0 ? configured : 0; if (limit > _maxSafeBudgetBytes) { _throwGraphMemoryOverflow(limit); } @@ -117,6 +111,9 @@ final class ReadContext { if (bytes < 0 || bytes > _maxSafeBudgetBytes) { _throwGraphMemoryOverflow(bytes); } + if (_effectiveGraphMemoryBytes <= 0) { + return; + } final remaining = _remainingGraphMemoryBytes - bytes; if (remaining < 0) { _throwGraphMemoryExceeded(bytes); diff --git a/dart/packages/fory/lib/src/fory.dart b/dart/packages/fory/lib/src/fory.dart index b39f3a0db1..bebc7f7896 100644 --- a/dart/packages/fory/lib/src/fory.dart +++ b/dart/packages/fory/lib/src/fory.dart @@ -64,14 +64,6 @@ final class Fory { Config.defaultMaxAverageSchemaVersionsPerType, int maxGraphMemoryBytes = Config.defaultMaxGraphMemoryBytes, }) { - if (maxGraphMemoryBytes != Config.defaultMaxGraphMemoryBytes && - maxGraphMemoryBytes <= 0) { - throw ArgumentError.value( - maxGraphMemoryBytes, - 'maxGraphMemoryBytes', - 'must be -1 or positive', - ); - } final config = Config( compatible: compatible, checkStructVersion: checkStructVersion, diff --git a/dart/packages/fory/test/graph_memory_budget_test.dart b/dart/packages/fory/test/graph_memory_budget_test.dart index 2cb2269978..60318b5c1b 100644 --- a/dart/packages/fory/test/graph_memory_budget_test.dart +++ b/dart/packages/fory/test/graph_memory_budget_test.dart @@ -30,6 +30,7 @@ import 'package:test/test.dart'; part 'graph_memory_budget_test.fory.dart'; const Matcher _throwsGraphBudget = ThrowsGraphBudget(); +const int _defaultGraphMemoryBytes = 128 * 1024 * 1024; const int _objectBytes = 1; const int _referenceBytes = 4; @@ -113,7 +114,10 @@ void _registerCompatibleArray(Fory fory) { ); } -ReadContext _readContext(Buffer buffer, {int maxGraphMemoryBytes = -1}) { +ReadContext _readContext( + Buffer buffer, { + int maxGraphMemoryBytes = _defaultGraphMemoryBytes, +}) { final config = Config(maxGraphMemoryBytes: maxGraphMemoryBytes); final resolver = TypeResolver(config); return ReadContext(config, resolver, RefReader(), MetaStringReader(resolver)) @@ -130,27 +134,36 @@ Object? _readWithBudget(Object? value, int budget) { void main() { group('graph memory budget', () { - test('known length auto derives from input bytes', () { + test('fixed default applies to roots', () { final buffer = Buffer.wrap(Uint8List(17)); final context = _readContext(buffer); - expect(context.effectiveGraphMemoryBytes, equals(17 * 8 + 64 * 1024)); expect( - () => context.reserveGraphMemory(17 * 8 + 64 * 1024), + context.effectiveGraphMemoryBytes, + equals(_defaultGraphMemoryBytes), + ); + expect( + () => context.reserveGraphMemory(_defaultGraphMemoryBytes), returnsNormally, ); expect(() => context.reserveGraphMemory(1), _throwsGraphBudget); }); - test('explicit config overrides auto', () { + test('explicit config overrides default and non-positive disables', () { final buffer = Buffer.wrap(Uint8List(4096)); final context = _readContext(buffer, maxGraphMemoryBytes: 31); expect(context.effectiveGraphMemoryBytes, equals(31)); expect(() => context.reserveGraphMemory(31), returnsNormally); expect(() => context.reserveGraphMemory(1), _throwsGraphBudget); - expect(() => Fory(maxGraphMemoryBytes: 0), throwsArgumentError); - expect(() => Fory(maxGraphMemoryBytes: -2), throwsArgumentError); + + final disabled = _readContext(buffer, maxGraphMemoryBytes: 0); + expect(disabled.effectiveGraphMemoryBytes, equals(0)); + expect( + () => disabled.reserveGraphMemory(_defaultGraphMemoryBytes + 1), + returnsNormally, + ); + expect(() => Fory(maxGraphMemoryBytes: -2), returnsNormally); }); test('uses parent storage for nested empty containers', () { diff --git a/docs/guide/cpp/configuration.md b/docs/guide/cpp/configuration.md index bd13ff7e27..93900070e8 100644 --- a/docs/guide/cpp/configuration.md +++ b/docs/guide/cpp/configuration.md @@ -107,10 +107,10 @@ auto fory = Fory::builder() .build(); ``` -Use `-1` for the automatic limit. For byte-array and `Buffer` roots, the -automatic limit is the root input size multiplied by `8`, plus `64 KiB`. For -stream roots, the automatic limit is `128 MiB` because the full root size is not -known up front. Positive values always override the automatic limit. +The default limit is a fixed `128 MiB` for byte-array, `Buffer`, and stream +roots. Positive values override the default. Explicit non-positive values +disable this budget and can expose deserialization DoS risk from compact inputs +that materialize large object graphs. This budget is a portable lower-bound estimate for shallow materialized graph owners such as dynamic collection backing storage, map key/value storage, @@ -227,18 +227,18 @@ auto fory = Fory::builder().build_thread_safe(); // Returns ThreadSafeFory ## Configuration Summary -| Option | Description | Default | -| ------------------------------------------------ | ------------------------------------------------- | ------- | -| `xlang(bool)` | Use xlang mode | `true` | -| `compatible(bool)` | Enable schema evolution | `true` | -| `track_ref(bool)` | Enable reference tracking | `true` | -| `max_graph_memory_bytes(int64_t)` | Max estimated graph memory per root read | `-1` | -| `max_dyn_depth(uint32_t)` | Maximum nesting depth for dynamic types | `5` | -| `max_type_fields(uint32_t)` | Max fields in one received struct metadata body | `512` | -| `max_type_meta_bytes(uint32_t)` | Max encoded bytes in one received metadata body | `4096` | -| `max_schema_versions_per_type(uint32_t)` | Max remote metadata versions for one logical type | `10` | -| `max_average_schema_versions_per_type(uint32_t)` | Average remote metadata versions across types | `3` | -| `check_struct_version(bool)` | Enable struct version checking | `false` | +| Option | Description | Default | +| ------------------------------------------------ | ------------------------------------------------- | --------- | +| `xlang(bool)` | Use xlang mode | `true` | +| `compatible(bool)` | Enable schema evolution | `true` | +| `track_ref(bool)` | Enable reference tracking | `true` | +| `max_graph_memory_bytes(int64_t)` | Max estimated graph memory per root read | `128 MiB` | +| `max_dyn_depth(uint32_t)` | Maximum nesting depth for dynamic types | `5` | +| `max_type_fields(uint32_t)` | Max fields in one received struct metadata body | `512` | +| `max_type_meta_bytes(uint32_t)` | Max encoded bytes in one received metadata body | `4096` | +| `max_schema_versions_per_type(uint32_t)` | Max remote metadata versions for one logical type | `10` | +| `max_average_schema_versions_per_type(uint32_t)` | Average remote metadata versions across types | `3` | +| `check_struct_version(bool)` | Enable struct version checking | `false` | ## Security diff --git a/docs/guide/csharp/configuration.md b/docs/guide/csharp/configuration.md index 5c045316f5..d08c622791 100644 --- a/docs/guide/csharp/configuration.md +++ b/docs/guide/csharp/configuration.md @@ -35,17 +35,17 @@ ThreadSafeFory threadSafe = Fory.Builder().BuildThreadSafe(); `Fory.Builder().Build()` uses: -| Option | Default | Description | -| --------------------------------- | ------- | ------------------------------------------------- | -| `TrackRef` | `false` | Reference tracking disabled | -| `Compatible` | `true` | Compatible schema-evolution metadata enabled | -| `CheckStructVersion` | `false` | Struct schema hash checks disabled | -| `MaxDepth` | `20` | Max dynamic nesting depth | -| `MaxGraphMemoryBytes` | `-1` | Auto graph memory budget | -| `MaxTypeFields` | `512` | Max fields in one received struct metadata body | -| `MaxTypeMetaBytes` | `4096` | Max encoded bytes in one received metadata body | -| `MaxSchemaVersionsPerType` | `10` | Max remote metadata versions for one logical type | -| `MaxAverageSchemaVersionsPerType` | `3` | Average remote metadata versions across types | +| Option | Default | Description | +| --------------------------------- | ----------- | ------------------------------------------------- | +| `TrackRef` | `false` | Reference tracking disabled | +| `Compatible` | `true` | Compatible schema-evolution metadata enabled | +| `CheckStructVersion` | `false` | Struct schema hash checks disabled | +| `MaxDepth` | `20` | Max dynamic nesting depth | +| `MaxGraphMemoryBytes` | `134217728` | Graph memory budget | +| `MaxTypeFields` | `512` | Max fields in one received struct metadata body | +| `MaxTypeMetaBytes` | `4096` | Max encoded bytes in one received metadata body | +| `MaxSchemaVersionsPerType` | `10` | Max remote metadata versions for one logical type | +| `MaxAverageSchemaVersionsPerType` | `3` | Average remote metadata versions across types | ## Builder Options @@ -107,9 +107,9 @@ Fory fory = Fory.Builder() .Build(); ``` -Use `-1` for the default automatic limit. For current C# inputs, auto uses the root input byte -length times `8`, plus `64 KiB`. A positive value overrides the automatic limit. `0` and negative -values other than `-1` are rejected. +The default limit is a fixed `128 MiB` for all root input forms. A positive value overrides the +default. Passing an explicit non-positive value disables this budget and can expose deserialization +DoS risk from compact inputs that materialize large object graphs. ### `MaxTypeFields(int value)` diff --git a/docs/guide/dart/configuration.md b/docs/guide/dart/configuration.md index 2c8b6baea3..d832c28d7b 100644 --- a/docs/guide/dart/configuration.md +++ b/docs/guide/dart/configuration.md @@ -115,12 +115,7 @@ Dart lists, sets, maps, object/reference arrays, structs, objects, and compatibl materialization. It does not count strings, binary values, or dense typed-array payloads, which are protected by byte-availability checks. -The default is `-1`, which means auto. Dart root inputs are memory-backed, so auto derives from the -root input size: - -```text -inputBytes * 8 + 64 KiB -``` +The default is a fixed `128 MiB` and is not derived from input size. Set a positive value when a trusted workload legitimately contains compact, container-heavy payloads: @@ -129,18 +124,21 @@ payloads: final fory = Fory(maxGraphMemoryBytes: 256 * 1024 * 1024); ``` +Passing an explicit non-positive value disables this budget and can expose deserialization DoS risk +from compact inputs that materialize large object graphs. + ## Defaults -| Option | Default | -| --------------------------------- | ------- | -| `compatible` | `true` | -| `checkStructVersion` | `false` | -| `maxDepth` | 256 | -| `maxTypeFields` | 512 | -| `maxTypeMetaBytes` | 4096 | -| `maxSchemaVersionsPerType` | 10 | -| `maxAverageSchemaVersionsPerType` | 3 | -| `maxGraphMemoryBytes` | -1 | +| Option | Default | +| --------------------------------- | --------- | +| `compatible` | `true` | +| `checkStructVersion` | `false` | +| `maxDepth` | 256 | +| `maxTypeFields` | 512 | +| `maxTypeMetaBytes` | 4096 | +| `maxSchemaVersionsPerType` | 10 | +| `maxAverageSchemaVersionsPerType` | 3 | +| `maxGraphMemoryBytes` | 134217728 | ## Xlang Notes @@ -157,8 +155,8 @@ Security-related configuration: - Register only the expected generated models before deserializing untrusted payloads. - Use `checkStructVersion: true` with `compatible: false` for intentional same-schema payloads. - Set `maxDepth` to reject unexpectedly deep payload shapes. -- Keep `maxGraphMemoryBytes` at the auto default for most inputs, or set an explicit positive byte - limit for known trusted graph-heavy payloads. +- Keep `maxGraphMemoryBytes` at the default for most inputs, or set an explicit positive byte limit + for known trusted graph-heavy payloads. Avoid disabling it for untrusted data. - Keep the remote schema metadata limits at their defaults unless the data is not malicious and a trusted peer sends larger metadata or many schema versions. - Prefer generated schemas and explicit field metadata over broad dynamic fields for untrusted input. diff --git a/docs/guide/go/configuration.md b/docs/guide/go/configuration.md index c260312c29..af21f22c29 100644 --- a/docs/guide/go/configuration.md +++ b/docs/guide/go/configuration.md @@ -33,17 +33,17 @@ f := fory.New(fory.WithXlang(true)) Default settings: -| Option | Default | Description | -| ------------------------------- | ------- | ------------------------------------------------- | -| TrackRef | false | Reference tracking disabled | -| MaxDepth | 20 | Maximum nesting depth | -| IsXlang | true | Xlang mode enabled | -| Compatible | true | Compatible schema-evolution metadata enabled | -| MaxGraphMemoryBytes | -1 | Automatic graph memory limit per root read | -| MaxTypeFields | 512 | Max fields in one received struct metadata body | -| MaxTypeMetaBytes | 4096 | Max encoded bytes in one received metadata body | -| MaxSchemaVersionsPerType | 10 | Max remote metadata versions for one logical type | -| MaxAverageSchemaVersionsPerType | 3 | Average remote metadata versions across types | +| Option | Default | Description | +| ------------------------------- | --------- | ------------------------------------------------- | +| TrackRef | false | Reference tracking disabled | +| MaxDepth | 20 | Maximum nesting depth | +| IsXlang | true | Xlang mode enabled | +| Compatible | true | Compatible schema-evolution metadata enabled | +| MaxGraphMemoryBytes | 134217728 | Graph memory limit per root read | +| MaxTypeFields | 512 | Max fields in one received struct metadata body | +| MaxTypeMetaBytes | 4096 | Max encoded bytes in one received metadata body | +| MaxSchemaVersionsPerType | 10 | Max remote metadata versions for one logical type | +| MaxAverageSchemaVersionsPerType | 3 | Average remote metadata versions across types | ### With Options @@ -137,19 +137,14 @@ Limit estimated shallow graph memory accepted during one root deserialization: f := fory.New(fory.WithMaxGraphMemoryBytes(256 * 1024 * 1024)) ``` -The default `-1` selects an automatic limit. Byte-slice roots use: - -```text -inputBytes * 8 + 64 KiB -``` - -`DeserializeFromReader` and `DeserializeFromStream` use `128 MiB` because the -full root length is unknown. The budget covers lower-bound slice backing -storage, map key/value storage, sets, generated object reads, and materialized -struct field storage. Strings, binary blobs, and primitive dense array owners -keep their byte-availability checks and are not reserved against this budget. -Set a positive value when a service needs a stricter or larger limit for trusted -data. +The default limit is a fixed `128 MiB` for byte-slice and stream roots. A +positive value overrides the default. Passing an explicit non-positive value +disables this budget and can expose deserialization DoS risk from compact inputs +that materialize large object graphs. The budget covers lower-bound slice +backing storage, map key/value storage, sets, generated object reads, and +materialized struct field storage. Strings, binary blobs, and primitive dense +array owners keep their byte-availability checks and are not reserved against +this budget. ### WithMaxTypeFields diff --git a/docs/guide/java/configuration.md b/docs/guide/java/configuration.md index e5bf5d461e..4aef1c5ba9 100644 --- a/docs/guide/java/configuration.md +++ b/docs/guide/java/configuration.md @@ -38,7 +38,7 @@ This page documents all configuration options available through `ForyBuilder`. | `registerGuavaTypes` | Whether to pre-register Guava types such as `RegularImmutableMap`/`RegularImmutableList`. These types are not public API, but seem pretty stable. | `true` | | `requireClassRegistration` | Disabling may allow unknown classes to be deserialized, potentially causing security risks. | `true` | | `maxDepth` | Set max depth for deserialization, when depth exceeds, an exception will be thrown. This can be used to refuse deserialization DDOS attack. | `50` | -| `maxGraphMemoryBytes` | Maximum estimated shallow graph memory accepted during one root deserialization. `-1` derives an automatic limit from the input shape: known-length inputs use `inputBytes * 8 + 64 KiB`, and stream or unknown-length inputs use `128 MiB`. Positive values set an explicit byte limit. | `-1` | +| `maxGraphMemoryBytes` | Maximum estimated shallow graph memory accepted during one root deserialization. The default is a fixed `128 MiB`; positive values set an explicit byte limit. Explicit non-positive values disable this budget and can expose deserialization DoS risk from compact inputs that materialize large object graphs. | `134217728` | | `maxTypeFields` | Maximum fields accepted in one received remote struct metadata body. | `512` | | `maxTypeMetaBytes` | Maximum encoded body bytes accepted for one received TypeDef or TypeMeta body, excluding the 8-byte header and any extended-size varint. | `4096` | | `maxSchemaVersionsPerType` | Maximum accepted remote metadata versions for one logical type. | `10` | @@ -100,8 +100,10 @@ Security-related options: - `requireClassRegistration(true)` restricts deserialization to registered classes. - `withMaxDepth(...)` rejects unexpectedly deep object graphs. - `withMaxGraphMemoryBytes(...)` bounds estimated shallow graph memory during one root - deserialization. Keep `-1` for the automatic input-shaped default, or set a positive byte limit - when trusted payloads need a larger or smaller limit. + deserialization. The default is a fixed `128 MiB`; set a positive byte limit when trusted + workloads need a larger or smaller limit. Passing an explicit non-positive value disables this + budget and can expose deserialization DoS risk from compact inputs that materialize large object + graphs. - `withMaxTypeFields(...)` and `withMaxTypeMetaBytes(...)` bound the field count and encoded body size of one received remote metadata body. - `withMaxSchemaVersionsPerType(...)` and diff --git a/docs/guide/javascript/configuration.md b/docs/guide/javascript/configuration.md index 0db0f5d7ab..6b02d50828 100644 --- a/docs/guide/javascript/configuration.md +++ b/docs/guide/javascript/configuration.md @@ -43,7 +43,7 @@ const fory = new Fory({ ref: true, compatible: true, maxDepth: 100, - maxGraphMemoryBytes: -1, + maxGraphMemoryBytes: 128 * 1024 * 1024, maxTypeFields: 512, maxTypeMetaBytes: 4096, maxSchemaVersionsPerType: 10, @@ -52,19 +52,19 @@ const fory = new Fory({ }); ``` -| Option | Default | Description | -| --------------------------------- | ------- | ------------------------------------------------------------------------------------- | -| `ref` | `false` | Enable reference tracking for shared or circular object graphs | -| `compatible` | `true` | Allow field additions/removals without breaking existing messages | -| `maxDepth` | `50` | Maximum nesting depth. Must be `>= 2`. Increase for deeply nested structures | -| `maxGraphMemoryBytes` | `-1` | Maximum estimated shallow graph memory accepted during one root deserialization | -| `maxTypeFields` | `512` | Maximum fields accepted in one received remote struct metadata body | -| `maxTypeMetaBytes` | `4096` | Maximum encoded body bytes accepted for one received TypeMeta body | -| `maxSchemaVersionsPerType` | `10` | Maximum accepted remote metadata versions for one logical type | -| `maxAverageSchemaVersionsPerType` | `3` | Average accepted remote metadata versions across accepted remote types | -| `useSliceString` | `false` | Optional string-reading optimization for Node.js. Leave at default unless benchmarked | -| `hps` | unset | Optional fast string helper from `@apache-fory/hps` (Node.js 20+) | -| `hooks.afterCodeGenerated` | unset | Callback to inspect the generated serializer code, useful for debugging | +| Option | Default | Description | +| --------------------------------- | --------- | ------------------------------------------------------------------------------------- | +| `ref` | `false` | Enable reference tracking for shared or circular object graphs | +| `compatible` | `true` | Allow field additions/removals without breaking existing messages | +| `maxDepth` | `50` | Maximum nesting depth. Must be `>= 2`. Increase for deeply nested structures | +| `maxGraphMemoryBytes` | `128 MiB` | Maximum estimated shallow graph memory accepted during one root deserialization | +| `maxTypeFields` | `512` | Maximum fields accepted in one received remote struct metadata body | +| `maxTypeMetaBytes` | `4096` | Maximum encoded body bytes accepted for one received TypeMeta body | +| `maxSchemaVersionsPerType` | `10` | Maximum accepted remote metadata versions for one logical type | +| `maxAverageSchemaVersionsPerType` | `3` | Average accepted remote metadata versions across accepted remote types | +| `useSliceString` | `false` | Optional string-reading optimization for Node.js. Leave at default unless benchmarked | +| `hps` | unset | Optional fast string helper from `@apache-fory/hps` (Node.js 20+) | +| `hooks.afterCodeGenerated` | unset | Callback to inspect the generated serializer code, useful for debugging | ## Reference Tracking @@ -99,10 +99,7 @@ generated from Fory schema IDL. See [Schema Evolution](schema-evolution.md). `maxGraphMemoryBytes` limits estimated shallow graph memory accepted during one root deserialization. The budget covers materialized arrays, sets, object arrays, maps, structs, and objects; it is not an exact JavaScript heap limit. -The default `-1` derives an automatic limit from the input bytes. JavaScript -deserializes from `Uint8Array` roots, so the automatic limit is `inputBytes \* 8 - -- 64 KiB`. +The default is a fixed `128 MiB` and is not derived from input size. Use a positive byte value to set an explicit lower or higher limit: @@ -112,6 +109,10 @@ const fory = new Fory({ }); ``` +Passing an explicit non-positive value disables this budget and can expose +deserialization DoS risk from compact inputs that materialize large object +graphs. + String, binary, and dedicated dense primitive array payloads keep their normal byte-size checks and do not consume this graph budget. Raise the limit only for trusted workloads that legitimately contain very compact object graphs. diff --git a/docs/guide/python/configuration.md b/docs/guide/python/configuration.md index 3acd0b12e8..601daab5af 100644 --- a/docs/guide/python/configuration.md +++ b/docs/guide/python/configuration.md @@ -60,22 +60,22 @@ class ThreadSafeFory: ## Parameters -| Parameter | Type | Default | Description | -| -------------------------------------- | ------------------------------- | ------- | -------------------------------------------------------------------------------------------------------------------------------------------------------- | -| `xlang` | `bool` | `True` | Use xlang mode. Set `False` for Python native mode. | -| `ref` | `bool` | `False` | Enable reference tracking for shared/circular references. Disable for better performance if your data has no shared references. | -| `strict` | `bool` | `True` | Require type registration for security. Keep this enabled for production unless a policy owns trust decisions. | -| `compatible` | `bool \| None` | `None` | Schema evolution mode. `None` enables compatible mode in both xlang and native mode. Set `False` only when every reader and writer uses the same schema. | -| `max_depth` | `int` | `50` | Maximum deserialization depth for security, preventing stack overflow attacks. | -| `max_type_fields` | `int` | `512` | Maximum fields accepted in one received remote struct metadata body. | -| `max_type_meta_bytes` | `int` | `4096` | Maximum encoded body bytes accepted for one received TypeDef body, excluding the 8-byte header and any extended-size varint. | -| `max_schema_versions_per_type` | `int` | `10` | Maximum accepted remote metadata versions for one logical type. | -| `max_average_schema_versions_per_type` | `int` | `3` | Average accepted remote metadata versions across accepted remote types. The effective global floor is `8192` schemas. | -| `max_graph_memory_bytes` | `int` | `-1` | Maximum estimated shallow graph memory for one root deserialization. `-1` selects the automatic limit. | -| `policy` | `DeserializationPolicy \| None` | `None` | Deserialization policy used for security checks. Strongly recommended when `strict=False`. | -| `field_nullable` | `bool` | `False` | Treat dataclass fields as nullable by default. | -| `meta_compressor` | `Any` | `None` | Optional metadata compressor used for compatible-mode metadata encoding. | -| `fory_factory` | `Callable \| None` | `None` | `ThreadSafeFory` factory hook. When set, `ThreadSafeFory` creates instances via this callback; otherwise it forwards `**kwargs` to `Fory` construction. | +| Parameter | Type | Default | Description | +| -------------------------------------- | ------------------------------- | ----------- | -------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `xlang` | `bool` | `True` | Use xlang mode. Set `False` for Python native mode. | +| `ref` | `bool` | `False` | Enable reference tracking for shared/circular references. Disable for better performance if your data has no shared references. | +| `strict` | `bool` | `True` | Require type registration for security. Keep this enabled for production unless a policy owns trust decisions. | +| `compatible` | `bool \| None` | `None` | Schema evolution mode. `None` enables compatible mode in both xlang and native mode. Set `False` only when every reader and writer uses the same schema. | +| `max_depth` | `int` | `50` | Maximum deserialization depth for security, preventing stack overflow attacks. | +| `max_type_fields` | `int` | `512` | Maximum fields accepted in one received remote struct metadata body. | +| `max_type_meta_bytes` | `int` | `4096` | Maximum encoded body bytes accepted for one received TypeDef body, excluding the 8-byte header and any extended-size varint. | +| `max_schema_versions_per_type` | `int` | `10` | Maximum accepted remote metadata versions for one logical type. | +| `max_average_schema_versions_per_type` | `int` | `3` | Average accepted remote metadata versions across accepted remote types. The effective global floor is `8192` schemas. | +| `max_graph_memory_bytes` | `int` | `134217728` | Maximum estimated shallow graph memory for one root deserialization. Explicit non-positive values disable this budget. | +| `policy` | `DeserializationPolicy \| None` | `None` | Deserialization policy used for security checks. Strongly recommended when `strict=False`. | +| `field_nullable` | `bool` | `False` | Treat dataclass fields as nullable by default. | +| `meta_compressor` | `Any` | `None` | Optional metadata compressor used for compatible-mode metadata encoding. | +| `fory_factory` | `Callable \| None` | `None` | `ThreadSafeFory` factory hook. When set, `ThreadSafeFory` creates instances via this callback; otherwise it forwards `**kwargs` to `Fory` construction. | ## Key Methods @@ -227,9 +227,10 @@ Received remote metadata is also limited: - `max_average_schema_versions_per_type` limits the average across accepted remote types. - `max_graph_memory_bytes` limits estimated shallow graph memory created during one root deserialization, including materialized lists, tuples, sets, dicts, object arrays, structs, and - Python objects. The default `-1` uses `input_bytes * 8 + 64 KiB` for known-length inputs and - `128 MiB` for stream inputs. Set a positive byte value for trusted payloads that legitimately - contain larger object graphs. + Python objects. The default is a fixed `128 MiB` for all root input forms. Set a positive byte + value for trusted payloads that legitimately contain larger or smaller object graphs. Passing an + explicit non-positive value disables this budget and can expose deserialization DoS risk from + compact inputs that materialize large object graphs. These limits do not change `strict`, `policy`, dynamic loading, unknown-class handling, or schema-evolution semantics. diff --git a/docs/guide/rust/configuration.md b/docs/guide/rust/configuration.md index 249af3ee71..2a14d4cace 100644 --- a/docs/guide/rust/configuration.md +++ b/docs/guide/rust/configuration.md @@ -114,20 +114,9 @@ let fory = Fory::builder() `max_graph_memory_bytes(...)` limits estimated shallow graph memory accepted during one root read. The budget covers `Vec`/collection element storage, map key/value storage, and materialized struct -or object field storage; it is not an exact process heap limit. The default is `-1`, which selects -an automatic limit based on the input size: - -```rust -let fory = Fory::builder().max_graph_memory_bytes(-1).build(); -``` - -For byte-slice and `Reader` roots, the automatic limit is: - -```text -input bytes * 8 + 64 KiB -``` - -Set a positive byte value when trusted payloads need a larger or smaller limit: +or object field storage; it is not an exact process heap limit. The default is a fixed `128 MiB` for +all root input forms. Set a positive byte value when trusted payloads need a larger or smaller +limit: ```rust let fory = Fory::builder() @@ -135,6 +124,9 @@ let fory = Fory::builder() .build(); ``` +Passing an explicit non-positive value disables this budget and can expose deserialization DoS risk +from compact inputs that materialize large object graphs. + ### Explicit Xlang Examples Set `.xlang(true)` explicitly for xlang serialization examples: @@ -174,16 +166,16 @@ let fory = Fory::builder() ## Configuration Summary -| Option | Description | Default | -| --------------------------------------------- | ------------------------------------------------- | ------- | -| `compatible(bool)` | Enable schema evolution | `true` | -| `xlang(bool)` | Use xlang mode | `true` | -| `max_dyn_depth(u32)` | Maximum nesting depth for dynamic types | `5` | -| `max_graph_memory_bytes(i64)` | Estimated graph memory per root read | `-1` | -| `max_type_fields(usize)` | Max fields in one received struct metadata body | `512` | -| `max_type_meta_bytes(usize)` | Max encoded bytes in one received metadata body | `4096` | -| `max_schema_versions_per_type(usize)` | Max remote metadata versions for one logical type | `10` | -| `max_average_schema_versions_per_type(usize)` | Average remote metadata versions across types | `3` | +| Option | Description | Default | +| --------------------------------------------- | ------------------------------------------------- | --------- | +| `compatible(bool)` | Enable schema evolution | `true` | +| `xlang(bool)` | Use xlang mode | `true` | +| `max_dyn_depth(u32)` | Maximum nesting depth for dynamic types | `5` | +| `max_graph_memory_bytes(i64)` | Estimated graph memory per root read | `128 MiB` | +| `max_type_fields(usize)` | Max fields in one received struct metadata body | `512` | +| `max_type_meta_bytes(usize)` | Max encoded bytes in one received metadata body | `4096` | +| `max_schema_versions_per_type(usize)` | Max remote metadata versions for one logical type | `10` | +| `max_average_schema_versions_per_type(usize)` | Average remote metadata versions across types | `3` | ## Compatible Mode diff --git a/docs/guide/swift/configuration.md b/docs/guide/swift/configuration.md index 9086b4cf3c..75fd8d281b 100644 --- a/docs/guide/swift/configuration.md +++ b/docs/guide/swift/configuration.md @@ -94,9 +94,9 @@ let fory = Fory(compatible: false, checkClassVersion: true) `maxDepth` bounds decoded payload nesting depth. `maxGraphMemoryBytes` bounds estimated shallow graph memory accepted during one root -deserialization. Swift roots are currently `Data` or `ByteBuffer`, so auto uses the root input byte -length times `8`, plus `64 KiB`. Use `-1` for the default automatic limit; a positive value -overrides it. `0` and negative values other than `-1` are rejected. +deserialization. The default limit is a fixed `128 MiB` for all root input forms. A positive value +overrides the default. Passing an explicit non-positive value disables this budget and can expose +deserialization DoS risk from compact inputs that materialize large object graphs. Compatible-mode remote metadata is also limited: @@ -111,7 +111,7 @@ Compatible-mode remote metadata is also limited: ```swift let fory = Fory( maxDepth: 5, - maxGraphMemoryBytes: -1, + maxGraphMemoryBytes: 128 * 1024 * 1024, maxTypeFields: 512, maxTypeMetaBytes: 4096, maxSchemaVersionsPerType: 10, diff --git a/docs/security/deserialization.md b/docs/security/deserialization.md index 618d7e0bc5..2ec48856bf 100644 --- a/docs/security/deserialization.md +++ b/docs/security/deserialization.md @@ -211,17 +211,18 @@ Runtimes should enforce a root-deserialization budget for estimated shallow memo materialized graph. This is cumulative accounting for graph owners created by one root read; it is not exact heap measurement and it is not a raw element-slot limit. -The public configuration is `maxGraphMemoryBytes`. `-1` means automatic input-shaped budgeting. -Positive user configuration always wins. For known-length root input, the automatic budget is -`inputBytes * 8 + 64 KiB`. For true stream or otherwise unknown-length root input, the automatic -budget is fixed at `128 MiB`. Stream budgeting should not depend on dynamic bytes-read accounting. +The public configuration is `maxGraphMemoryBytes`. The default is a fixed `128 MiB` for all root +input forms; positive user configuration overrides the default. Explicit non-positive configuration +disables this budget and can expose deserialization DoS risk from compact inputs that materialize +large object graphs. The budget is not derived from root input size, and stream budgeting should not +depend on dynamic bytes-read accounting. Graph budget accounting should: - happen in root-operation read state, with cleanup owned by the root deserialization `finally`; -- keep read context/read state limited to raw byte reservation and generic counted-byte arithmetic; - collection, map, array, struct, and object storage formulas belong in the concrete serializer or - generated serializer owner; +- keep read context/read state limited to raw byte reservation; counted arithmetic and collection, + map, array, struct, and object storage formulas belong in the concrete serializer or generated + serializer owner; - reject arithmetic overflow before comparing budget or allocating; - estimate lower-bound shallow owner storage: independently materialized collections, maps, sets, and reference arrays reserve nonzero shallow self cost plus backing/reference/inline storage, and diff --git a/docs/specification/xlang_implementation_guide.md b/docs/specification/xlang_implementation_guide.md index 3e5076d0db..4d6c9e867c 100644 --- a/docs/specification/xlang_implementation_guide.md +++ b/docs/specification/xlang_implementation_guide.md @@ -401,17 +401,17 @@ real owner invariant. Materializing readers should also reserve a root-operation estimated graph memory budget before allocation or size hinting. The budget belongs to `ReadContext` or the equivalent root read state, not to serializers and not to -ambient thread-local state. Positive `maxGraphMemoryBytes` configuration wins; -auto configuration uses `inputBytes * 8 + 64 KiB` for known-length root input -and fixed `128 MiB` for true stream or unknown-length root input. Do not add -dynamic stream bytes-read accounting for this budget. - -Read context or equivalent read state owns only raw byte accounting and generic -counted-byte arithmetic, such as reserving `bytes` or `count * elementBytes` -with overflow checks. It must not expose collection, map, array, struct, or +ambient thread-local state. `maxGraphMemoryBytes` defaults to a fixed `128 MiB`; +positive configuration overrides the default; explicit non-positive +configuration disables graph-memory enforcement. Do not derive this budget from +root input size, and do not add dynamic stream bytes-read accounting for this +budget. + +Read context or equivalent read state owns only raw byte reservation. It must +not expose counted arithmetic helpers or collection, map, array, struct, or object semantic reservation APIs. Concrete serializers and generated serializer owners compute the storage constants and formulas for the owner path they -allocate. +allocate, including counted-byte overflow checks. The budget estimates lower-bound shallow memory for materialized graph owners, not exact heap bytes. Reserve self storage exactly once at the owner that stores diff --git a/go/fory/array.go b/go/fory/array.go index 4a7921cba7..4a6b3d2588 100644 --- a/go/fory/array.go +++ b/go/fory/array.go @@ -318,7 +318,12 @@ func (s arrayDynSerializer) Write(ctx *WriteContext, refMode RefMode, writeType func (s arrayDynSerializer) ReadData(ctx *ReadContext, value reflect.Value) { // Create a temp slice to read into, then copy back to array sliceType := reflect.SliceOf(value.Type().Elem()) - if !ctx.ReserveCountedGraphMemory(value.Len(), int64(value.Type().Elem().Size())) { + elemBytes := int64(value.Type().Elem().Size()) + if int64(value.Len()) > maxGraphCount(elemBytes) { + ctx.setGraphMemoryError("graph memory estimate overflows: length=%d elementBytes=%d", value.Len(), elemBytes) + return + } + if !ctx.ReserveGraphMemory(int64(value.Len()) * elemBytes) { return } tempSlice := reflect.MakeSlice(sliceType, value.Len(), value.Len()) diff --git a/go/fory/codegen/decoder.go b/go/fory/codegen/decoder.go index 16e53f8a7f..931e36b614 100644 --- a/go/fory/codegen/decoder.go +++ b/go/fory/codegen/decoder.go @@ -25,6 +25,28 @@ import ( "github.com/apache/fory/go/fory" ) +func writeGraphReservation(buf *bytes.Buffer, indent, countExpr, elemBytesExpr string) { + fmt.Fprintf(buf, "%s{\n", indent) + fmt.Fprintf(buf, "%s\tgraphCount := %s\n", indent, countExpr) + fmt.Fprintf(buf, "%s\tgraphElemBytes := int64(%s)\n", indent, elemBytesExpr) + fmt.Fprintf(buf, "%s\tif graphCount < 0 {\n", indent) + fmt.Fprintf(buf, "%s\t\tctx.SetError(fory.DeserializationErrorf(\"negative graph element count: %%d\", graphCount))\n", indent) + fmt.Fprintf(buf, "%s\t\treturn ctx.TakeError()\n", indent) + fmt.Fprintf(buf, "%s\t}\n", indent) + fmt.Fprintf(buf, "%s\tif graphElemBytes < 0 {\n", indent) + fmt.Fprintf(buf, "%s\t\tctx.SetError(fory.DeserializationErrorf(\"negative graph element size: %%d\", graphElemBytes))\n", indent) + fmt.Fprintf(buf, "%s\t\treturn ctx.TakeError()\n", indent) + fmt.Fprintf(buf, "%s\t}\n", indent) + fmt.Fprintf(buf, "%s\tif graphCount != 0 && graphElemBytes != 0 && int64(graphCount) > fory.MaxInt64/graphElemBytes {\n", indent) + fmt.Fprintf(buf, "%s\t\tctx.SetError(fory.DeserializationErrorf(\"graph memory estimate overflows: length=%%d elementBytes=%%d\", graphCount, graphElemBytes))\n", indent) + fmt.Fprintf(buf, "%s\t\treturn ctx.TakeError()\n", indent) + fmt.Fprintf(buf, "%s\t}\n", indent) + fmt.Fprintf(buf, "%s\tif !ctx.ReserveGraphMemory(int64(graphCount) * graphElemBytes) {\n", indent) + fmt.Fprintf(buf, "%s\t\treturn ctx.TakeError()\n", indent) + fmt.Fprintf(buf, "%s\t}\n", indent) + fmt.Fprintf(buf, "%s}\n", indent) +} + // generateReadTyped generates the strongly-typed ReadData method func generateReadTyped(buf *bytes.Buffer, s *StructInfo) error { fmt.Fprintf(buf, "// ReadTyped provides strongly-typed deserialization with no reflection overhead\n") @@ -178,9 +200,7 @@ func generateFieldReadTyped(buf *bytes.Buffer, field *FieldInfo) error { fmt.Fprintf(buf, "\t\t\tif ctx.HasError() {\n") fmt.Fprintf(buf, "\t\t\t\treturn ctx.TakeError()\n") fmt.Fprintf(buf, "\t\t\t}\n") - fmt.Fprintf(buf, "\t\t\tif !ctx.ReserveCountedGraphMemory(sliceLen, %s) {\n", unsafeSizeExpr(elemType)) - fmt.Fprintf(buf, "\t\t\t\treturn ctx.TakeError()\n") - fmt.Fprintf(buf, "\t\t\t}\n") + writeGraphReservation(buf, "\t\t\t", "sliceLen", unsafeSizeExpr(elemType)) fmt.Fprintf(buf, "\t\t\tif sliceLen == 0 {\n") fmt.Fprintf(buf, "\t\t\t\t%s = make([]any, 0)\n", fieldAccess) fmt.Fprintf(buf, "\t\t\t} else {\n") @@ -209,9 +229,7 @@ func generateFieldReadTyped(buf *bytes.Buffer, field *FieldInfo) error { fmt.Fprintf(buf, "\t\t\t\tif ctx.HasError() {\n") fmt.Fprintf(buf, "\t\t\t\t\treturn ctx.TakeError()\n") fmt.Fprintf(buf, "\t\t\t\t}\n") - fmt.Fprintf(buf, "\t\t\t\tif !ctx.ReserveCountedGraphMemory(sliceLen, %s) {\n", unsafeSizeExpr(elemType)) - fmt.Fprintf(buf, "\t\t\t\t\treturn ctx.TakeError()\n") - fmt.Fprintf(buf, "\t\t\t\t}\n") + writeGraphReservation(buf, "\t\t\t\t", "sliceLen", unsafeSizeExpr(elemType)) fmt.Fprintf(buf, "\t\t\t\tif sliceLen == 0 {\n") fmt.Fprintf(buf, "\t\t\t\t\t%s = make([]any, 0)\n", fieldAccess) fmt.Fprintf(buf, "\t\t\t\t} else {\n") @@ -516,9 +534,7 @@ func generateSliceReadInline(buf *bytes.Buffer, sliceType *types.Slice, fieldAcc fmt.Fprintf(buf, "\t\t\tif ctx.HasError() {\n") fmt.Fprintf(buf, "\t\t\t\treturn ctx.TakeError()\n") fmt.Fprintf(buf, "\t\t\t}\n") - fmt.Fprintf(buf, "\t\t\tif !ctx.ReserveCountedGraphMemory(sliceLen, %s) {\n", unsafeSizeExpr(elemType)) - fmt.Fprintf(buf, "\t\t\t\treturn ctx.TakeError()\n") - fmt.Fprintf(buf, "\t\t\t}\n") + writeGraphReservation(buf, "\t\t\t", "sliceLen", unsafeSizeExpr(elemType)) fmt.Fprintf(buf, "\t\t\tif sliceLen == 0 {\n") fmt.Fprintf(buf, "\t\t\t\t%s = make(%s, 0)\n", fieldAccess, sliceType.String()) fmt.Fprintf(buf, "\t\t\t} else {\n") @@ -537,9 +553,7 @@ func generateSliceReadInline(buf *bytes.Buffer, sliceType *types.Slice, fieldAcc fmt.Fprintf(buf, "\t\t\t\tif ctx.HasError() {\n") fmt.Fprintf(buf, "\t\t\t\t\treturn ctx.TakeError()\n") fmt.Fprintf(buf, "\t\t\t\t}\n") - fmt.Fprintf(buf, "\t\t\t\tif !ctx.ReserveCountedGraphMemory(sliceLen, %s) {\n", unsafeSizeExpr(elemType)) - fmt.Fprintf(buf, "\t\t\t\t\treturn ctx.TakeError()\n") - fmt.Fprintf(buf, "\t\t\t\t}\n") + writeGraphReservation(buf, "\t\t\t\t", "sliceLen", unsafeSizeExpr(elemType)) fmt.Fprintf(buf, "\t\t\t\tif sliceLen == 0 {\n") fmt.Fprintf(buf, "\t\t\t\t\t%s = make(%s, 0)\n", fieldAccess, sliceType.String()) fmt.Fprintf(buf, "\t\t\t\t} else {\n") @@ -566,9 +580,7 @@ func generateSliceReadInlineNoNull(buf *bytes.Buffer, sliceType *types.Slice, fi fmt.Fprintf(buf, "%sif ctx.HasError() {\n", indent) fmt.Fprintf(buf, "%s\treturn ctx.TakeError()\n", indent) fmt.Fprintf(buf, "%s}\n", indent) - fmt.Fprintf(buf, "%sif !ctx.ReserveCountedGraphMemory(sliceLen, %s) {\n", indent, unsafeSizeExpr(elemType)) - fmt.Fprintf(buf, "%s\treturn ctx.TakeError()\n", indent) - fmt.Fprintf(buf, "%s}\n", indent) + writeGraphReservation(buf, indent, "sliceLen", unsafeSizeExpr(elemType)) fmt.Fprintf(buf, "%sif sliceLen == 0 {\n", indent) fmt.Fprintf(buf, "%s\t%s = make([]any, 0)\n", indent, fieldAccess) fmt.Fprintf(buf, "%s} else {\n", indent) @@ -592,9 +604,7 @@ func generateSliceReadInlineNoNull(buf *bytes.Buffer, sliceType *types.Slice, fi fmt.Fprintf(buf, "%sif ctx.HasError() {\n", indent) fmt.Fprintf(buf, "%s\treturn ctx.TakeError()\n", indent) fmt.Fprintf(buf, "%s}\n", indent) - fmt.Fprintf(buf, "%sif !ctx.ReserveCountedGraphMemory(sliceLen, %s) {\n", indent, unsafeSizeExpr(elemType)) - fmt.Fprintf(buf, "%s\treturn ctx.TakeError()\n", indent) - fmt.Fprintf(buf, "%s}\n", indent) + writeGraphReservation(buf, indent, "sliceLen", unsafeSizeExpr(elemType)) fmt.Fprintf(buf, "%sif sliceLen == 0 {\n", indent) fmt.Fprintf(buf, "%s\t%s = make(%s, 0)\n", indent, fieldAccess, sliceType.String()) fmt.Fprintf(buf, "%s} else {\n", indent) @@ -861,9 +871,7 @@ func generateMapReadInline(buf *bytes.Buffer, mapType *types.Map, fieldAccess st fmt.Fprintf(buf, "\t\t\tif ctx.HasError() {\n") fmt.Fprintf(buf, "\t\t\t\treturn ctx.TakeError()\n") fmt.Fprintf(buf, "\t\t\t}\n") - fmt.Fprintf(buf, "\t\t\tif !ctx.ReserveCountedGraphMemory(mapLen, %s + %s) {\n", unsafeSizeExpr(keyType), unsafeSizeExpr(valueType)) - fmt.Fprintf(buf, "\t\t\t\treturn ctx.TakeError()\n") - fmt.Fprintf(buf, "\t\t\t}\n") + writeGraphReservation(buf, "\t\t\t", "mapLen", unsafeSizeExpr(keyType)+" + "+unsafeSizeExpr(valueType)) fmt.Fprintf(buf, "\t\t\tif mapLen == 0 {\n") fmt.Fprintf(buf, "\t\t\t\t%s = make(%s)\n", fieldAccess, mapType.String()) fmt.Fprintf(buf, "\t\t\t} else {\n") @@ -882,9 +890,7 @@ func generateMapReadInline(buf *bytes.Buffer, mapType *types.Map, fieldAccess st fmt.Fprintf(buf, "\t\t\t\tif ctx.HasError() {\n") fmt.Fprintf(buf, "\t\t\t\t\treturn ctx.TakeError()\n") fmt.Fprintf(buf, "\t\t\t\t}\n") - fmt.Fprintf(buf, "\t\t\t\tif !ctx.ReserveCountedGraphMemory(mapLen, %s + %s) {\n", unsafeSizeExpr(keyType), unsafeSizeExpr(valueType)) - fmt.Fprintf(buf, "\t\t\t\t\treturn ctx.TakeError()\n") - fmt.Fprintf(buf, "\t\t\t\t}\n") + writeGraphReservation(buf, "\t\t\t\t", "mapLen", unsafeSizeExpr(keyType)+" + "+unsafeSizeExpr(valueType)) fmt.Fprintf(buf, "\t\t\t\tif mapLen == 0 {\n") fmt.Fprintf(buf, "\t\t\t\t\t%s = make(%s)\n", fieldAccess, mapType.String()) fmt.Fprintf(buf, "\t\t\t\t} else {\n") @@ -920,9 +926,7 @@ func generateMapReadInlineNoNull(buf *bytes.Buffer, mapType *types.Map, fieldAcc fmt.Fprintf(buf, "%sif ctx.HasError() {\n", indent) fmt.Fprintf(buf, "%s\treturn ctx.TakeError()\n", indent) fmt.Fprintf(buf, "%s}\n", indent) - fmt.Fprintf(buf, "%sif !ctx.ReserveCountedGraphMemory(mapLen, %s + %s) {\n", indent, unsafeSizeExpr(keyType), unsafeSizeExpr(valueType)) - fmt.Fprintf(buf, "%s\treturn ctx.TakeError()\n", indent) - fmt.Fprintf(buf, "%s}\n", indent) + writeGraphReservation(buf, indent, "mapLen", unsafeSizeExpr(keyType)+" + "+unsafeSizeExpr(valueType)) fmt.Fprintf(buf, "%sif mapLen == 0 {\n", indent) fmt.Fprintf(buf, "%s\t%s = make(%s)\n", indent, fieldAccess, mapType.String()) fmt.Fprintf(buf, "%s} else {\n", indent) diff --git a/go/fory/fory.go b/go/fory/fory.go index 5b4d168878..b08d9ecee3 100644 --- a/go/fory/fory.go +++ b/go/fory/fory.go @@ -83,7 +83,7 @@ func defaultConfig() Config { MaxDepth: 20, IsXlang: true, MaxTypeFields: 512, - MaxGraphMemoryBytes: -1, + MaxGraphMemoryBytes: 128 * 1024 * 1024, MaxTypeMetaBytes: 4096, MaxSchemaVersionsPerType: 10, MaxAverageSchemaVersionsPerType: 3, @@ -113,11 +113,8 @@ func WithMaxDepth(depth int) Option { } // WithMaxGraphMemoryBytes sets the maximum estimated graph memory accepted during one root deserialization. -// Use -1 for the automatic input-shaped limit. +// Non-positive values disable graph-memory enforcement. func WithMaxGraphMemoryBytes(size int64) Option { - if size != -1 && size <= 0 { - panic("MaxGraphMemoryBytes must be positive or -1 for auto") - } return func(f *Fory) { f.config.MaxGraphMemoryBytes = size } @@ -576,7 +573,7 @@ func (f *Fory) Deserialize(data []byte, v any) error { defer f.resetReadState() f.readCtx.SetData(data) target := reflect.ValueOf(v).Elem() - if err := f.initRootGraphBudget(target, len(data), false); err != nil { + if err := f.initRootGraphBudget(target); err != nil { return err } @@ -586,7 +583,7 @@ func (f *Fory) Deserialize(data []byte, v any) error { } // Deserialize the value - TypeMeta is read inline using streaming protocol - f.readRootValue(target) + f.readCtx.ReadValue(target, RefModeTracking, true) if f.readCtx.HasError() { return f.readCtx.TakeError() } @@ -671,7 +668,7 @@ func (f *Fory) DeserializeFrom(buf *ByteBuffer, v any) error { origBuffer := f.readCtx.buffer f.readCtx.buffer = buf target := reflect.ValueOf(v).Elem() - if err := f.initRootGraphBudget(target, buf.readableBytes(), false); err != nil { + if err := f.initRootGraphBudget(target); err != nil { f.readCtx.buffer = origBuffer return err } @@ -683,7 +680,7 @@ func (f *Fory) DeserializeFrom(buf *ByteBuffer, v any) error { } // Deserialize the value - TypeMeta is read inline using streaming protocol - f.readRootValue(target) + f.readCtx.ReadValue(target, RefModeTracking, true) if f.readCtx.HasError() { f.readCtx.buffer = origBuffer return f.readCtx.TakeError() @@ -787,7 +784,7 @@ func (f *Fory) DeserializeWithCallbackBuffers(buffer *ByteBuffer, v any, buffers } target := rv.Elem() - if err := f.initRootGraphBudget(target, buffer.readableBytes(), false); err != nil { + if err := f.initRootGraphBudget(target); err != nil { return err } @@ -798,7 +795,7 @@ func (f *Fory) DeserializeWithCallbackBuffers(buffer *ByteBuffer, v any, buffers } // Deserialize the value - TypeMeta is read inline using streaming protocol - f.readRootValue(target) + f.readCtx.ReadValue(target, RefModeTracking, true) if f.readCtx.HasError() { return f.readCtx.TakeError() } @@ -1054,7 +1051,7 @@ func Deserialize[T any](f *Fory, data []byte, target *T) error { *[]byte, *[]int8, *[]int16, *[]int32, *[]int64, *[]int, *[]float32, *[]float64, *[]bool: default: targetVal = reflect.ValueOf(target).Elem() - if err := f.initRootGraphBudget(targetVal, len(data), false); err != nil { + if err := f.initRootGraphBudget(targetVal); err != nil { return err } } @@ -1213,36 +1210,33 @@ func Deserialize[T any](f *Fory, data []byte, target *T) error { } } -func (f *Fory) initRootGraphBudget(target reflect.Value, rootInputBytes int, unknownLengthInput bool) error { - if target.IsValid() && target.Type() == f.rootGraphSkipType { +func (f *Fory) initRootGraphBudget(target reflect.Value) error { + if !target.IsValid() { + f.readCtx.initGraphMemoryBudget() + if f.readCtx.HasError() { + return f.readCtx.TakeError() + } return nil } - return f.initRootGraphBudgetSlow(target, rootInputBytes, unknownLengthInput) -} - -func (f *Fory) readRootValue(target reflect.Value) { - if target.IsValid() { - targetType := target.Type() - if targetType.Kind() == reflect.Struct && !f.typeResolver.IsUnionType(targetType) { - f.readCtx.ReadStruct(target) - return - } + targetType := target.Type() + if targetType == f.rootGraphSkipType { + return nil } - f.readCtx.ReadValue(target, RefModeTracking, true) + return f.initRootGraphBudgetSlow(targetType) } //go:noinline -func (f *Fory) initRootGraphBudgetSlow(target reflect.Value, rootInputBytes int, unknownLengthInput bool) error { - bytes, hasChildren, isStruct := f.rootGraphInfo(target) +func (f *Fory) initRootGraphBudgetSlow(targetType reflect.Type) error { + bytes, hasChildren, isStruct := f.rootGraphInfo(targetType) if !isStruct { - f.readCtx.initGraphMemoryBudget(rootInputBytes, unknownLengthInput) + f.readCtx.initGraphMemoryBudget() if f.readCtx.HasError() { return f.readCtx.TakeError() } return nil } if hasChildren { - f.readCtx.initGraphMemoryBudget(rootInputBytes, unknownLengthInput) + f.readCtx.initGraphMemoryBudget() if f.readCtx.HasError() { return f.readCtx.TakeError() } @@ -1251,18 +1245,17 @@ func (f *Fory) initRootGraphBudgetSlow(target reflect.Value, rootInputBytes int, } return nil } - if f.config.MaxGraphMemoryBytes <= 0 && bytes <= knownRootBudgetSlackBytes { - f.rootGraphSkipType = target.Type() + if f.config.MaxGraphMemoryBytes <= 0 || bytes <= f.config.MaxGraphMemoryBytes { + f.rootGraphSkipType = targetType return nil } - return f.checkRootGraphSelf(bytes, rootInputBytes, unknownLengthInput) + return f.checkRootGraphSelf(bytes) } -func (f *Fory) rootGraphInfo(target reflect.Value) (int64, bool, bool) { - if !target.IsValid() || target.Kind() != reflect.Struct { +func (f *Fory) rootGraphInfo(targetType reflect.Type) (int64, bool, bool) { + if targetType == nil || targetType.Kind() != reflect.Struct { return 0, false, false } - targetType := target.Type() if targetType == dateReflectType || targetType == timeReflectType { return 0, false, true } @@ -1277,26 +1270,13 @@ func (f *Fory) rootGraphInfo(target reflect.Value) (int64, bool, bool) { return bytes, hasChildren, true } -func (f *Fory) checkRootGraphSelf(bytes int64, rootInputBytes int, unknownLengthInput bool) error { +func (f *Fory) checkRootGraphSelf(bytes int64) error { if bytes <= 0 { return nil } limit := f.config.MaxGraphMemoryBytes if limit <= 0 { - if unknownLengthInput { - limit = streamRootBudgetBytes - } else { - if rootInputBytes < 0 { - return DeserializationErrorf("root input size must be non-negative: %d", rootInputBytes) - } - if int64(rootInputBytes) > (MaxInt64-knownRootBudgetSlackBytes)/knownRootBudgetMultiplier { - return DeserializationErrorf("root input size %d overflows automatic graph memory budget", rootInputBytes) - } - if bytes <= knownRootBudgetSlackBytes { - return nil - } - limit = int64(rootInputBytes)*knownRootBudgetMultiplier + knownRootBudgetSlackBytes - } + return nil } if bytes > limit { return DeserializationErrorf( diff --git a/go/fory/graph_memory_budget_test.go b/go/fory/graph_memory_budget_test.go index 7c266d72e2..b3362aa4f7 100644 --- a/go/fory/graph_memory_budget_test.go +++ b/go/fory/graph_memory_budget_test.go @@ -44,37 +44,37 @@ func graphOwnerSizeOf[T any]() int64 { } func TestGraphMemoryBudgetConfig(t *testing.T) { - require.Equal(t, int64(-1), New().config.MaxGraphMemoryBytes) + require.Equal(t, int64(128*1024*1024), New().config.MaxGraphMemoryBytes) require.Equal(t, int64(123), New(WithMaxGraphMemoryBytes(123)).config.MaxGraphMemoryBytes) - require.Panics(t, func() { New(WithMaxGraphMemoryBytes(0)) }) - require.Panics(t, func() { New(WithMaxGraphMemoryBytes(-2)) }) + require.Equal(t, int64(0), New(WithMaxGraphMemoryBytes(0)).config.MaxGraphMemoryBytes) + require.Equal(t, int64(-2), New(WithMaxGraphMemoryBytes(-2)).config.MaxGraphMemoryBytes) } -func TestGraphMemoryBudgetAutoLimits(t *testing.T) { +func TestGraphMemoryBudgetFixedDefaultAndDisable(t *testing.T) { ctx := NewReadContext(false) - ctx.initGraphMemoryBudget(10, false) + ctx.initGraphMemoryBudget() require.False(t, ctx.HasError()) - require.Equal(t, int64(10)*knownRootBudgetMultiplier+knownRootBudgetSlackBytes, ctx.graphMemoryLimitBytes) + require.Equal(t, int64(128*1024*1024), ctx.graphMemoryLimitBytes) require.True(t, ctx.ReserveGraphMemory(ctx.graphMemoryLimitBytes)) require.False(t, ctx.ReserveGraphMemory(1)) require.Contains(t, ctx.CheckError().Error(), "maxGraphMemoryBytes") ctx = NewReadContext(false) - ctx.initGraphMemoryBudget(10, true) + ctx.maxGraphMemoryBytes = 0 + ctx.initGraphMemoryBudget() + require.False(t, ctx.HasError()) + require.Equal(t, int64(0), ctx.graphMemoryLimitBytes) + require.True(t, ctx.ReserveGraphMemory(MaxInt64)) require.False(t, ctx.HasError()) - require.Equal(t, streamRootBudgetBytes, ctx.graphMemoryLimitBytes) - require.True(t, ctx.ReserveGraphMemory(streamRootBudgetBytes)) - require.False(t, ctx.ReserveGraphMemory(1)) - require.Contains(t, ctx.CheckError().Error(), "maxGraphMemoryBytes") ctx = NewReadContext(false) ctx.maxGraphMemoryBytes = 77 - ctx.initGraphMemoryBudget(10, true) + ctx.initGraphMemoryBudget() require.False(t, ctx.HasError()) require.Equal(t, int64(77), ctx.graphMemoryLimitBytes) } -func TestGraphMemoryBudgetKnownVsStreamRoot(t *testing.T) { +func TestGraphMemoryBudgetRootKindsShareDefault(t *testing.T) { writer := New(WithCompatible(false)) values := make([]any, 12000) for i := range values { @@ -85,8 +85,8 @@ func TestGraphMemoryBudgetKnownVsStreamRoot(t *testing.T) { var fromBytes []any err = New(WithCompatible(false)).Deserialize(data, &fromBytes) - require.Error(t, err) - require.Contains(t, err.Error(), "maxGraphMemoryBytes") + require.NoError(t, err) + require.Len(t, fromBytes, len(values)) var fromStream []any err = New(WithCompatible(false)).DeserializeFromReader(bytes.NewReader(data), &fromStream) @@ -163,9 +163,9 @@ func TestGraphMemoryBudgetMapAndOverflow(t *testing.T) { require.Contains(t, err.Error(), "maxGraphMemoryBytes") ctx := NewReadContext(false) - ctx.initGraphMemoryBudget(0, true) - require.False(t, ctx.ReserveCountedGraphMemory(MaxInt, MaxInt64)) - require.Contains(t, ctx.CheckError().Error(), "overflows") + ctx.initGraphMemoryBudget() + require.False(t, ctx.ReserveGraphMemory(-1)) + require.Contains(t, ctx.CheckError().Error(), "non-negative") } func TestGraphMemoryBudgetSlicesAndInlineValues(t *testing.T) { diff --git a/go/fory/map.go b/go/fory/map.go index 9d68706838..a5f2acaccc 100644 --- a/go/fory/map.go +++ b/go/fory/map.go @@ -310,7 +310,15 @@ func (s mapSerializer) ReadData(ctx *ReadContext, value reflect.Value) { ctx.setGraphMemoryError("map entry size overflows: key=%d value=%d", keyBytes, valueBytes) return } - if !ctx.ReserveCountedGraphMemory(size, elemBytes) { + if size < 0 { + ctx.setGraphMemoryError("negative graph element count: %d", size) + return + } + if int64(size) > maxGraphCount(elemBytes) { + ctx.setGraphMemoryError("graph memory estimate overflows: length=%d elementBytes=%d", size, elemBytes) + return + } + if !ctx.ReserveGraphMemory(int64(size) * elemBytes) { return } if size == 0 { diff --git a/go/fory/map_primitive.go b/go/fory/map_primitive.go index d0ed63a632..0a658b2755 100644 --- a/go/fory/map_primitive.go +++ b/go/fory/map_primitive.go @@ -94,7 +94,15 @@ func readTypedMapSize(ctx *ReadContext, elemBytes int64, maxLength int64) (int, if ctx.HasError() { return 0, false } - if !ctx.reserveCountedGraphMemory(size, elemBytes, maxLength) { + if size < 0 { + ctx.setGraphMemoryError("negative graph element count: %d", size) + return 0, false + } + if int64(size) > maxLength { + ctx.setGraphMemoryError("graph memory estimate overflows: length=%d elementBytes=%d", size, elemBytes) + return 0, false + } + if !ctx.ReserveGraphMemory(int64(size) * elemBytes) { return 0, false } if size == 0 { diff --git a/go/fory/reader.go b/go/fory/reader.go index 5579659b96..0fe3c3481e 100644 --- a/go/fory/reader.go +++ b/go/fory/reader.go @@ -49,12 +49,6 @@ type ReadContext struct { remainingGraphMemoryBytes int64 } -const ( - knownRootBudgetMultiplier = int64(8) - knownRootBudgetSlackBytes = int64(64 * 1024) - streamRootBudgetBytes = int64(128 * 1024 * 1024) -) - var referenceSlotBytes = int64(unsafe.Sizeof(uintptr(0))) var stringElementBytes = graphSizeOf[string]() var stringMaxLength = maxGraphCount(stringElementBytes) @@ -140,9 +134,9 @@ func NewReadContext(trackRef bool) *ReadContext { refReader: NewRefReader(trackRef), trackRef: trackRef, maxDepth: 128, // Default maximum nesting depth - maxGraphMemoryBytes: -1, - graphMemoryLimitBytes: MaxInt64, - remainingGraphMemoryBytes: MaxInt64, + maxGraphMemoryBytes: 128 * 1024 * 1024, + graphMemoryLimitBytes: 128 * 1024 * 1024, + remainingGraphMemoryBytes: 128 * 1024 * 1024, } } @@ -162,67 +156,26 @@ func (c *ReadContext) Reset() { } } -func (c *ReadContext) initGraphMemoryBudget(rootInputBytes int, unknownLengthInput bool) { +func (c *ReadContext) initGraphMemoryBudget() { limit := c.maxGraphMemoryBytes if limit <= 0 { - if unknownLengthInput { - limit = streamRootBudgetBytes - } else { - if rootInputBytes < 0 { - c.setGraphMemoryError("root input size must be non-negative: %d", rootInputBytes) - return - } - if int64(rootInputBytes) > (MaxInt64-knownRootBudgetSlackBytes)/knownRootBudgetMultiplier { - c.setGraphMemoryError("root input size %d overflows automatic graph memory budget", rootInputBytes) - return - } - limit = int64(rootInputBytes)*knownRootBudgetMultiplier + knownRootBudgetSlackBytes - } + c.graphMemoryLimitBytes = 0 + c.remainingGraphMemoryBytes = MaxInt64 + return } c.graphMemoryLimitBytes = limit c.remainingGraphMemoryBytes = limit } -// ReserveCountedGraphMemory reserves length * elementBytes estimated graph bytes. -func (c *ReadContext) ReserveCountedGraphMemory(length int, elemBytes int64) bool { - if length < 0 { - c.setGraphMemoryError("negative graph element count: %d", length) - return false - } - if elemBytes < 0 { - c.setGraphMemoryError("negative graph element size: %d", elemBytes) - return false - } - if length == 0 { - return true - } - return c.reserveCountedGraphMemory(length, elemBytes, maxGraphCount(elemBytes)) -} - -func (c *ReadContext) reserveCountedGraphMemory(length int, elemBytes int64, maxLength int64) bool { - if length == 0 { - return true - } - if int64(length) > maxLength { - c.setGraphMemoryError("graph memory estimate overflows: length=%d elementBytes=%d", length, elemBytes) - return false - } - bytes := int64(length) * elemBytes - remaining := c.remainingGraphMemoryBytes - if bytes > remaining { - c.setGraphMemoryExceeded(bytes, remaining) - return false - } - c.remainingGraphMemoryBytes = remaining - bytes - return true -} - // ReserveGraphMemory reserves raw estimated graph-owner bytes. func (c *ReadContext) ReserveGraphMemory(bytes int64) bool { if bytes < 0 { c.setGraphMemoryError("estimated graph memory must be non-negative, got %d bytes", bytes) return false } + if c.graphMemoryLimitBytes <= 0 { + return true + } remaining := c.remainingGraphMemoryBytes if bytes > remaining { c.setGraphMemoryExceeded(bytes, remaining) @@ -715,7 +668,15 @@ func (c *ReadContext) readStringSliceData() []string { if c.HasError() { return nil } - if !c.reserveCountedGraphMemory(length, stringElementBytes, stringMaxLength) { + if length < 0 { + c.setGraphMemoryError("negative graph element count: %d", length) + return nil + } + if int64(length) > stringMaxLength { + c.setGraphMemoryError("graph memory estimate overflows: length=%d elementBytes=%d", length, stringElementBytes) + return nil + } + if !c.ReserveGraphMemory(int64(length) * stringElementBytes) { return nil } if length == 0 { @@ -924,15 +885,16 @@ func (c *ReadContext) ReadValue(value reflect.Value, refMode RefMode, readType b c.SetError(DeserializationError("invalid reflect.Value")) return } + valueType := value.Type() // Handle array targets (arrays are serialized as slices) - if value.Type().Kind() == reflect.Array { + if valueType.Kind() == reflect.Array { c.ReadArrayValue(value, refMode, readType) return } // For any types, we need to read the actual type from the buffer first - if value.Type().Kind() == reflect.Interface { + if valueType.Kind() == reflect.Interface { // Handle ref tracking based on refMode var refID int32 = int32(NotNullValueFlag) if refMode == RefModeTracking { @@ -1041,14 +1003,13 @@ func (c *ReadContext) ReadValue(value reflect.Value, refMode RefMode, readType b return } - if typeInfo := c.getTypeInfoByType(value.Type()); typeInfo != nil && typeInfo.Serializer != nil { + if typeInfo := c.getTypeInfoByType(valueType); typeInfo != nil && typeInfo.Serializer != nil { typeInfo.Serializer.Read(c, refMode, readType, false, value) return } // For struct types, use optimized ReadStruct path when using full ref tracking and type info. // Unions use a custom serializer and must bypass ReadStruct. - valueType := value.Type() if refMode == RefModeTracking && readType && !c.typeResolver.IsUnionType(valueType) { if valueType.Kind() == reflect.Struct { c.ReadStruct(value) diff --git a/go/fory/set.go b/go/fory/set.go index 437e679938..a014d91f5d 100644 --- a/go/fory/set.go +++ b/go/fory/set.go @@ -325,7 +325,7 @@ func (s setSerializer) ReadData(ctx *ReadContext, value reflect.Value) { ctx.setGraphMemoryError("map entry size overflows: key=%d value=%d", keyBytes, valueBytes) return } - if !ctx.ReserveCountedGraphMemory(length, elemBytes) { + if !ctx.ReserveGraphMemory(0) { return } // Initialize empty set if length is 0 @@ -373,7 +373,15 @@ func (s setSerializer) ReadData(ctx *ReadContext, value reflect.Value) { ctx.setGraphMemoryError("map entry size overflows: key=%d value=%d", keyBytes, valueBytes) return } - if !ctx.ReserveCountedGraphMemory(length, elemBytes) { + if length < 0 { + ctx.setGraphMemoryError("negative graph element count: %d", length) + return + } + if int64(length) > maxGraphCount(elemBytes) { + ctx.setGraphMemoryError("graph memory estimate overflows: length=%d elementBytes=%d", length, elemBytes) + return + } + if !ctx.ReserveGraphMemory(int64(length) * elemBytes) { return } diff --git a/go/fory/slice.go b/go/fory/slice.go index 7975feb432..67dbc708e2 100644 --- a/go/fory/slice.go +++ b/go/fory/slice.go @@ -319,8 +319,18 @@ func (s *sliceSerializer) ReadData(ctx *ReadContext, value reflect.Value) { } return } - if !isArrayType && !ctx.reserveCountedGraphMemory(length, s.elemBytes, s.maxLength) { - return + if !isArrayType { + if length < 0 { + ctx.setGraphMemoryError("negative graph element count: %d", length) + return + } + if int64(length) > s.maxLength { + ctx.setGraphMemoryError("graph memory estimate overflows: length=%d elementBytes=%d", length, s.elemBytes) + return + } + if !ctx.ReserveGraphMemory(int64(length) * s.elemBytes) { + return + } } // ReadData collection flags diff --git a/go/fory/slice_dyn.go b/go/fory/slice_dyn.go index b65721f9f8..35594be458 100644 --- a/go/fory/slice_dyn.go +++ b/go/fory/slice_dyn.go @@ -287,8 +287,18 @@ func (s *sliceDynSerializer) readData(ctx *ReadContext, value reflect.Value, exp value.Set(reflect.MakeSlice(sliceType, 0, 0)) return } - if !allocatedByCaller && !ctx.reserveCountedGraphMemory(length, s.elemBytes, s.maxLength) { - return + if !allocatedByCaller { + if length < 0 { + ctx.setGraphMemoryError("negative graph element count: %d", length) + return + } + if int64(length) > s.maxLength { + ctx.setGraphMemoryError("graph memory estimate overflows: length=%d elementBytes=%d", length, s.elemBytes) + return + } + if !ctx.ReserveGraphMemory(int64(length) * s.elemBytes) { + return + } } collectFlag := buf.ReadInt8(ctxErr) diff --git a/go/fory/slice_primitive.go b/go/fory/slice_primitive.go index 3e2fc5e2d8..ec5c99d723 100644 --- a/go/fory/slice_primitive.go +++ b/go/fory/slice_primitive.go @@ -652,7 +652,15 @@ func (s stringSliceSerializer) ReadData(ctx *ReadContext, value reflect.Value) { return } ptr := (*[]string)(value.Addr().UnsafePointer()) - if !ctx.reserveCountedGraphMemory(length, stringElementBytes, stringMaxLength) { + if length < 0 { + ctx.setGraphMemoryError("negative graph element count: %d", length) + return + } + if int64(length) > stringMaxLength { + ctx.setGraphMemoryError("graph memory estimate overflows: length=%d elementBytes=%d", length, stringElementBytes) + return + } + if !ctx.ReserveGraphMemory(int64(length) * stringElementBytes) { return } if length == 0 { diff --git a/go/fory/slice_primitive_list.go b/go/fory/slice_primitive_list.go index efda751dde..3f920ad17e 100644 --- a/go/fory/slice_primitive_list.go +++ b/go/fory/slice_primitive_list.go @@ -179,7 +179,15 @@ func (s primitiveListSerializer) ReadData(ctx *ReadContext, value reflect.Value) if ctx.HasError() { return } - if !ctx.reserveCountedGraphMemory(length, s.elemBytes, s.maxLength) { + if length < 0 { + ctx.setGraphMemoryError("negative graph element count: %d", length) + return + } + if int64(length) > s.maxLength { + ctx.setGraphMemoryError("graph memory estimate overflows: length=%d elementBytes=%d", length, s.elemBytes) + return + } + if !ctx.ReserveGraphMemory(int64(length) * s.elemBytes) { return } if length == 0 { @@ -243,7 +251,7 @@ func (s compatiblePrimitiveListToArraySerializer) ReadData(ctx *ReadContext, val } if length == 0 { if value.Kind() == reflect.Slice { - if !ctx.reserveCountedGraphMemory(length, s.listReader.elemBytes, s.listReader.maxLength) { + if !ctx.ReserveGraphMemory(0) { return } value.Set(reflect.MakeSlice(value.Type(), 0, 0)) @@ -284,7 +292,15 @@ func (s compatiblePrimitiveListToArraySerializer) ReadData(ctx *ReadContext, val return } if value.Kind() == reflect.Slice { - if !ctx.reserveCountedGraphMemory(length, s.listReader.elemBytes, s.listReader.maxLength) { + if length < 0 { + ctx.setGraphMemoryError("negative graph element count: %d", length) + return + } + if int64(length) > s.listReader.maxLength { + ctx.setGraphMemoryError("graph memory estimate overflows: length=%d elementBytes=%d", length, s.listReader.elemBytes) + return + } + if !ctx.ReserveGraphMemory(int64(length) * s.listReader.elemBytes) { return } temp := reflect.New(value.Type()).Elem() diff --git a/go/fory/stream.go b/go/fory/stream.go index a032c7a64f..f8f1e63430 100644 --- a/go/fory/stream.go +++ b/go/fory/stream.go @@ -97,7 +97,7 @@ func (f *Fory) DeserializeFromStream(is *InputStream, v any) error { origBuffer := f.readCtx.buffer f.readCtx.buffer = is.buffer target := reflect.ValueOf(v).Elem() - if err := f.initRootGraphBudget(target, 0, true); err != nil { + if err := f.initRootGraphBudget(target); err != nil { f.readCtx.buffer = origBuffer f.resetReadState() return err @@ -112,7 +112,7 @@ func (f *Fory) DeserializeFromStream(is *InputStream, v any) error { return f.readCtx.TakeError() } - f.readRootValue(target) + f.readCtx.ReadValue(target, RefModeTracking, true) if f.readCtx.HasError() { return f.readCtx.TakeError() } @@ -129,7 +129,7 @@ func (f *Fory) DeserializeFromReader(r io.Reader, v any) error { // Always reset to enforce stateless semantics. f.readCtx.buffer.ResetWithReader(r, 0) target := reflect.ValueOf(v).Elem() - if err := f.initRootGraphBudget(target, 0, true); err != nil { + if err := f.initRootGraphBudget(target); err != nil { return err } @@ -138,7 +138,7 @@ func (f *Fory) DeserializeFromReader(r io.Reader, v any) error { return f.readCtx.TakeError() } - f.readRootValue(target) + f.readCtx.ReadValue(target, RefModeTracking, true) if f.readCtx.HasError() { return f.readCtx.TakeError() } diff --git a/go/fory/tests/structs_fory_gen.go b/go/fory/tests/structs_fory_gen.go index 945f3f5f58..2fb5ad9623 100644 --- a/go/fory/tests/structs_fory_gen.go +++ b/go/fory/tests/structs_fory_gen.go @@ -1,6 +1,6 @@ // Code generated by forygen. DO NOT EDIT. // source: structs.go -// generated at: 2026-06-30T22:09:02+08:00 +// generated at: 2026-07-01T02:06:03+08:00 package fory @@ -190,8 +190,24 @@ func (g *DynamicSliceDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v if ctx.HasError() { return ctx.TakeError() } - if !ctx.ReserveCountedGraphMemory(sliceLen, int64(unsafe.Sizeof(*new(any)))) { - return ctx.TakeError() + { + graphCount := sliceLen + graphElemBytes := int64(int64(unsafe.Sizeof(*new(any)))) + if graphCount < 0 { + ctx.SetError(fory.DeserializationErrorf("negative graph element count: %d", graphCount)) + return ctx.TakeError() + } + if graphElemBytes < 0 { + ctx.SetError(fory.DeserializationErrorf("negative graph element size: %d", graphElemBytes)) + return ctx.TakeError() + } + if graphCount != 0 && graphElemBytes != 0 && int64(graphCount) > fory.MaxInt64/graphElemBytes { + ctx.SetError(fory.DeserializationErrorf("graph memory estimate overflows: length=%d elementBytes=%d", graphCount, graphElemBytes)) + return ctx.TakeError() + } + if !ctx.ReserveGraphMemory(int64(graphCount) * graphElemBytes) { + return ctx.TakeError() + } } if sliceLen == 0 { v.DynamicSlice = make([]any, 0) @@ -221,8 +237,24 @@ func (g *DynamicSliceDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v if ctx.HasError() { return ctx.TakeError() } - if !ctx.ReserveCountedGraphMemory(sliceLen, int64(unsafe.Sizeof(*new(any)))) { - return ctx.TakeError() + { + graphCount := sliceLen + graphElemBytes := int64(int64(unsafe.Sizeof(*new(any)))) + if graphCount < 0 { + ctx.SetError(fory.DeserializationErrorf("negative graph element count: %d", graphCount)) + return ctx.TakeError() + } + if graphElemBytes < 0 { + ctx.SetError(fory.DeserializationErrorf("negative graph element size: %d", graphElemBytes)) + return ctx.TakeError() + } + if graphCount != 0 && graphElemBytes != 0 && int64(graphCount) > fory.MaxInt64/graphElemBytes { + ctx.SetError(fory.DeserializationErrorf("graph memory estimate overflows: length=%d elementBytes=%d", graphCount, graphElemBytes)) + return ctx.TakeError() + } + if !ctx.ReserveGraphMemory(int64(graphCount) * graphElemBytes) { + return ctx.TakeError() + } } if sliceLen == 0 { v.DynamicSlice = make([]any, 0) @@ -678,8 +710,24 @@ func (g *MapDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v *MapDemo) if ctx.HasError() { return ctx.TakeError() } - if !ctx.ReserveCountedGraphMemory(mapLen, int64(unsafe.Sizeof(*new(int)))+int64(unsafe.Sizeof(*new(int)))) { - return ctx.TakeError() + { + graphCount := mapLen + graphElemBytes := int64(int64(unsafe.Sizeof(*new(int))) + int64(unsafe.Sizeof(*new(int)))) + if graphCount < 0 { + ctx.SetError(fory.DeserializationErrorf("negative graph element count: %d", graphCount)) + return ctx.TakeError() + } + if graphElemBytes < 0 { + ctx.SetError(fory.DeserializationErrorf("negative graph element size: %d", graphElemBytes)) + return ctx.TakeError() + } + if graphCount != 0 && graphElemBytes != 0 && int64(graphCount) > fory.MaxInt64/graphElemBytes { + ctx.SetError(fory.DeserializationErrorf("graph memory estimate overflows: length=%d elementBytes=%d", graphCount, graphElemBytes)) + return ctx.TakeError() + } + if !ctx.ReserveGraphMemory(int64(graphCount) * graphElemBytes) { + return ctx.TakeError() + } } if mapLen == 0 { v.IntMap = make(map[int]int) @@ -728,8 +776,24 @@ func (g *MapDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v *MapDemo) if ctx.HasError() { return ctx.TakeError() } - if !ctx.ReserveCountedGraphMemory(mapLen, int64(unsafe.Sizeof(*new(int)))+int64(unsafe.Sizeof(*new(int)))) { - return ctx.TakeError() + { + graphCount := mapLen + graphElemBytes := int64(int64(unsafe.Sizeof(*new(int))) + int64(unsafe.Sizeof(*new(int)))) + if graphCount < 0 { + ctx.SetError(fory.DeserializationErrorf("negative graph element count: %d", graphCount)) + return ctx.TakeError() + } + if graphElemBytes < 0 { + ctx.SetError(fory.DeserializationErrorf("negative graph element size: %d", graphElemBytes)) + return ctx.TakeError() + } + if graphCount != 0 && graphElemBytes != 0 && int64(graphCount) > fory.MaxInt64/graphElemBytes { + ctx.SetError(fory.DeserializationErrorf("graph memory estimate overflows: length=%d elementBytes=%d", graphCount, graphElemBytes)) + return ctx.TakeError() + } + if !ctx.ReserveGraphMemory(int64(graphCount) * graphElemBytes) { + return ctx.TakeError() + } } if mapLen == 0 { v.IntMap = make(map[int]int) @@ -780,8 +844,24 @@ func (g *MapDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v *MapDemo) if ctx.HasError() { return ctx.TakeError() } - if !ctx.ReserveCountedGraphMemory(mapLen, int64(unsafe.Sizeof(*new(string)))+int64(unsafe.Sizeof(*new(int)))) { - return ctx.TakeError() + { + graphCount := mapLen + graphElemBytes := int64(int64(unsafe.Sizeof(*new(string))) + int64(unsafe.Sizeof(*new(int)))) + if graphCount < 0 { + ctx.SetError(fory.DeserializationErrorf("negative graph element count: %d", graphCount)) + return ctx.TakeError() + } + if graphElemBytes < 0 { + ctx.SetError(fory.DeserializationErrorf("negative graph element size: %d", graphElemBytes)) + return ctx.TakeError() + } + if graphCount != 0 && graphElemBytes != 0 && int64(graphCount) > fory.MaxInt64/graphElemBytes { + ctx.SetError(fory.DeserializationErrorf("graph memory estimate overflows: length=%d elementBytes=%d", graphCount, graphElemBytes)) + return ctx.TakeError() + } + if !ctx.ReserveGraphMemory(int64(graphCount) * graphElemBytes) { + return ctx.TakeError() + } } if mapLen == 0 { v.MixedMap = make(map[string]int) @@ -830,8 +910,24 @@ func (g *MapDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v *MapDemo) if ctx.HasError() { return ctx.TakeError() } - if !ctx.ReserveCountedGraphMemory(mapLen, int64(unsafe.Sizeof(*new(string)))+int64(unsafe.Sizeof(*new(int)))) { - return ctx.TakeError() + { + graphCount := mapLen + graphElemBytes := int64(int64(unsafe.Sizeof(*new(string))) + int64(unsafe.Sizeof(*new(int)))) + if graphCount < 0 { + ctx.SetError(fory.DeserializationErrorf("negative graph element count: %d", graphCount)) + return ctx.TakeError() + } + if graphElemBytes < 0 { + ctx.SetError(fory.DeserializationErrorf("negative graph element size: %d", graphElemBytes)) + return ctx.TakeError() + } + if graphCount != 0 && graphElemBytes != 0 && int64(graphCount) > fory.MaxInt64/graphElemBytes { + ctx.SetError(fory.DeserializationErrorf("graph memory estimate overflows: length=%d elementBytes=%d", graphCount, graphElemBytes)) + return ctx.TakeError() + } + if !ctx.ReserveGraphMemory(int64(graphCount) * graphElemBytes) { + return ctx.TakeError() + } } if mapLen == 0 { v.MixedMap = make(map[string]int) @@ -882,8 +978,24 @@ func (g *MapDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v *MapDemo) if ctx.HasError() { return ctx.TakeError() } - if !ctx.ReserveCountedGraphMemory(mapLen, int64(unsafe.Sizeof(*new(string)))+int64(unsafe.Sizeof(*new(string)))) { - return ctx.TakeError() + { + graphCount := mapLen + graphElemBytes := int64(int64(unsafe.Sizeof(*new(string))) + int64(unsafe.Sizeof(*new(string)))) + if graphCount < 0 { + ctx.SetError(fory.DeserializationErrorf("negative graph element count: %d", graphCount)) + return ctx.TakeError() + } + if graphElemBytes < 0 { + ctx.SetError(fory.DeserializationErrorf("negative graph element size: %d", graphElemBytes)) + return ctx.TakeError() + } + if graphCount != 0 && graphElemBytes != 0 && int64(graphCount) > fory.MaxInt64/graphElemBytes { + ctx.SetError(fory.DeserializationErrorf("graph memory estimate overflows: length=%d elementBytes=%d", graphCount, graphElemBytes)) + return ctx.TakeError() + } + if !ctx.ReserveGraphMemory(int64(graphCount) * graphElemBytes) { + return ctx.TakeError() + } } if mapLen == 0 { v.StringMap = make(map[string]string) @@ -932,8 +1044,24 @@ func (g *MapDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v *MapDemo) if ctx.HasError() { return ctx.TakeError() } - if !ctx.ReserveCountedGraphMemory(mapLen, int64(unsafe.Sizeof(*new(string)))+int64(unsafe.Sizeof(*new(string)))) { - return ctx.TakeError() + { + graphCount := mapLen + graphElemBytes := int64(int64(unsafe.Sizeof(*new(string))) + int64(unsafe.Sizeof(*new(string)))) + if graphCount < 0 { + ctx.SetError(fory.DeserializationErrorf("negative graph element count: %d", graphCount)) + return ctx.TakeError() + } + if graphElemBytes < 0 { + ctx.SetError(fory.DeserializationErrorf("negative graph element size: %d", graphElemBytes)) + return ctx.TakeError() + } + if graphCount != 0 && graphElemBytes != 0 && int64(graphCount) > fory.MaxInt64/graphElemBytes { + ctx.SetError(fory.DeserializationErrorf("graph memory estimate overflows: length=%d elementBytes=%d", graphCount, graphElemBytes)) + return ctx.TakeError() + } + if !ctx.ReserveGraphMemory(int64(graphCount) * graphElemBytes) { + return ctx.TakeError() + } } if mapLen == 0 { v.StringMap = make(map[string]string) @@ -1299,8 +1427,24 @@ func (g *SliceDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v *SliceD if ctx.HasError() { return ctx.TakeError() } - if !ctx.ReserveCountedGraphMemory(sliceLen, int64(unsafe.Sizeof(*new(bool)))) { - return ctx.TakeError() + { + graphCount := sliceLen + graphElemBytes := int64(int64(unsafe.Sizeof(*new(bool)))) + if graphCount < 0 { + ctx.SetError(fory.DeserializationErrorf("negative graph element count: %d", graphCount)) + return ctx.TakeError() + } + if graphElemBytes < 0 { + ctx.SetError(fory.DeserializationErrorf("negative graph element size: %d", graphElemBytes)) + return ctx.TakeError() + } + if graphCount != 0 && graphElemBytes != 0 && int64(graphCount) > fory.MaxInt64/graphElemBytes { + ctx.SetError(fory.DeserializationErrorf("graph memory estimate overflows: length=%d elementBytes=%d", graphCount, graphElemBytes)) + return ctx.TakeError() + } + if !ctx.ReserveGraphMemory(int64(graphCount) * graphElemBytes) { + return ctx.TakeError() + } } if sliceLen == 0 { v.BoolSlice = make([]bool, 0) @@ -1341,8 +1485,24 @@ func (g *SliceDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v *SliceD if ctx.HasError() { return ctx.TakeError() } - if !ctx.ReserveCountedGraphMemory(sliceLen, int64(unsafe.Sizeof(*new(bool)))) { - return ctx.TakeError() + { + graphCount := sliceLen + graphElemBytes := int64(int64(unsafe.Sizeof(*new(bool)))) + if graphCount < 0 { + ctx.SetError(fory.DeserializationErrorf("negative graph element count: %d", graphCount)) + return ctx.TakeError() + } + if graphElemBytes < 0 { + ctx.SetError(fory.DeserializationErrorf("negative graph element size: %d", graphElemBytes)) + return ctx.TakeError() + } + if graphCount != 0 && graphElemBytes != 0 && int64(graphCount) > fory.MaxInt64/graphElemBytes { + ctx.SetError(fory.DeserializationErrorf("graph memory estimate overflows: length=%d elementBytes=%d", graphCount, graphElemBytes)) + return ctx.TakeError() + } + if !ctx.ReserveGraphMemory(int64(graphCount) * graphElemBytes) { + return ctx.TakeError() + } } if sliceLen == 0 { v.BoolSlice = make([]bool, 0) @@ -1385,8 +1545,24 @@ func (g *SliceDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v *SliceD if ctx.HasError() { return ctx.TakeError() } - if !ctx.ReserveCountedGraphMemory(sliceLen, int64(unsafe.Sizeof(*new(float64)))) { - return ctx.TakeError() + { + graphCount := sliceLen + graphElemBytes := int64(int64(unsafe.Sizeof(*new(float64)))) + if graphCount < 0 { + ctx.SetError(fory.DeserializationErrorf("negative graph element count: %d", graphCount)) + return ctx.TakeError() + } + if graphElemBytes < 0 { + ctx.SetError(fory.DeserializationErrorf("negative graph element size: %d", graphElemBytes)) + return ctx.TakeError() + } + if graphCount != 0 && graphElemBytes != 0 && int64(graphCount) > fory.MaxInt64/graphElemBytes { + ctx.SetError(fory.DeserializationErrorf("graph memory estimate overflows: length=%d elementBytes=%d", graphCount, graphElemBytes)) + return ctx.TakeError() + } + if !ctx.ReserveGraphMemory(int64(graphCount) * graphElemBytes) { + return ctx.TakeError() + } } if sliceLen == 0 { v.FloatSlice = make([]float64, 0) @@ -1427,8 +1603,24 @@ func (g *SliceDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v *SliceD if ctx.HasError() { return ctx.TakeError() } - if !ctx.ReserveCountedGraphMemory(sliceLen, int64(unsafe.Sizeof(*new(float64)))) { - return ctx.TakeError() + { + graphCount := sliceLen + graphElemBytes := int64(int64(unsafe.Sizeof(*new(float64)))) + if graphCount < 0 { + ctx.SetError(fory.DeserializationErrorf("negative graph element count: %d", graphCount)) + return ctx.TakeError() + } + if graphElemBytes < 0 { + ctx.SetError(fory.DeserializationErrorf("negative graph element size: %d", graphElemBytes)) + return ctx.TakeError() + } + if graphCount != 0 && graphElemBytes != 0 && int64(graphCount) > fory.MaxInt64/graphElemBytes { + ctx.SetError(fory.DeserializationErrorf("graph memory estimate overflows: length=%d elementBytes=%d", graphCount, graphElemBytes)) + return ctx.TakeError() + } + if !ctx.ReserveGraphMemory(int64(graphCount) * graphElemBytes) { + return ctx.TakeError() + } } if sliceLen == 0 { v.FloatSlice = make([]float64, 0) @@ -1471,8 +1663,24 @@ func (g *SliceDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v *SliceD if ctx.HasError() { return ctx.TakeError() } - if !ctx.ReserveCountedGraphMemory(sliceLen, int64(unsafe.Sizeof(*new(int32)))) { - return ctx.TakeError() + { + graphCount := sliceLen + graphElemBytes := int64(int64(unsafe.Sizeof(*new(int32)))) + if graphCount < 0 { + ctx.SetError(fory.DeserializationErrorf("negative graph element count: %d", graphCount)) + return ctx.TakeError() + } + if graphElemBytes < 0 { + ctx.SetError(fory.DeserializationErrorf("negative graph element size: %d", graphElemBytes)) + return ctx.TakeError() + } + if graphCount != 0 && graphElemBytes != 0 && int64(graphCount) > fory.MaxInt64/graphElemBytes { + ctx.SetError(fory.DeserializationErrorf("graph memory estimate overflows: length=%d elementBytes=%d", graphCount, graphElemBytes)) + return ctx.TakeError() + } + if !ctx.ReserveGraphMemory(int64(graphCount) * graphElemBytes) { + return ctx.TakeError() + } } if sliceLen == 0 { v.IntSlice = make([]int32, 0) @@ -1513,8 +1721,24 @@ func (g *SliceDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v *SliceD if ctx.HasError() { return ctx.TakeError() } - if !ctx.ReserveCountedGraphMemory(sliceLen, int64(unsafe.Sizeof(*new(int32)))) { - return ctx.TakeError() + { + graphCount := sliceLen + graphElemBytes := int64(int64(unsafe.Sizeof(*new(int32)))) + if graphCount < 0 { + ctx.SetError(fory.DeserializationErrorf("negative graph element count: %d", graphCount)) + return ctx.TakeError() + } + if graphElemBytes < 0 { + ctx.SetError(fory.DeserializationErrorf("negative graph element size: %d", graphElemBytes)) + return ctx.TakeError() + } + if graphCount != 0 && graphElemBytes != 0 && int64(graphCount) > fory.MaxInt64/graphElemBytes { + ctx.SetError(fory.DeserializationErrorf("graph memory estimate overflows: length=%d elementBytes=%d", graphCount, graphElemBytes)) + return ctx.TakeError() + } + if !ctx.ReserveGraphMemory(int64(graphCount) * graphElemBytes) { + return ctx.TakeError() + } } if sliceLen == 0 { v.IntSlice = make([]int32, 0) @@ -1557,8 +1781,24 @@ func (g *SliceDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v *SliceD if ctx.HasError() { return ctx.TakeError() } - if !ctx.ReserveCountedGraphMemory(sliceLen, int64(unsafe.Sizeof(*new(string)))) { - return ctx.TakeError() + { + graphCount := sliceLen + graphElemBytes := int64(int64(unsafe.Sizeof(*new(string)))) + if graphCount < 0 { + ctx.SetError(fory.DeserializationErrorf("negative graph element count: %d", graphCount)) + return ctx.TakeError() + } + if graphElemBytes < 0 { + ctx.SetError(fory.DeserializationErrorf("negative graph element size: %d", graphElemBytes)) + return ctx.TakeError() + } + if graphCount != 0 && graphElemBytes != 0 && int64(graphCount) > fory.MaxInt64/graphElemBytes { + ctx.SetError(fory.DeserializationErrorf("graph memory estimate overflows: length=%d elementBytes=%d", graphCount, graphElemBytes)) + return ctx.TakeError() + } + if !ctx.ReserveGraphMemory(int64(graphCount) * graphElemBytes) { + return ctx.TakeError() + } } if sliceLen == 0 { v.StringSlice = make([]string, 0) @@ -1607,8 +1847,24 @@ func (g *SliceDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v *SliceD if ctx.HasError() { return ctx.TakeError() } - if !ctx.ReserveCountedGraphMemory(sliceLen, int64(unsafe.Sizeof(*new(string)))) { - return ctx.TakeError() + { + graphCount := sliceLen + graphElemBytes := int64(int64(unsafe.Sizeof(*new(string)))) + if graphCount < 0 { + ctx.SetError(fory.DeserializationErrorf("negative graph element count: %d", graphCount)) + return ctx.TakeError() + } + if graphElemBytes < 0 { + ctx.SetError(fory.DeserializationErrorf("negative graph element size: %d", graphElemBytes)) + return ctx.TakeError() + } + if graphCount != 0 && graphElemBytes != 0 && int64(graphCount) > fory.MaxInt64/graphElemBytes { + ctx.SetError(fory.DeserializationErrorf("graph memory estimate overflows: length=%d elementBytes=%d", graphCount, graphElemBytes)) + return ctx.TakeError() + } + if !ctx.ReserveGraphMemory(int64(graphCount) * graphElemBytes) { + return ctx.TakeError() + } } if sliceLen == 0 { v.StringSlice = make([]string, 0) diff --git a/java/fory-core/src/main/java/org/apache/fory/Fory.java b/java/fory-core/src/main/java/org/apache/fory/Fory.java index b961e9c582..de76df2bd9 100644 --- a/java/fory-core/src/main/java/org/apache/fory/Fory.java +++ b/java/fory-core/src/main/java/org/apache/fory/Fory.java @@ -410,7 +410,7 @@ private ForyException processCopyError(Throwable e) { @Override public Object deserialize(byte[] bytes) { - return deserialize(MemoryUtils.wrap(bytes), (Iterable) null, false, bytes.length); + return deserialize(MemoryUtils.wrap(bytes), (Iterable) null); } @Override @@ -420,22 +420,21 @@ public Object deserialize(ByteBuffer byteBuffer) { @Override public T deserialize(byte[] bytes, Class type) { - return deserialize(MemoryUtils.wrap(bytes), type, false, bytes.length); + return deserializeRoot(MemoryUtils.wrap(bytes), type); } @Override public T deserialize(MemoryBuffer buffer, Class type) { - return deserialize(buffer, type, false, buffer.remaining()); + return deserializeRoot(buffer, type); } - private T deserialize( - MemoryBuffer buffer, Class type, boolean unknownLengthInput, int rootInputBytes) { + private T deserializeRoot(MemoryBuffer buffer, Class type) { ensureRegistrationFinished(); byte bitmap = buffer.readByte(); if (bitmap != headerBitmap) { checkHeaderBitmapWithoutOutOfBand(bitmap); } - readContext.prepare(buffer, null, false, rootInputBytes, unknownLengthInput); + readContext.prepare(buffer, null, false); try { try { jitContext.lock(); @@ -456,7 +455,7 @@ private T deserialize( @Override public T deserialize(ForyInputStream inputStream, Class type) { try { - return deserialize(inputStream.getBuffer(), type, true, 0); + return deserializeRoot(inputStream.getBuffer(), type); } finally { inputStream.shrinkBuffer(); } @@ -464,7 +463,7 @@ public T deserialize(ForyInputStream inputStream, Class type) { @Override public T deserialize(ForyReadableChannel channel, Class type) { - return deserialize(channel.getBuffer(), type, true, 0); + return deserializeRoot(channel.getBuffer(), type); } @Override @@ -492,14 +491,10 @@ public Object deserialize(MemoryBuffer buffer) { */ @Override public Object deserialize(MemoryBuffer buffer, Iterable outOfBandBuffers) { - return deserialize(buffer, outOfBandBuffers, false, buffer.remaining()); + return deserializeRoot(buffer, outOfBandBuffers); } - private Object deserialize( - MemoryBuffer buffer, - Iterable outOfBandBuffers, - boolean unknownLengthInput, - int rootInputBytes) { + private Object deserializeRoot(MemoryBuffer buffer, Iterable outOfBandBuffers) { ensureRegistrationFinished(); byte bitmap = buffer.readByte(); boolean peerOutOfBandEnabled = false; @@ -518,11 +513,7 @@ private Object deserialize( + "produced with bufferCallback null."); } readContext.prepare( - buffer, - peerOutOfBandEnabled ? outOfBandBuffers : null, - peerOutOfBandEnabled, - rootInputBytes, - unknownLengthInput); + buffer, peerOutOfBandEnabled ? outOfBandBuffers : null, peerOutOfBandEnabled); try { try { jitContext.lock(); @@ -549,7 +540,7 @@ public Object deserialize(ForyInputStream inputStream) { public Object deserialize(ForyInputStream inputStream, Iterable outOfBandBuffers) { try { MemoryBuffer buf = inputStream.getBuffer(); - return deserialize(buf, outOfBandBuffers, true, 0); + return deserializeRoot(buf, outOfBandBuffers); } finally { inputStream.shrinkBuffer(); } @@ -563,7 +554,7 @@ public Object deserialize(ForyReadableChannel channel) { @Override public Object deserialize(ForyReadableChannel channel, Iterable outOfBandBuffers) { MemoryBuffer buf = channel.getBuffer(); - return deserialize(buf, outOfBandBuffers, true, 0); + return deserializeRoot(buf, outOfBandBuffers); } @SuppressWarnings("unchecked") diff --git a/java/fory-core/src/main/java/org/apache/fory/config/Config.java b/java/fory-core/src/main/java/org/apache/fory/config/Config.java index 323f4f7bd4..869bce3ef1 100644 --- a/java/fory-core/src/main/java/org/apache/fory/config/Config.java +++ b/java/fory-core/src/main/java/org/apache/fory/config/Config.java @@ -322,7 +322,7 @@ public int maxAverageSchemaVersionsPerType() { return maxAverageSchemaVersionsPerType; } - /** Returns the root-operation estimated graph memory limit in bytes, or -1 for auto. */ + /** Returns the root-operation estimated graph memory limit in bytes. Non-positive disables it. */ public long maxGraphMemoryBytes() { return maxGraphMemoryBytes; } diff --git a/java/fory-core/src/main/java/org/apache/fory/config/ForyBuilder.java b/java/fory-core/src/main/java/org/apache/fory/config/ForyBuilder.java index b5975502ee..facd26013c 100644 --- a/java/fory-core/src/main/java/org/apache/fory/config/ForyBuilder.java +++ b/java/fory-core/src/main/java/org/apache/fory/config/ForyBuilder.java @@ -103,7 +103,7 @@ public final class ForyBuilder { int maxTypeMetaBytes = 4096; int maxSchemaVersionsPerType = 10; int maxAverageSchemaVersionsPerType = 3; - long maxGraphMemoryBytes = -1; + long maxGraphMemoryBytes = 128L * 1024 * 1024; float mapRefLoadFactor = 0.51f; boolean forVirtualThread = false; TypeChecker typeChecker; @@ -575,14 +575,10 @@ public ForyBuilder withMaxAverageSchemaVersionsPerType(int maxAverageSchemaVersi /** * Sets the maximum estimated graph memory accepted during one root deserialization. * - *

The default is {@code -1}, which derives an automatic per-root budget from the input shape. - * Positive values are explicit byte limits. Other values are invalid. + *

The default is a fixed 128 MiB. Positive values are explicit byte limits. Explicit + * non-positive values disable this budget. */ public ForyBuilder withMaxGraphMemoryBytes(long maxGraphMemoryBytes) { - Preconditions.checkArgument( - maxGraphMemoryBytes == -1 || maxGraphMemoryBytes > 0, - "maxGraphMemoryBytes must be positive or -1 for auto but got %s", - maxGraphMemoryBytes); this.maxGraphMemoryBytes = maxGraphMemoryBytes; recordAction(b -> b.withMaxGraphMemoryBytes(maxGraphMemoryBytes)); return this; diff --git a/java/fory-core/src/main/java/org/apache/fory/context/ReadContext.java b/java/fory-core/src/main/java/org/apache/fory/context/ReadContext.java index dcd39f1347..e28b31fb25 100644 --- a/java/fory-core/src/main/java/org/apache/fory/context/ReadContext.java +++ b/java/fory-core/src/main/java/org/apache/fory/context/ReadContext.java @@ -51,9 +51,6 @@ */ @SuppressWarnings({"rawtypes", "unchecked"}) public final class ReadContext { - private static final long KNOWN_ROOT_BUDGET_MULTIPLIER = 8L; - private static final long KNOWN_ROOT_BUDGET_SLACK_BYTES = 64L * 1024; - private static final long STREAM_ROOT_BUDGET_BYTES = 128L * 1024 * 1024; private final Config config; private final Generics generics; private final TypeResolver typeResolver; @@ -117,27 +114,19 @@ public ReadContext( public void prepare( MemoryBuffer buffer, Iterable outOfBandBuffers, - boolean peerOutOfBandEnabled, - int rootInputBytes, - boolean unknownLengthInput) { + boolean peerOutOfBandEnabled) { this.buffer = buffer; this.peerOutOfBandEnabled = peerOutOfBandEnabled; this.outOfBandBuffers = outOfBandBuffers == null ? null : outOfBandBuffers.iterator(); - initGraphMemoryBudget(rootInputBytes, unknownLengthInput); + initGraphMemoryBudget(); } - private void initGraphMemoryBudget(int rootInputBytes, boolean unknownLengthInput) { + private void initGraphMemoryBudget() { long limit = maxGraphMemoryBytes; if (limit <= 0) { - if (unknownLengthInput) { - limit = STREAM_ROOT_BUDGET_BYTES; - } else { - if (rootInputBytes < 0) { - throw new IllegalArgumentException( - "Root input size must be non-negative: " + rootInputBytes); - } - limit = rootInputBytes * KNOWN_ROOT_BUDGET_MULTIPLIER + KNOWN_ROOT_BUDGET_SLACK_BYTES; - } + graphMemoryLimitBytes = 0; + remainingGraphMemoryBytes = Long.MAX_VALUE; + return; } graphMemoryLimitBytes = limit; remainingGraphMemoryBytes = limit; @@ -349,6 +338,9 @@ public void reserveGraphMemory(long bytes) { if (bytes < 0) { throwNegativeGraphMemory(bytes); } + if (graphMemoryLimitBytes <= 0) { + return; + } long remaining = remainingGraphMemoryBytes; if (bytes > remaining) { throwGraphMemoryExceeded(bytes, remaining); diff --git a/java/fory-core/src/test/java/org/apache/fory/ForyTestBase.java b/java/fory-core/src/test/java/org/apache/fory/ForyTestBase.java index 94c0e893b8..63e3ffcdc1 100644 --- a/java/fory-core/src/test/java/org/apache/fory/ForyTestBase.java +++ b/java/fory-core/src/test/java/org/apache/fory/ForyTestBase.java @@ -344,7 +344,7 @@ public static void withWriteContext( public static T withReadContext( Fory fory, MemoryBuffer buffer, Function action) { ReadContext context = (ReadContext) ReflectionUtils.getObjectFieldValue(fory, "readContext"); - context.prepare(buffer, null, false, buffer.remaining(), false); + context.prepare(buffer, null, false); try { return action.apply(context); } finally { diff --git a/java/fory-core/src/test/java/org/apache/fory/io/MemoryBufferObjectInputTest.java b/java/fory-core/src/test/java/org/apache/fory/io/MemoryBufferObjectInputTest.java index 3b379c7b3a..71f73582c2 100644 --- a/java/fory-core/src/test/java/org/apache/fory/io/MemoryBufferObjectInputTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/io/MemoryBufferObjectInputTest.java @@ -53,7 +53,7 @@ public void testForyStructInput(boolean compressNumber) throws IOException { buffer.writeFloat32(4.1f); buffer.writeFloat64(4.2); new StringSerializer(fory.getConfig()).writeString(buffer, "abc"); - fory.getReadContext().prepare(buffer, null, false, buffer.remaining(), false); + fory.getReadContext().prepare(buffer, null, false); try (MemoryBufferObjectInput input = new MemoryBufferObjectInput(fory.getConfig(), fory.getReadContext())) { assertEquals(input.readByte(), 1); diff --git a/java/fory-core/src/test/java/org/apache/fory/io/MemoryBufferObjectOutputTest.java b/java/fory-core/src/test/java/org/apache/fory/io/MemoryBufferObjectOutputTest.java index f5b69b8b13..a04a0e9f3c 100644 --- a/java/fory-core/src/test/java/org/apache/fory/io/MemoryBufferObjectOutputTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/io/MemoryBufferObjectOutputTest.java @@ -46,7 +46,7 @@ public void testForyStructOutput() throws IOException { output.writeChars("abc"); output.writeUTF("abc"); } - fory.getReadContext().prepare(buffer, null, false, buffer.remaining(), false); + fory.getReadContext().prepare(buffer, null, false); try (MemoryBufferObjectInput input = new MemoryBufferObjectInput(fory.getConfig(), fory.getReadContext())) { assertEquals(input.readByte(), 1); diff --git a/java/fory-core/src/test/java/org/apache/fory/resolver/ClassResolverTest.java b/java/fory-core/src/test/java/org/apache/fory/resolver/ClassResolverTest.java index 157b3950f5..0b94da34cb 100644 --- a/java/fory-core/src/test/java/org/apache/fory/resolver/ClassResolverTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/resolver/ClassResolverTest.java @@ -360,7 +360,7 @@ public void testRemoteTypeDefChecksTypeChecker() { ReadContext readContext = reader.getReadContext(); readContext.setMetaReadContext(new MetaReadContext()); MemoryBuffer buffer = MemoryBuffer.newHeapBuffer(256); - readContext.prepare(buffer, null, false, buffer.remaining(), false); + readContext.prepare(buffer, null, false); buffer.writeVarUInt32(0); typeDef.writeTypeDef(buffer); buffer.readerIndex(0); @@ -473,7 +473,7 @@ public void testExactLocalEnumTypeDefBypassesLimit() { ReadContext readContext = fory.getReadContext(); readContext.setMetaReadContext(new MetaReadContext()); MemoryBuffer buffer = MemoryBuffer.newHeapBuffer(256); - readContext.prepare(buffer, null, false, buffer.remaining(), false); + readContext.prepare(buffer, null, false); buffer.writeVarUInt32(0); exact.writeTypeDef(buffer); buffer.readerIndex(0); @@ -792,7 +792,7 @@ public void testWriteClassName() { } finally { fory.getWriteContext().reset(); } - fory.getReadContext().prepare(buffer, null, false, buffer.remaining(), false); + fory.getReadContext().prepare(buffer, null, false); try { Assert.assertSame(classResolver.readTypeInfo(fory.getReadContext()).getType(), getClass()); Assert.assertSame(classResolver.readTypeInfo(fory.getReadContext()).getType(), getClass()); diff --git a/java/fory-core/src/test/java/org/apache/fory/serializer/ArraySerializersTest.java b/java/fory-core/src/test/java/org/apache/fory/serializer/ArraySerializersTest.java index 752eb77686..df69696d64 100644 --- a/java/fory-core/src/test/java/org/apache/fory/serializer/ArraySerializersTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/serializer/ArraySerializersTest.java @@ -373,15 +373,11 @@ private static Object readPrimitiveArrayBody( MemoryBuffer control = MemoryBuffer.newHeapBuffer(1); control.writeBoolean(false); readContext.prepare( - control, - Collections.singletonList(MemoryUtils.wrap(new byte[byteSize])), - true, - control.remaining(), - false); + control, Collections.singletonList(MemoryUtils.wrap(new byte[byteSize])), true); } else { MemoryBuffer buffer = MemoryBuffer.newHeapBuffer(5); buffer.writeVarUInt32Small7(byteSize); - readContext.prepare(buffer, null, false, buffer.remaining(), false); + readContext.prepare(buffer, null, false); } return fory.getSerializer(arrayType).read(readContext); } @@ -392,7 +388,7 @@ private static Object readTruncatedPrimitiveArrayBody( MemoryBuffer buffer = MemoryBuffer.newHeapBuffer(5); buffer.writeVarUInt32Small7(byteSize); MemoryBuffer truncated = MemoryBuffer.fromByteArray(buffer.getBytes(0, buffer.writerIndex())); - readContext.prepare(truncated, null, false, truncated.remaining(), false); + readContext.prepare(truncated, null, false); return fory.getSerializer(arrayType).read(readContext); } @@ -400,7 +396,7 @@ private static Object readPrimitiveArrayRawBody(Fory fory, Class arrayType) { ReadContext readContext = fory.getReadContext(); MemoryBuffer buffer = MemoryBuffer.newHeapBuffer(5); writeNegativeDecodedVarUInt32(buffer); - readContext.prepare(buffer, null, false, buffer.remaining(), false); + readContext.prepare(buffer, null, false); return fory.getSerializer(arrayType).read(readContext); } @@ -408,7 +404,7 @@ private static Object readObjectArrayBody(Fory fory, Class arrayType, int num ReadContext readContext = fory.getReadContext(); MemoryBuffer buffer = MemoryBuffer.newHeapBuffer(5); buffer.writeVarUInt32Small7(numElements); - readContext.prepare(buffer, null, false, buffer.remaining(), false); + readContext.prepare(buffer, null, false); return fory.getSerializer(arrayType).read(readContext); } diff --git a/java/fory-core/src/test/java/org/apache/fory/serializer/CompatibleSerializerTest.java b/java/fory-core/src/test/java/org/apache/fory/serializer/CompatibleSerializerTest.java index 492bee83f8..ad25757355 100644 --- a/java/fory-core/src/test/java/org/apache/fory/serializer/CompatibleSerializerTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/serializer/CompatibleSerializerTest.java @@ -144,7 +144,7 @@ public void testNullableListBodyBounds() throws Exception { MemoryBuffer buffer = MemoryUtils.buffer(0); Fory fory = builder().build(); ReadContext readContext = fory.getReadContext(); - readContext.prepare(buffer, null, false, buffer.remaining(), false); + readContext.prepare(buffer, null, false); try { InvocationTargetException exception = Assert.expectThrows( diff --git a/java/fory-core/src/test/java/org/apache/fory/serializer/ExceptionSerializersTest.java b/java/fory-core/src/test/java/org/apache/fory/serializer/ExceptionSerializersTest.java index 2ec74b12a0..2b0505e077 100644 --- a/java/fory-core/src/test/java/org/apache/fory/serializer/ExceptionSerializersTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/serializer/ExceptionSerializersTest.java @@ -179,7 +179,7 @@ public void testThrowableReadsMainWireOrderWithCyclicCause() { payload.readerIndex(0); ReadContext readContext = fory.getReadContext(); - readContext.prepare(payload, null, false, payload.remaining(), false); + readContext.prepare(payload, null, false); readContext.preserveRefId(); Serializer serializer = (Serializer) fory.getTypeResolver().getSerializer(RuntimeException.class); @@ -251,7 +251,7 @@ public void testThrowableRejectsMismatchedClassLayerCount() { payload.readerIndex(0); ReadContext readContext = fory.getReadContext(); - readContext.prepare(payload, null, false, payload.remaining(), false); + readContext.prepare(payload, null, false); Serializer serializer = (Serializer) fory.getTypeResolver().getSerializer(CustomException.class); diff --git a/java/fory-core/src/test/java/org/apache/fory/serializer/GraphMemoryBudgetTest.java b/java/fory-core/src/test/java/org/apache/fory/serializer/GraphMemoryBudgetTest.java index f14559c511..867eba616d 100644 --- a/java/fory-core/src/test/java/org/apache/fory/serializer/GraphMemoryBudgetTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/serializer/GraphMemoryBudgetTest.java @@ -19,10 +19,12 @@ package org.apache.fory.serializer; +import static org.apache.fory.io.ForyStreamReader.of; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertThrows; import static org.testng.Assert.assertTrue; +import java.io.ByteArrayInputStream; import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; @@ -37,27 +39,23 @@ import org.testng.annotations.Test; public class GraphMemoryBudgetTest extends ForyTestBase { - private static final long KNOWN_ROOT_MULTIPLIER = 8L; - private static final long KNOWN_ROOT_SLACK_BYTES = 64L * 1024; - private static final long STREAM_ROOT_BYTES = 128L * 1024 * 1024; + private static final long DEFAULT_GRAPH_MEMORY_BYTES = 128L * 1024 * 1024; private static final int REFERENCE_BYTES = 4; private static final int OBJECT_SELF_BYTES = 1; @Test - public void testConfigValidation() { - assertEquals(newFory(-1).getConfig().maxGraphMemoryBytes(), -1); + public void testConfigDefaultsAndDisable() { + assertEquals(builder().build().getConfig().maxGraphMemoryBytes(), DEFAULT_GRAPH_MEMORY_BYTES); assertEquals(newFory(123).getConfig().maxGraphMemoryBytes(), 123); - assertThrows(IllegalArgumentException.class, () -> builder().withMaxGraphMemoryBytes(0)); - assertThrows(IllegalArgumentException.class, () -> builder().withMaxGraphMemoryBytes(-2)); + assertEquals(newFory(0).getConfig().maxGraphMemoryBytes(), 0); + assertEquals(newFory(-2).getConfig().maxGraphMemoryBytes(), -2); } @Test - public void testKnownAutoBudget() { - Fory fory = newFory(-1); - ReadContext readContext = prepareContext(fory, 17, false); + public void testDefaultFixedBudget() { + ReadContext readContext = prepareContext(builder().build()); try { - long budget = knownAutoBytes(17); - readContext.reserveGraphMemory(budget); + readContext.reserveGraphMemory(DEFAULT_GRAPH_MEMORY_BYTES); assertThrows(InsecureException.class, () -> readContext.reserveGraphMemory(1)); } finally { readContext.reset(); @@ -65,12 +63,11 @@ public void testKnownAutoBudget() { } @Test - public void testStreamAutoBudget() { - Fory fory = newFory(-1); - ReadContext readContext = prepareContext(fory, 17, true); + public void testDisabledBudget() { + ReadContext readContext = prepareContext(newFory(0)); try { - readContext.reserveGraphMemory(STREAM_ROOT_BYTES); - assertThrows(InsecureException.class, () -> readContext.reserveGraphMemory(1)); + readContext.reserveGraphMemory(DEFAULT_GRAPH_MEMORY_BYTES + 1); + readContext.reserveGraphMemory(Long.MAX_VALUE); } finally { readContext.reset(); } @@ -79,7 +76,7 @@ public void testStreamAutoBudget() { @Test public void testExplicitBudgetWins() { Fory fory = newFory(7); - ReadContext readContext = prepareContext(fory, 1024 * 1024, false); + ReadContext readContext = prepareContext(fory); try { readContext.reserveGraphMemory(7); assertThrows(InsecureException.class, () -> readContext.reserveGraphMemory(1)); @@ -91,17 +88,21 @@ public void testExplicitBudgetWins() { @Test public void testNestedEmptyContainersUseParentStorage() { List value = emptyLists(1); - byte[] bytes = newFory(-1).serialize(value); + byte[] bytes = builder().build().serialize(value); long required = collectionBytes(1) + collectionBytes(0); assertThrows(InsecureException.class, () -> newFory(required - 1).deserialize(bytes)); + assertThrows( + InsecureException.class, + () -> newFory(required - 1).deserialize(of(new ByteArrayInputStream(bytes)))); assertEquals(newFory(required).deserialize(bytes), value); + assertEquals(newFory(required).deserialize(of(new ByteArrayInputStream(bytes))), value); } @Test public void testSiblingBudgetIsCumulative() { List value = nullLists(2, 64); - byte[] bytes = newFory(-1).serialize(value); + byte[] bytes = builder().build().serialize(value); long firstChildOnly = collectionBytes(2) + collectionBytes(64); assertThrows(InsecureException.class, () -> newFory(firstChildOnly).deserialize(bytes)); @@ -111,7 +112,7 @@ public void testSiblingBudgetIsCumulative() { @Test public void testMapBudgetAndOverflow() { Fory fory = newFory(mapBytes(1) - 1); - ReadContext readContext = prepareContext(fory, 8, false); + ReadContext readContext = prepareContext(fory); try { assertThrows(InsecureException.class, () -> readContext.reserveGraphMemory(mapBytes(1))); } finally { @@ -119,7 +120,7 @@ public void testMapBudgetAndOverflow() { } Fory exactFory = newFory(mapBytes(1)); - ReadContext exactContext = prepareContext(exactFory, 8, false); + ReadContext exactContext = prepareContext(exactFory); try { exactContext.reserveGraphMemory(mapBytes(1)); assertThrows(InsecureException.class, () -> exactContext.reserveGraphMemory(1)); @@ -130,9 +131,9 @@ public void testMapBudgetAndOverflow() { MemoryBuffer buffer = MemoryBuffer.newHeapBuffer(8); buffer.writeVarUInt32Small7(Integer.MAX_VALUE); buffer = trimBuffer(buffer); - Fory reader = newFory(STREAM_ROOT_BYTES); + Fory reader = newFory(DEFAULT_GRAPH_MEMORY_BYTES); ReadContext mapContext = reader.getReadContext(); - mapContext.prepare(buffer, null, false, buffer.remaining(), false); + mapContext.prepare(buffer, null, false); try { assertThrows( DeserializationException.class, @@ -147,7 +148,7 @@ public void testObjectArrayBudget() { Fory exactFory = newFory(1); ReadContext exactContext = exactFory.getReadContext(); MemoryBuffer exactBuffer = objectArraySizeBuffer(0); - exactContext.prepare(exactBuffer, null, false, exactBuffer.remaining(), false); + exactContext.prepare(exactBuffer, null, false); try { Object[] array = (Object[]) exactFory.getSerializer(Object[].class).read(exactContext); assertEquals(array.length, 0); @@ -158,7 +159,7 @@ public void testObjectArrayBudget() { Fory slotFory = newFory(objectArrayBytes(2) - 1); ReadContext slotContext = slotFory.getReadContext(); MemoryBuffer slotBuffer = objectArraySizeBuffer(2); - slotContext.prepare(slotBuffer, null, false, slotBuffer.remaining(), false); + slotContext.prepare(slotBuffer, null, false); try { assertThrows( InsecureException.class, () -> slotFory.getSerializer(Object[].class).read(slotContext)); @@ -170,7 +171,7 @@ public void testObjectArrayBudget() { @Test public void testPojoGraphBudget() { Pojo value = new Pojo(7, 9L, "child string is skipped as a leaf"); - byte[] bytes = newFory(-1).serialize(value); + byte[] bytes = builder().build().serialize(value); long required = pojoBytes(); assertThrows(InsecureException.class, () -> newFory(required - 1, false).deserialize(bytes)); @@ -185,7 +186,7 @@ public void testNestedEmptyPojoGraphBudget() { ArrayList value = new ArrayList<>(); value.add(new EmptyPojo()); value.add(new EmptyPojo()); - byte[] bytes = newFory(-1).serialize(value); + byte[] bytes = builder().build().serialize(value); long required = collectionBytes(2) + 2L * emptyPojoBytes(); assertThrows(InsecureException.class, () -> newFory(required - 1).deserialize(bytes)); @@ -219,7 +220,7 @@ public void testTruncatedCollectionStillFails() { buffer.writeByte(0); buffer = trimBuffer(buffer); ReadContext readContext = fory.getReadContext(); - readContext.prepare(buffer, null, false, buffer.remaining(), false); + readContext.prepare(buffer, null, false); try { assertThrows( IndexOutOfBoundsException.class, @@ -237,11 +238,10 @@ private static Fory newFory(long maxGraphMemoryBytes, boolean codegen) { return builder().withMaxGraphMemoryBytes(maxGraphMemoryBytes).withCodegen(codegen).build(); } - private static ReadContext prepareContext( - Fory fory, int rootInputBytes, boolean unknownLengthInput) { + private static ReadContext prepareContext(Fory fory) { MemoryBuffer buffer = MemoryBuffer.newHeapBuffer(0); ReadContext readContext = fory.getReadContext(); - readContext.prepare(buffer, null, false, rootInputBytes, unknownLengthInput); + readContext.prepare(buffer, null, false); return readContext; } @@ -265,10 +265,6 @@ private static long pojoBytes() { return OBJECT_SELF_BYTES + 4 + 8 + REFERENCE_BYTES; } - private static long knownAutoBytes(int inputBytes) { - return inputBytes * KNOWN_ROOT_MULTIPLIER + KNOWN_ROOT_SLACK_BYTES; - } - private static List emptyLists(int numElements) { List root = new ArrayList<>(numElements); for (int i = 0; i < numElements; i++) { diff --git a/java/fory-core/src/test/java/org/apache/fory/serializer/PrimitiveSerializersTest.java b/java/fory-core/src/test/java/org/apache/fory/serializer/PrimitiveSerializersTest.java index 4c7d650331..bc945d60ff 100644 --- a/java/fory-core/src/test/java/org/apache/fory/serializer/PrimitiveSerializersTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/serializer/PrimitiveSerializersTest.java @@ -307,7 +307,7 @@ private static Object readPrimitiveListBody(Fory fory, Class listType, int he MemoryBuffer buffer = MemoryBuffer.newHeapBuffer(5); buffer.writeVarUInt32Small7(headerSize); ReadContext readContext = fory.getReadContext(); - readContext.prepare(buffer, null, false, buffer.remaining(), false); + readContext.prepare(buffer, null, false); return fory.getSerializer(listType).read(readContext); } @@ -315,7 +315,7 @@ private static Object readPrimitiveListRawBody(Fory fory, Class listType) { MemoryBuffer buffer = MemoryBuffer.newHeapBuffer(5); writeNegativeDecodedVarUInt32(buffer); ReadContext readContext = fory.getReadContext(); - readContext.prepare(buffer, null, false, buffer.remaining(), false); + readContext.prepare(buffer, null, false); return fory.getSerializer(listType).read(readContext); } diff --git a/java/fory-core/src/test/java/org/apache/fory/serializer/collection/ChildContainerSerializersTest.java b/java/fory-core/src/test/java/org/apache/fory/serializer/collection/ChildContainerSerializersTest.java index cfc4cfd60f..afc536be28 100644 --- a/java/fory-core/src/test/java/org/apache/fory/serializer/collection/ChildContainerSerializersTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/serializer/collection/ChildContainerSerializersTest.java @@ -143,7 +143,7 @@ public void testChildCollectionRejectsMismatchedClassLayerCount() { payload.readerIndex(0); ReadContext readContext = fory.getReadContext(); - readContext.prepare(payload, null, false, payload.remaining(), false); + readContext.prepare(payload, null, false); Serializer serializer = (Serializer) fory.getTypeResolver().getSerializer(ChildArrayList.class); diff --git a/java/fory-core/src/test/java/org/apache/fory/serializer/collection/CollectionSerializersTest.java b/java/fory-core/src/test/java/org/apache/fory/serializer/collection/CollectionSerializersTest.java index c8ccbc8310..ff42616527 100644 --- a/java/fory-core/src/test/java/org/apache/fory/serializer/collection/CollectionSerializersTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/serializer/collection/CollectionSerializersTest.java @@ -1421,7 +1421,7 @@ public void testBitSetRejectsNegativeBinary() { MemoryBuffer buffer = MemoryBuffer.newHeapBuffer(5); writeNegativeDecodedVarUInt32(buffer); ReadContext readContext = fory.getReadContext(); - readContext.prepare(buffer, null, false, buffer.remaining(), false); + readContext.prepare(buffer, null, false); Assert.expectThrows( DeserializationException.class, () -> fory.getSerializer(BitSet.class).read(readContext)); } diff --git a/javascript/packages/core/lib/context.ts b/javascript/packages/core/lib/context.ts index 3d2e4b9f2b..3fc57aa761 100644 --- a/javascript/packages/core/lib/context.ts +++ b/javascript/packages/core/lib/context.ts @@ -531,8 +531,6 @@ export class WriteContext { export class ReadContext { private static readonly MIN_REMOTE_TYPE_META_LIMIT = 8192; - private static readonly KNOWN_ROOT_BUDGET_MULTIPLIER = 8; - private static readonly KNOWN_ROOT_BUDGET_SLACK_BYTES = 64 * 1024; readonly reader: BinaryReader; readonly refReader: RefReader; @@ -574,10 +572,7 @@ export class ReadContext { this.typeMeta = []; this._depth = 0; this.effectiveGraphMemoryBytes = - this.maxGraphMemoryBytes > 0 - ? this.maxGraphMemoryBytes - : bytes.byteLength * ReadContext.KNOWN_ROOT_BUDGET_MULTIPLIER + - ReadContext.KNOWN_ROOT_BUDGET_SLACK_BYTES; + this.maxGraphMemoryBytes > 0 ? this.maxGraphMemoryBytes : 0; this.remainingGraphMemoryBytes = this.effectiveGraphMemoryBytes; } @@ -585,6 +580,9 @@ export class ReadContext { if (!Number.isSafeInteger(bytes) || bytes < 0) { this.throwGraphMemoryOverflow(bytes); } + if (this.effectiveGraphMemoryBytes <= 0) { + return; + } const remaining = this.remainingGraphMemoryBytes - bytes; if (remaining < 0) { this.throwGraphBudgetExceeded(bytes); diff --git a/javascript/packages/core/lib/fory.ts b/javascript/packages/core/lib/fory.ts index 0e11fd6c2a..f17f5ab7c5 100644 --- a/javascript/packages/core/lib/fory.ts +++ b/javascript/packages/core/lib/fory.ts @@ -38,7 +38,7 @@ const DEFAULT_MAX_TYPE_FIELDS = 512 as const; const DEFAULT_MAX_TYPE_META_BYTES = 4096 as const; const DEFAULT_MAX_SCHEMA_VERSIONS_PER_TYPE = 10 as const; const DEFAULT_MAX_AVERAGE_SCHEMA_VERSIONS_PER_TYPE = 3 as const; -const DEFAULT_MAX_GRAPH_MEMORY_BYTES = -1 as const; +const DEFAULT_MAX_GRAPH_MEMORY_BYTES = 128 * 1024 * 1024; export default class Fory { readonly typeResolver: TypeResolver; readonly anySerializer: Serializer; @@ -108,12 +108,9 @@ export default class Fory { } const maxGraphMemoryBytes = config?.maxGraphMemoryBytes ?? DEFAULT_MAX_GRAPH_MEMORY_BYTES; - if ( - !Number.isSafeInteger(maxGraphMemoryBytes) || - (maxGraphMemoryBytes !== -1 && maxGraphMemoryBytes <= 0) - ) { + if (!Number.isSafeInteger(maxGraphMemoryBytes)) { throw new Error( - `maxGraphMemoryBytes must be -1 or a positive safe integer but got ${maxGraphMemoryBytes}`, + `maxGraphMemoryBytes must be a safe integer but got ${maxGraphMemoryBytes}`, ); } return { diff --git a/javascript/test/graphMemoryBudget.test.ts b/javascript/test/graphMemoryBudget.test.ts index d3bb3ce912..b7b21158ae 100644 --- a/javascript/test/graphMemoryBudget.test.ts +++ b/javascript/test/graphMemoryBudget.test.ts @@ -20,7 +20,7 @@ import Fory, { Type } from "../packages/core/index"; import { describe, expect, test } from "@jest/globals"; -const KNOWN_SLACK_BYTES = 64 * 1024; +const DEFAULT_GRAPH_MEMORY_BYTES = 128 * 1024 * 1024; const OBJECT_BYTES = 1; const REFERENCE_BYTES = 4; @@ -41,26 +41,19 @@ function deserializeAny(bytes: Uint8Array, maxGraphMemoryBytes: number) { } describe("graph memory budget", () => { - test("uses known length auto budget", () => { - const inputBytes = 17; + test("uses fixed default budget", () => { const fory = new Fory({ compatible: false }); - const budget = inputBytes * 8 + KNOWN_SLACK_BYTES; - fory.readContext.reset(new Uint8Array(inputBytes)); - expect(() => fory.readContext.reserveGraphMemory(budget)).not.toThrow(); + fory.readContext.reset(new Uint8Array(17)); + expect(() => + fory.readContext.reserveGraphMemory(DEFAULT_GRAPH_MEMORY_BYTES), + ).not.toThrow(); expect(() => fory.readContext.reserveGraphMemory(1)).toThrow( /maxGraphMemoryBytes/, ); }); - test("validates explicit config", () => { - expect(() => new Fory({ maxGraphMemoryBytes: 0 })).toThrow( - /maxGraphMemoryBytes/, - ); - expect(() => new Fory({ maxGraphMemoryBytes: -2 })).toThrow( - /maxGraphMemoryBytes/, - ); - + test("handles explicit config and disable", () => { const fory = new Fory({ maxGraphMemoryBytes: 24 }); fory.readContext.reset(new Uint8Array(1)); expect(() => fory.readContext.reserveGraphMemory(0)).not.toThrow(); @@ -68,6 +61,13 @@ describe("graph memory budget", () => { expect(() => fory.readContext.reserveGraphMemory(1)).toThrow( /maxGraphMemoryBytes/, ); + + const disabled = new Fory({ maxGraphMemoryBytes: 0 }); + disabled.readContext.reset(new Uint8Array(1)); + expect(() => + disabled.readContext.reserveGraphMemory(Number.MAX_SAFE_INTEGER), + ).not.toThrow(); + expect(() => new Fory({ maxGraphMemoryBytes: -2 })).not.toThrow(); }); test("uses parent storage for nested empty containers", () => { diff --git a/python/pyfory/_fory.py b/python/pyfory/_fory.py index 9a3e65f6de..9082d00acd 100644 --- a/python/pyfory/_fory.py +++ b/python/pyfory/_fory.py @@ -140,7 +140,7 @@ def __init__( max_type_meta_bytes: int = 4096, max_schema_versions_per_type: int = 10, max_average_schema_versions_per_type: int = 3, - max_graph_memory_bytes: int = -1, + max_graph_memory_bytes: int = 128 * 1024 * 1024, policy: DeserializationPolicy = None, field_nullable: bool = False, meta_compressor=None, @@ -186,7 +186,8 @@ def __init__( across accepted remote types. max_graph_memory_bytes: Maximum estimated graph memory per root - deserialization. `-1` means auto; positive values are explicit byte limits. + deserialization. Defaults to 128 MiB; positive values are explicit byte limits, + and non-positive values intentionally disable this protection. policy: Custom deserialization policy for security checks. When provided, it controls which types can be deserialized, overriding the default policy. @@ -220,10 +221,10 @@ def __init__( raise ValueError("max_average_schema_versions_per_type must be a positive integer") if ( not isinstance(max_graph_memory_bytes, int) - or (max_graph_memory_bytes != -1 and max_graph_memory_bytes <= 0) or max_graph_memory_bytes > (1 << 63) - 1 + or max_graph_memory_bytes < -(1 << 63) ): - raise ValueError("max_graph_memory_bytes must be -1 or a positive 63-bit integer") + raise ValueError("max_graph_memory_bytes must be a 63-bit integer") self.max_graph_memory_bytes = max_graph_memory_bytes self.config = Config( xlang=xlang, @@ -572,7 +573,6 @@ def _deserialize( buffers=buffers, unsupported_objects=unsupported_objects, peer_out_of_band_enabled=peer_out_of_band_enabled, - root_input_bytes=buffer.size() - reader_index, ) return read_context.read_ref() diff --git a/python/pyfory/context.pxi b/python/pyfory/context.pxi index 10ac94bf45..14405b46f5 100644 --- a/python/pyfory/context.pxi +++ b/python/pyfory/context.pxi @@ -30,9 +30,6 @@ STRING_TYPE_ID = TypeId.STRING SMALL_STRING_THRESHOLD = 16 cdef int32_t MAX_CACHED_META_STRINGS = 8192 cdef int32_t MAX_CACHED_META_STRING_LENGTH = 2048 -cdef int64_t _KNOWN_ROOT_BUDGET_MULTIPLIER = 8 -cdef int64_t _KNOWN_ROOT_BUDGET_SLACK_BYTES = 64 * 1024 -cdef int64_t _STREAM_ROOT_BUDGET_BYTES = 128 * 1024 * 1024 cdef int64_t _MAX_GRAPH_MEMORY_BYTES = 9223372036854775807 @@ -793,26 +790,16 @@ cdef class ReadContext: buffers=None, unsupported_objects=None, bint peer_out_of_band_enabled=False, - int64_t root_input_bytes=-1, ): cdef int64_t limit - if self.max_graph_memory_bytes > 0: - limit = self.max_graph_memory_bytes - elif buffer.has_input_stream(): - limit = _STREAM_ROOT_BUDGET_BYTES - else: - if root_input_bytes < 0: - root_input_bytes = buffer.size() - buffer.get_reader_index() - if root_input_bytes > (_MAX_GRAPH_MEMORY_BYTES - _KNOWN_ROOT_BUDGET_SLACK_BYTES) // _KNOWN_ROOT_BUDGET_MULTIPLIER: - raise ValueError("max_graph_memory_bytes auto budget overflow") - limit = root_input_bytes * _KNOWN_ROOT_BUDGET_MULTIPLIER + _KNOWN_ROOT_BUDGET_SLACK_BYTES + limit = self.max_graph_memory_bytes if self.max_graph_memory_bytes > 0 else 0 self.buffer = buffer self.c_buffer = buffer.c_buffer self.buffers = iter(buffers) if buffers is not None else None self.unsupported_objects = iter(unsupported_objects) if unsupported_objects is not None else None self.peer_out_of_band_enabled = peer_out_of_band_enabled self.graph_memory_limit_bytes = limit - self.remaining_graph_memory_bytes = limit + self.remaining_graph_memory_bytes = limit if limit > 0 else _MAX_GRAPH_MEMORY_BYTES self.depth = 0 cpdef inline reset(self): @@ -837,6 +824,8 @@ cdef class ReadContext: raise ValueError("Estimated graph memory is negative") if num_bytes > _MAX_GRAPH_MEMORY_BYTES: raise ValueError("Estimated graph memory overflow") + if self.graph_memory_limit_bytes <= 0: + return if num_bytes > self.remaining_graph_memory_bytes: used = self.graph_memory_limit_bytes - self.remaining_graph_memory_bytes raise ValueError( @@ -848,6 +837,8 @@ cdef class ReadContext: cdef inline void reserve_graph_memory_fast(self, int64_t num_bytes): cdef int64_t used + if self.graph_memory_limit_bytes <= 0: + return if num_bytes > self.remaining_graph_memory_bytes: used = self.graph_memory_limit_bytes - self.remaining_graph_memory_bytes raise ValueError( @@ -857,22 +848,12 @@ cdef class ReadContext: ) self.remaining_graph_memory_bytes -= num_bytes - cpdef inline reserve_graph_memory(self, int64_t num_bytes): - self.reserve_graph_memory_c(num_bytes) - - cdef inline void reserve_counted_graph_memory_c( - self, - int64_t count, - int64_t element_bytes, - ): - if count < 0 or element_bytes < 0: + cpdef inline reserve_graph_memory(self, num_bytes): + if num_bytes < 0: raise ValueError("Estimated graph memory is negative") - if element_bytes != 0 and count > _MAX_GRAPH_MEMORY_BYTES // element_bytes: + if num_bytes > _MAX_GRAPH_MEMORY_BYTES: raise ValueError("Estimated graph memory overflow") - self.reserve_graph_memory_c(count * element_bytes) - - cpdef inline reserve_counted_graph_memory(self, int64_t count, int64_t element_bytes): - self.reserve_counted_graph_memory_c(count, element_bytes) + self.reserve_graph_memory_c(num_bytes) cpdef inline add_context_object(self, key, obj): self.context_objects[id(key)] = obj diff --git a/python/pyfory/context.py b/python/pyfory/context.py index 62ac7810a9..d894516d80 100644 --- a/python/pyfory/context.py +++ b/python/pyfory/context.py @@ -37,9 +37,6 @@ FLOAT64_TYPE_ID = TypeId.FLOAT64 BOOL_TYPE_ID = TypeId.BOOL STRING_TYPE_ID = TypeId.STRING -_KNOWN_ROOT_BUDGET_MULTIPLIER = 8 -_KNOWN_ROOT_BUDGET_SLACK_BYTES = 64 * 1024 -_STREAM_ROOT_BUDGET_BYTES = 128 * 1024 * 1024 _MAX_GRAPH_MEMORY_BYTES = (1 << 63) - 1 @@ -529,26 +526,14 @@ def prepare( buffers=None, unsupported_objects=None, peer_out_of_band_enabled=False, - root_input_bytes=None, ): - if self.max_graph_memory_bytes > 0: - limit = self.max_graph_memory_bytes - elif buffer.has_input_stream(): - limit = _STREAM_ROOT_BUDGET_BYTES - else: - if root_input_bytes is None: - root_input_bytes = buffer.size() - buffer.get_reader_index() - if root_input_bytes < 0: - raise ValueError("root input byte count is negative") - if root_input_bytes > (_MAX_GRAPH_MEMORY_BYTES - _KNOWN_ROOT_BUDGET_SLACK_BYTES) // _KNOWN_ROOT_BUDGET_MULTIPLIER: - raise ValueError("max_graph_memory_bytes auto budget overflow") - limit = root_input_bytes * _KNOWN_ROOT_BUDGET_MULTIPLIER + _KNOWN_ROOT_BUDGET_SLACK_BYTES + limit = self.max_graph_memory_bytes if self.max_graph_memory_bytes > 0 else 0 self.buffer = buffer self.buffers = iter(buffers) if buffers is not None else None self.unsupported_objects = iter(unsupported_objects) if unsupported_objects is not None else None self.peer_out_of_band_enabled = peer_out_of_band_enabled self.graph_memory_limit_bytes = limit - self.remaining_graph_memory_bytes = limit + self.remaining_graph_memory_bytes = limit if limit > 0 else _MAX_GRAPH_MEMORY_BYTES self.depth = 0 def reset(self): @@ -571,6 +556,8 @@ def reserve_graph_memory(self, num_bytes): raise ValueError("Estimated graph memory is negative") if num_bytes > _MAX_GRAPH_MEMORY_BYTES: raise ValueError("Estimated graph memory overflow") + if self.graph_memory_limit_bytes <= 0: + return remaining = self.remaining_graph_memory_bytes if num_bytes > remaining: used = self.graph_memory_limit_bytes - remaining @@ -581,13 +568,6 @@ def reserve_graph_memory(self, num_bytes): ) self.remaining_graph_memory_bytes = remaining - num_bytes - def reserve_counted_graph_memory(self, count, element_bytes): - if count < 0 or element_bytes < 0: - raise ValueError("Estimated graph memory is negative") - if element_bytes and count > _MAX_GRAPH_MEMORY_BYTES // element_bytes: - raise ValueError("Estimated graph memory overflow") - self.reserve_graph_memory(count * element_bytes) - def add_context_object(self, key, obj): self.context_objects[id(key)] = obj diff --git a/python/pyfory/serialization.pyx b/python/pyfory/serialization.pyx index 54d72cc5cb..56e46b32c6 100644 --- a/python/pyfory/serialization.pyx +++ b/python/pyfory/serialization.pyx @@ -114,7 +114,8 @@ cdef class Config: max_schema_versions_per_type: Maximum accepted remote metadata versions for one logical type. max_average_schema_versions_per_type: Average remote schema versions allowed across accepted remote types. max_graph_memory_bytes: Maximum estimated graph memory per root - deserialization. -1 means auto; positive values are explicit byte limits. + deserialization. Defaults to 128 MiB; positive values are explicit byte limits, + and non-positive values intentionally disable this protection. field_nullable: Treats struct/dataclass fields as nullable by default. policy: Deserialization policy used for security-sensitive checks. meta_compressor: Optional typedef/meta compressor implementation. @@ -171,7 +172,8 @@ cdef class Config: max_schema_versions_per_type: Maximum accepted remote metadata versions for one logical type. max_average_schema_versions_per_type: Average remote schema versions allowed across accepted remote types. max_graph_memory_bytes: Maximum estimated graph memory per root - deserialization. -1 means auto; positive values are explicit byte limits. + deserialization. Defaults to 128 MiB; positive values are explicit byte limits, + and non-positive values intentionally disable this protection. field_nullable: Treat all struct fields as nullable by default. policy: Deserialization policy implementation. meta_compressor: Optional typedef/meta compressor. @@ -193,10 +195,10 @@ cdef class Config: raise ValueError("max_average_schema_versions_per_type must be a positive integer") if ( not isinstance(max_graph_memory_bytes, int) - or (max_graph_memory_bytes != -1 and max_graph_memory_bytes <= 0) or max_graph_memory_bytes > 9223372036854775807 + or max_graph_memory_bytes < -9223372036854775808 ): - raise ValueError("max_graph_memory_bytes must be -1 or a positive 63-bit integer") + raise ValueError("max_graph_memory_bytes must be a 63-bit integer") self.max_type_fields = max_type_fields self.max_type_meta_bytes = max_type_meta_bytes self.max_schema_versions_per_type = max_schema_versions_per_type @@ -861,7 +863,7 @@ cdef class Fory: max_type_meta_bytes=4096, max_schema_versions_per_type=10, max_average_schema_versions_per_type=3, - max_graph_memory_bytes=-1, + max_graph_memory_bytes=128 * 1024 * 1024, policy=None, field_nullable=False, meta_compressor=None, @@ -881,7 +883,8 @@ cdef class Fory: max_schema_versions_per_type: Maximum accepted remote metadata versions for one logical type. max_average_schema_versions_per_type: Average remote schema versions allowed across accepted remote types. max_graph_memory_bytes: Maximum estimated graph memory per root - deserialization. -1 means auto; positive values are explicit byte limits. + deserialization. Defaults to 128 MiB; positive values are explicit byte limits, + and non-positive values intentionally disable this protection. policy: Optional deserialization policy implementation. field_nullable: Treat struct fields as nullable by default. meta_compressor: Optional typedef/meta compressor implementation. @@ -901,10 +904,10 @@ cdef class Fory: self.max_depth = max_depth if ( not isinstance(max_graph_memory_bytes, int) - or (max_graph_memory_bytes != -1 and max_graph_memory_bytes <= 0) or max_graph_memory_bytes > 9223372036854775807 + or max_graph_memory_bytes < -9223372036854775808 ): - raise ValueError("max_graph_memory_bytes must be -1 or a positive 63-bit integer") + raise ValueError("max_graph_memory_bytes must be a 63-bit integer") self.max_graph_memory_bytes = max_graph_memory_bytes self.config = Config( xlang=xlang, @@ -1076,7 +1079,6 @@ cdef class Fory: cdef int32_t reader_index cdef uint8_t bitmap cdef bint peer_out_of_band_enabled - cdef int64_t root_input_bytes cdef int64_t graph_memory_limit if isinstance(buffer, bytes): buffer = Buffer(buffer) @@ -1093,13 +1095,7 @@ cdef class Fory: raise ValueError("Out-of-band buffers are required by the root header") if not peer_out_of_band_enabled and buffers is not None: raise ValueError("Out-of-band buffers were provided for an in-band root payload") - if self.max_graph_memory_bytes > 0: - graph_memory_limit = self.max_graph_memory_bytes - elif read_buffer.has_input_stream(): - graph_memory_limit = _STREAM_ROOT_BUDGET_BYTES - else: - root_input_bytes = read_buffer.size() - reader_index - graph_memory_limit = root_input_bytes * _KNOWN_ROOT_BUDGET_MULTIPLIER + _KNOWN_ROOT_BUDGET_SLACK_BYTES + graph_memory_limit = self.max_graph_memory_bytes if self.max_graph_memory_bytes > 0 else 0 # Keep the root context setup inline. Top-level deserialize is a hot path, # so it should not pay an extra method call just to bind the active buffer. read_context.buffer = read_buffer @@ -1110,7 +1106,7 @@ cdef class Fory: ) read_context.peer_out_of_band_enabled = peer_out_of_band_enabled read_context.graph_memory_limit_bytes = graph_memory_limit - read_context.remaining_graph_memory_bytes = graph_memory_limit + read_context.remaining_graph_memory_bytes = graph_memory_limit if graph_memory_limit > 0 else _MAX_GRAPH_MEMORY_BYTES read_context.depth = 0 return read_context.read_ref() diff --git a/python/pyfory/tests/test_graph_memory_budget.py b/python/pyfory/tests/test_graph_memory_budget.py index d777b719bf..514f1fbac2 100644 --- a/python/pyfory/tests/test_graph_memory_budget.py +++ b/python/pyfory/tests/test_graph_memory_budget.py @@ -31,9 +31,7 @@ np = None -KNOWN_ROOT_BUDGET_MULTIPLIER = 8 -KNOWN_ROOT_BUDGET_SLACK_BYTES = 64 * 1024 -STREAM_ROOT_BUDGET_BYTES = 128 * 1024 * 1024 +DEFAULT_GRAPH_MEMORY_BYTES = 128 * 1024 * 1024 REFERENCE_BYTES = struct.calcsize("P") OWNER_BYTES = 1 MAX_GRAPH_MEMORY_BYTES = (1 << 63) - 1 @@ -104,7 +102,7 @@ def object_memory(num_fields): return OWNER_BYTES + num_fields * REFERENCE_BYTES -def new_fory(limit=-1, *, xlang=True): +def new_fory(limit=DEFAULT_GRAPH_MEMORY_BYTES, *, xlang=True): return pyfory.Fory( xlang=xlang, ref=True, @@ -128,34 +126,38 @@ def varuint_payload(value): return buffer.to_bytes(0, buffer.get_writer_index()) -def test_known_length_auto_budget(): +def test_fixed_default_budget(): fory = new_fory(xlang=False) - root_input_bytes = 17 try: - fory.read_context.prepare(Buffer(b"x" * root_input_bytes), root_input_bytes=root_input_bytes) - expected = root_input_bytes * KNOWN_ROOT_BUDGET_MULTIPLIER + KNOWN_ROOT_BUDGET_SLACK_BYTES - assert fory.read_context.graph_memory_limit_bytes == expected - fory.read_context.reserve_graph_memory(expected) + fory.read_context.prepare(Buffer(b"x" * 17)) + assert fory.read_context.graph_memory_limit_bytes == DEFAULT_GRAPH_MEMORY_BYTES + fory.read_context.reserve_graph_memory(DEFAULT_GRAPH_MEMORY_BYTES) with pytest.raises(ValueError, match="Estimated graph memory budget exceeded"): fory.read_context.reserve_graph_memory(1) finally: fory.reset_read() -def test_stream_auto_budget(): +def test_stream_uses_fixed_default_budget(): fory = new_fory(xlang=False) try: buffer = Buffer.from_stream(OneByteStream(b"streamed")) - fory.read_context.prepare(buffer, root_input_bytes=1) - assert fory.read_context.graph_memory_limit_bytes == STREAM_ROOT_BUDGET_BYTES + fory.read_context.prepare(buffer) + assert fory.read_context.graph_memory_limit_bytes == DEFAULT_GRAPH_MEMORY_BYTES finally: fory.reset_read() -def test_explicit_config_overrides_auto(): +def test_explicit_config_and_disable(): value = [1] budget = collection_memory(1) assert expect_budget(value, budget) == value + disabled = new_fory(0, xlang=False) + try: + disabled.read_context.prepare(Buffer(b"x")) + disabled.read_context.reserve_graph_memory(MAX_GRAPH_MEMORY_BYTES) + finally: + disabled.reset_read() def test_nested_empty_containers_use_parent_storage(): @@ -190,7 +192,17 @@ def test_dynamic_object_owner_is_charged(): value.left = 1 value.right = "x" budget = object_memory(2) - restored = expect_budget(value, budget, xlang=False) + + writer = new_fory(xlang=False) + writer.register_type(BudgetObject) + data = writer.serialize(value) + with pytest.raises(ValueError, match="Estimated graph memory budget exceeded"): + reader = new_fory(budget - 1, xlang=False) + reader.register_type(BudgetObject) + reader.deserialize(data) + reader = new_fory(budget, xlang=False) + reader.register_type(BudgetObject) + restored = reader.deserialize(data) assert restored.left == value.left assert restored.right == value.right @@ -201,10 +213,8 @@ def test_map_entry_budget_and_overflow(): fory = new_fory(xlang=False) try: - fory.read_context.prepare(Buffer(b""), root_input_bytes=0) - max_map_entries = MAX_GRAPH_MEMORY_BYTES // (2 * REFERENCE_BYTES) with pytest.raises(ValueError, match="Estimated graph memory overflow"): - fory.read_context.reserve_counted_graph_memory(max_map_entries + 1, 2 * REFERENCE_BYTES) + fory.read_context.reserve_graph_memory(MAX_GRAPH_MEMORY_BYTES + 1) finally: fory.reset_read() @@ -243,7 +253,7 @@ def test_declared_large_list_still_needs_bytes(): fory = new_fory(10_000_000, xlang=False) serializer = ListSerializer(fory.type_resolver, list) try: - fory.read_context.prepare(Buffer(varuint_payload(1000)), root_input_bytes=1) + fory.read_context.prepare(Buffer(varuint_payload(1000))) with pytest.raises(Exception) as exc_info: serializer.read(fory.read_context) assert "Estimated graph memory" not in str(exc_info.value) @@ -251,7 +261,7 @@ def test_declared_large_list_still_needs_bytes(): fory.reset_read() -@pytest.mark.parametrize("limit", [0, -2, 1 << 63]) +@pytest.mark.parametrize("limit", [1 << 63, -(1 << 63) - 1]) def test_invalid_config(limit): with pytest.raises(ValueError, match="max_graph_memory_bytes"): new_fory(limit) diff --git a/rust/fory-core/src/config.rs b/rust/fory-core/src/config.rs index 64c76d98f0..76f4d652c2 100644 --- a/rust/fory-core/src/config.rs +++ b/rust/fory-core/src/config.rs @@ -41,7 +41,8 @@ pub struct Config { /// and preserved during serialization/deserialization. pub track_ref: bool, /// Maximum estimated graph memory accepted during one root deserialization. - /// `-1` selects the automatic input-shaped limit. + /// Defaults to 128 MiB. Positive values are explicit limits; non-positive + /// values intentionally disable this protection. pub max_graph_memory_bytes: i64, /// Maximum accepted field count in one received struct TypeMeta. pub max_type_fields: u32, @@ -64,7 +65,7 @@ impl Default for Config { max_dyn_depth: 5, check_struct_version: false, track_ref: false, - max_graph_memory_bytes: -1, + max_graph_memory_bytes: 128 * 1024 * 1024, max_type_fields: 512, max_type_meta_bytes: 4096, max_schema_versions_per_type: 10, diff --git a/rust/fory-core/src/context.rs b/rust/fory-core/src/context.rs index 549370b5bf..0f57f06b23 100644 --- a/rust/fory-core/src/context.rs +++ b/rust/fory-core/src/context.rs @@ -31,10 +31,6 @@ use crate::type_id as types; use crate::TypeId; use std::rc::Rc; -const KNOWN_ROOT_BUDGET_MULTIPLIER: usize = 8; -const KNOWN_ROOT_BUDGET_SLACK_BYTES: usize = 64 * 1024; -const MAX_GRAPH_COUNT: usize = u32::MAX as usize; - /// Thread-local context cache with fast path for single Fory instance. /// Uses (cached_id, context) for O(1) access when using same Fory instance repeatedly. /// Falls back to HashMap for multiple Fory instances per thread. @@ -454,42 +450,24 @@ impl<'a> ReadContext<'a> { } #[inline(always)] - pub(crate) fn init_graph_memory_budget( - &mut self, - root_input_bytes: usize, - ) -> Result<(), Error> { + pub(crate) fn init_graph_memory_budget(&mut self) -> Result<(), Error> { let limit = if self.max_graph_memory_bytes > 0 { usize::try_from(self.max_graph_memory_bytes) .map_err(|_| graph_memory_error("max_graph_memory_bytes does not fit usize"))? } else { - if root_input_bytes - > (usize::MAX - KNOWN_ROOT_BUDGET_SLACK_BYTES) / KNOWN_ROOT_BUDGET_MULTIPLIER - { - return Err(graph_memory_error( - "root input size overflows automatic graph memory budget", - )); - } - root_input_bytes * KNOWN_ROOT_BUDGET_MULTIPLIER + KNOWN_ROOT_BUDGET_SLACK_BYTES + 0 }; self.graph_memory_limit_bytes = limit; - self.remaining_graph_memory_bytes = limit; + self.remaining_graph_memory_bytes = if limit > 0 { limit } else { usize::MAX }; Ok(()) } - #[inline(always)] - pub(crate) fn reserve_counted_graph_memory( - &mut self, - len: u32, - elem_bytes: usize, - ) -> Result { - let len = len as usize; - self.reserve_counted_graph_bytes(len, elem_bytes)?; - Ok(len) - } - #[inline(always)] #[doc(hidden)] pub fn reserve_graph_memory(&mut self, bytes: usize) -> Result<(), Error> { + if self.graph_memory_limit_bytes == 0 { + return Ok(()); + } let remaining = self.remaining_graph_memory_bytes; if bytes > remaining { return Err(graph_memory_exceeded( @@ -502,31 +480,6 @@ impl<'a> ReadContext<'a> { Ok(()) } - #[inline(always)] - fn reserve_counted_graph_bytes(&mut self, len: usize, elem_bytes: usize) -> Result<(), Error> { - if len == 0 { - return Ok(()); - } - if elem_bytes <= usize::MAX / MAX_GRAPH_COUNT { - return self.reserve_graph_memory(len * elem_bytes); - } - self.reserve_counted_graph_checked(len, elem_bytes) - } - - #[cold] - #[inline(never)] - fn reserve_counted_graph_checked( - &mut self, - len: usize, - elem_bytes: usize, - ) -> Result<(), Error> { - let bytes = match len.checked_mul(elem_bytes) { - Some(bytes) => bytes, - None => return Err(graph_memory_overflow(len, elem_bytes)), - }; - self.reserve_graph_memory(bytes) - } - #[inline(always)] pub fn detach_reader(&mut self) -> Reader<'_> { mem::take(&mut self.reader) @@ -643,15 +596,6 @@ fn graph_memory_error(message: &'static str) -> Error { Error::invalid_data(message) } -#[cold] -#[inline(never)] -fn graph_memory_overflow(len: usize, elem_bytes: usize) -> Error { - Error::invalid_data(format!( - "graph memory estimate overflows: length={} elementBytes={}", - len, elem_bytes - )) -} - #[cold] #[inline(never)] fn graph_memory_exceeded(bytes: usize, remaining: usize, limit: usize) -> Error { diff --git a/rust/fory-core/src/fory.rs b/rust/fory-core/src/fory.rs index fd3bfdb827..ea237b96fc 100644 --- a/rust/fory-core/src/fory.rs +++ b/rust/fory-core/src/fory.rs @@ -263,12 +263,9 @@ impl ForyBuilder { /// Sets the maximum estimated graph memory accepted during one root deserialization. /// - /// Use `-1` for the automatic input-shaped limit. Positive values are explicit byte limits. + /// Defaults to 128 MiB. Positive values are explicit byte limits; non-positive + /// values intentionally disable this protection. pub fn max_graph_memory_bytes(mut self, max_bytes: i64) -> Self { - assert!( - max_bytes == -1 || max_bytes > 0, - "max_graph_memory_bytes must be positive or -1 for auto" - ); self.config.max_graph_memory_bytes = max_bytes; self } @@ -1000,7 +997,7 @@ impl Fory { self.with_read_context(|context| { let outlive_buffer = unsafe { mem::transmute::<&[u8], &[u8]>(bf) }; context.attach_reader(Reader::new(outlive_buffer)); - let result = match context.init_graph_memory_budget(bf.len()) { + let result = match context.init_graph_memory_budget() { Ok(()) => self.deserialize_with_context(context), Err(err) => { context.reset(); @@ -1068,9 +1065,8 @@ impl Fory { let outlive_buffer = unsafe { mem::transmute::<&[u8], &[u8]>(reader.bf) }; let mut new_reader = Reader::new(outlive_buffer); new_reader.set_cursor(reader.cursor); - let root_input_bytes = reader.bf.len().saturating_sub(reader.cursor); context.attach_reader(new_reader); - let result = match context.init_graph_memory_budget(root_input_bytes) { + let result = match context.init_graph_memory_budget() { Ok(()) => self.deserialize_with_context(context), Err(err) => { context.reset(); diff --git a/rust/fory-core/src/serializer/codec.rs b/rust/fory-core/src/serializer/codec.rs index 302019c757..e103c27b29 100644 --- a/rust/fory-core/src/serializer/codec.rs +++ b/rust/fory-core/src/serializer/codec.rs @@ -52,6 +52,20 @@ const DECL_VALUE_TYPE: u8 = 0b100000; const MAX_CHUNK_SIZE: u8 = 255; +#[inline(always)] +fn reserve_graph_storage( + context: &mut ReadContext, + len: u32, + elem_bytes: usize, +) -> Result { + let len = len as usize; + let bytes = len + .checked_mul(elem_bytes) + .ok_or_else(|| Error::invalid_data("graph memory estimate overflows"))?; + context.reserve_graph_memory(bytes)?; + Ok(len) +} + #[inline(always)] pub fn field_ref_mode(field_type: &FieldType) -> RefMode { if field_type.track_ref { @@ -1710,7 +1724,7 @@ where fn read_data(context: &mut ReadContext) -> Result, Error> { let len = context.reader.read_var_u32()?; - context.reserve_counted_graph_memory(len, C::graph_storage_size())?; + reserve_graph_storage(context, len, C::graph_storage_size())?; if len == 0 { return Ok(Vec::new()); } @@ -1739,7 +1753,7 @@ where remote_field_type: &FieldType, ) -> Result, Error> { let len = context.reader.read_var_u32()?; - context.reserve_counted_graph_memory(len, C::graph_storage_size())?; + reserve_graph_storage(context, len, C::graph_storage_size())?; if len == 0 { return Ok(Vec::new()); } @@ -2285,7 +2299,7 @@ where let elem_bytes = KC::graph_storage_size() .checked_add(VC::graph_storage_size()) .ok_or_else(|| Error::invalid_data("graph memory estimate overflows"))?; - context.reserve_counted_graph_memory(len, elem_bytes)?; + reserve_graph_storage(context, len, elem_bytes)?; if len == 0 { return Ok(HashMap::new()); } @@ -2308,7 +2322,7 @@ where let elem_bytes = KC::graph_storage_size() .checked_add(VC::graph_storage_size()) .ok_or_else(|| Error::invalid_data("graph memory estimate overflows"))?; - let capacity = context.reserve_counted_graph_memory(len, elem_bytes)?; + let capacity = reserve_graph_storage(context, len, elem_bytes)?; if len == 0 { return Ok(HashMap::new()); } diff --git a/rust/fory-core/src/serializer/collection.rs b/rust/fory-core/src/serializer/collection.rs index 4b94f8ac32..ca26c2b050 100644 --- a/rust/fory-core/src/serializer/collection.rs +++ b/rust/fory-core/src/serializer/collection.rs @@ -42,6 +42,20 @@ fn check_collection_len(context: &ReadContext, len: u32) -> Result Ok(len) } +#[inline(always)] +fn reserve_collection_storage( + context: &mut ReadContext, + len: u32, + elem_bytes: usize, +) -> Result { + let len = len as usize; + let bytes = len + .checked_mul(elem_bytes) + .ok_or_else(|| Error::invalid_data("graph memory estimate overflows"))?; + context.reserve_graph_memory(bytes)?; + Ok(len) +} + pub fn write_collection_type_info( context: &mut WriteContext, collection_type_id: u32, @@ -239,7 +253,7 @@ where C: FromIterator, { let len = context.reader.read_var_u32()?; - let len_usize = context.reserve_counted_graph_memory(len, T::fory_graph_storage_size())?; + let len_usize = reserve_collection_storage(context, len, T::fory_graph_storage_size())?; if len == 0 { return Ok(C::from_iter(std::iter::empty())); } @@ -284,7 +298,7 @@ where T: Serializer + ForyDefault, { let len = context.reader.read_var_u32()?; - let len_usize = context.reserve_counted_graph_memory(len, T::fory_graph_storage_size())?; + let len_usize = reserve_collection_storage(context, len, T::fory_graph_storage_size())?; if len == 0 { return Ok(Vec::new()); } @@ -733,7 +747,7 @@ where { let element_type = generic_field_type(remote_field_type, 0, "list")?; let len = context.reader.read_var_u32()?; - let len_usize = context.reserve_counted_graph_memory(len, T::fory_graph_storage_size())?; + let len_usize = reserve_collection_storage(context, len, T::fory_graph_storage_size())?; if len == 0 { return Ok(Vec::new()); } diff --git a/rust/fory-core/src/serializer/map.rs b/rust/fory-core/src/serializer/map.rs index 8264d79dec..79bb8c53b3 100644 --- a/rust/fory-core/src/serializer/map.rs +++ b/rust/fory-core/src/serializer/map.rs @@ -39,6 +39,20 @@ fn write_chunk_size(context: &mut WriteContext, header_offset: usize, size: u8) context.writer.set_bytes(header_offset + 1, &[size]); } +#[inline(always)] +fn reserve_map_storage( + context: &mut ReadContext, + len: u32, + elem_bytes: usize, +) -> Result { + let len = len as usize; + let bytes = len + .checked_mul(elem_bytes) + .ok_or_else(|| Error::invalid_data("graph memory estimate overflows"))?; + context.reserve_graph_memory(bytes)?; + Ok(len) +} + pub fn write_map_data<'a, K, V, I>( iter: I, length: usize, @@ -556,7 +570,7 @@ impl Vec> { (0..count).map(|_| Vec::new()).collect() } -fn assert_budget_error(err: Error, effective_limit: usize) { - let message = err.to_string(); - assert!( - message.contains("estimated graph memory request"), - "{message}" +#[test] +fn config_validation() { + assert_eq!( + Fory::builder().build().config().max_graph_memory_bytes, + DEFAULT_GRAPH_MEMORY_BYTES ); - assert!( - message.contains(&format!("effective limit {effective_limit}")), - "{message}" + assert_eq!( + Fory::builder() + .max_graph_memory_bytes(0) + .build() + .config() + .max_graph_memory_bytes, + 0 ); + assert_eq!( + Fory::builder() + .max_graph_memory_bytes(-2) + .build() + .config() + .max_graph_memory_bytes, + -2 + ); + let _ = Fory::builder().max_graph_memory_bytes(1).build(); } #[test] -fn config_validation() { - assert!(panic::catch_unwind(|| Fory::builder().max_graph_memory_bytes(0)).is_err()); - assert!(panic::catch_unwind(|| Fory::builder().max_graph_memory_bytes(-2)).is_err()); - let _ = Fory::builder().max_graph_memory_bytes(-1).build(); - let _ = Fory::builder().max_graph_memory_bytes(1).build(); +fn non_positive_budget_disables_enforcement() { + let value: Vec = Vec::new(); + let writer = fory_with_budget(DEFAULT_GRAPH_MEMORY_BYTES); + let bytes = writer.serialize(&value).unwrap(); + + assert!(fory_with_budget(1) + .deserialize::>(&bytes) + .is_err()); + assert!(fory_with_budget(0) + .deserialize::>(&bytes) + .is_ok()); } #[test] -fn known_auto_budget() { +fn byte_root_uses_fixed_default_budget() { let value = compact_empty_lists(12000); - let writer = fory_with_budget(-1); + let writer = fory_with_budget(DEFAULT_GRAPH_MEMORY_BYTES); let bytes = writer.serialize(&value).unwrap(); - let auto_limit = bytes.len() * 8 + 64 * 1024; - - let err = writer.deserialize::>>(&bytes).unwrap_err(); - assert_budget_error(err, auto_limit); - - let explicit = fory_with_budget(auto_limit as i64); - let err = explicit - .deserialize::>>(&bytes) - .unwrap_err(); - assert_budget_error(err, auto_limit); + let decoded = writer.deserialize::>>(&bytes).unwrap(); + assert_eq!(decoded, value); } #[test] -fn reader_known_auto_budget() { +fn reader_root_uses_fixed_default_budget() { let value = compact_empty_lists(12000); - let writer = fory_with_budget(-1); + let writer = fory_with_budget(DEFAULT_GRAPH_MEMORY_BYTES); let bytes = writer.serialize(&value).unwrap(); - let auto_limit = bytes.len() * 8 + 64 * 1024; let mut reader = Reader::new(&bytes); - let err = writer + let decoded = writer .deserialize_from::>>(&mut reader) - .unwrap_err(); - assert_budget_error(err, auto_limit); + .unwrap(); + assert_eq!(decoded, value); } #[test] fn explicit_override() { let value = compact_empty_lists(12000); - let writer = fory_with_budget(-1); + let writer = fory_with_budget(DEFAULT_GRAPH_MEMORY_BYTES); let bytes = writer.serialize(&value).unwrap(); - assert!(writer.deserialize::>>(&bytes).is_err()); let vec_bytes = mem::size_of::>(); let estimate = mem::size_of::>>() + value.len() * vec_bytes; + let limited = fory_with_budget((estimate - 1) as i64); + assert!(limited.deserialize::>>(&bytes).is_err()); let explicit = fory_with_budget(estimate as i64); let decoded: Vec> = explicit.deserialize(&bytes).unwrap(); assert_eq!(decoded, value); @@ -144,7 +156,7 @@ fn explicit_override() { #[test] fn empty_collection_owner_self() { let value: Vec = Vec::new(); - let writer = fory_with_budget(-1); + let writer = fory_with_budget(DEFAULT_GRAPH_MEMORY_BYTES); let bytes = writer.serialize(&value).unwrap(); let limited = fory_with_budget((mem::size_of::>() - 1) as i64); @@ -158,7 +170,7 @@ fn empty_collection_owner_self() { #[test] fn empty_struct_owner_self() { let value = BudgetEmpty; - let writer = fory_with_budget(-1); + let writer = fory_with_budget(DEFAULT_GRAPH_MEMORY_BYTES); let bytes = writer.serialize(&value).unwrap(); assert_eq!( @@ -188,7 +200,7 @@ fn sibling_cumulative_budget() { first: vec!["a".to_string()], second: vec!["b".to_string()], }; - let writer = fory_with_budget(-1); + let writer = fory_with_budget(DEFAULT_GRAPH_MEMORY_BYTES); let bytes = writer.serialize(&value).unwrap(); let root = mem::size_of::() as i64; let one_vec = mem::size_of::() as i64; @@ -202,7 +214,7 @@ fn sibling_cumulative_budget() { #[test] fn map_budget() { let value: HashMap = HashMap::from([("a".to_string(), 1)]); - let writer = fory_with_budget(-1); + let writer = fory_with_budget(DEFAULT_GRAPH_MEMORY_BYTES); let bytes = writer.serialize(&value).unwrap(); let required = (mem::size_of::>() + mem::size_of::() @@ -226,7 +238,7 @@ fn inline_value_vec_budget() { right: i + 1, }) .collect::>(); - let writer = fory_with_budget(-1); + let writer = fory_with_budget(DEFAULT_GRAPH_MEMORY_BYTES); let bytes = writer.serialize(&value).unwrap(); let under_inline = mem::size_of::>() + value.len() * mem::size_of::(); @@ -244,7 +256,7 @@ fn box_vector_owner_self() { }) .collect::>(), ); - let writer = fory_with_budget(-1); + let writer = fory_with_budget(DEFAULT_GRAPH_MEMORY_BYTES); let bytes = writer.serialize(&value).unwrap(); let required = mem::size_of::>() + value.len() * mem::size_of::(); @@ -264,7 +276,7 @@ fn compatible_list_array_budget() { let value = ListWireInts { values: (0..64).map(Some).collect(), }; - let writer = compatible_fory::(-1); + let writer = compatible_fory::(DEFAULT_GRAPH_MEMORY_BYTES); let bytes = writer.serialize(&value).unwrap(); let required = mem::size_of::() + 64 * mem::size_of::(); @@ -285,26 +297,30 @@ fn compatible_list_array_budget() { fn dense_paths_skipped() { let fory = fory_with_budget(1); - let string_bytes = fory_with_budget(-1) + let string_bytes = fory_with_budget(DEFAULT_GRAPH_MEMORY_BYTES) .serialize(&"hello".to_string()) .unwrap(); let decoded: String = fory.deserialize(&string_bytes).unwrap(); assert_eq!(decoded, "hello"); let binary = vec![1_u8, 2, 3, 4]; - let binary_bytes = fory_with_budget(-1).serialize(&binary).unwrap(); + let binary_bytes = fory_with_budget(DEFAULT_GRAPH_MEMORY_BYTES) + .serialize(&binary) + .unwrap(); let decoded: Vec = fory.deserialize(&binary_bytes).unwrap(); assert_eq!(decoded, binary); let ints = vec![1_i32, 2, 3, 4]; - let int_bytes = fory_with_budget(-1).serialize(&ints).unwrap(); + let int_bytes = fory_with_budget(DEFAULT_GRAPH_MEMORY_BYTES) + .serialize(&ints) + .unwrap(); let decoded: Vec = fory.deserialize(&int_bytes).unwrap(); assert_eq!(decoded, ints); } #[test] fn byte_check_preserved() { - let writer = fory_with_budget(-1); + let writer = fory_with_budget(DEFAULT_GRAPH_MEMORY_BYTES); let mut bytes = writer.serialize(&Vec::::new()).unwrap(); let last = bytes.len() - 1; bytes[last] = 64; diff --git a/swift/Sources/Fory/AnySerializer.swift b/swift/Sources/Fory/AnySerializer.swift index eca738293e..b1b0fd0782 100644 --- a/swift/Sources/Fory/AnySerializer.swift +++ b/swift/Sources/Fory/AnySerializer.swift @@ -20,15 +20,20 @@ import Foundation private let anyReferenceBytes = 4 private let anyArrayOwnerBytes = max(1, MemoryLayout<[Any]>.stride) +@inline(never) +private func throwAnyGraphMemoryOverflow() throws -> Never { + throw ForyError.invalidData("graph memory estimate overflows") +} + @inline(__always) private func reserveAnyReferenceArrayMemory(_ context: ReadContext, count: Int) throws { let (slotBytes, overflow) = count.multipliedReportingOverflow(by: anyReferenceBytes) if overflow { - try context.reserveGraphMemory(-1) + try throwAnyGraphMemoryOverflow() } let (bytes, addOverflow) = anyArrayOwnerBytes.addingReportingOverflow(slotBytes) if addOverflow { - try context.reserveGraphMemory(-1) + try throwAnyGraphMemoryOverflow() } try context.reserveGraphMemory(bytes) } @@ -39,12 +44,12 @@ private func reserveAnyReferenceMapMemory(_ context: ReadContext, _ type: M { let (slotBytes, overflow) = count.multipliedReportingOverflow(by: 2 * anyReferenceBytes) if overflow { - try context.reserveGraphMemory(-1) + try throwAnyGraphMemoryOverflow() } let ownerBytes = max(1, MemoryLayout.stride) let (bytes, addOverflow) = ownerBytes.addingReportingOverflow(slotBytes) if addOverflow { - try context.reserveGraphMemory(-1) + try throwAnyGraphMemoryOverflow() } try context.reserveGraphMemory(bytes) } diff --git a/swift/Sources/Fory/CollectionSerializers.swift b/swift/Sources/Fory/CollectionSerializers.swift index 3c2ed59f74..51efae5202 100644 --- a/swift/Sources/Fory/CollectionSerializers.swift +++ b/swift/Sources/Fory/CollectionSerializers.swift @@ -41,16 +41,29 @@ private func storedElementBytes(_ type: Element.Type) -> In type.isRefType ? storedReferenceBytes : max(1, MemoryLayout.stride) } +@inline(__always) +private func reserveGraphStorage( + _ context: ReadContext, + count: Int, + elementBytes: Int +) throws { + if count < 0 || elementBytes < 0 { + throw ForyError.invalidData("graph memory estimate overflows") + } + let (bytes, overflow) = count.multipliedReportingOverflow(by: elementBytes) + if overflow { + throw ForyError.invalidData("graph memory estimate overflows") + } + try context.reserveGraphMemory(bytes) +} + @inline(__always) private func reserveGraphArrayMemory( _ context: ReadContext, _ type: Element.Type, count: Int ) throws { - try context.reserveCountedGraphMemory( - count: count, - elementBytes: storedElementBytes(type) - ) + try reserveGraphStorage(context, count: count, elementBytes: storedElementBytes(type)) } @inline(__always) @@ -64,9 +77,9 @@ private func reserveGraphMapMemory( let valueBytes = storedElementBytes(value) let (elementBytes, overflow) = keyBytes.addingReportingOverflow(valueBytes) if overflow { - try context.reserveGraphMemory(-1) + throw ForyError.invalidData("graph memory estimate overflows") } - try context.reserveCountedGraphMemory(count: count, elementBytes: elementBytes) + try reserveGraphStorage(context, count: count, elementBytes: elementBytes) } private func primitiveArrayTypeID(for _: Element.Type) -> TypeId? { diff --git a/swift/Sources/Fory/FieldCodecs.swift b/swift/Sources/Fory/FieldCodecs.swift index bb914f52d1..9bab1391cb 100644 --- a/swift/Sources/Fory/FieldCodecs.swift +++ b/swift/Sources/Fory/FieldCodecs.swift @@ -29,13 +29,29 @@ private func serializerElementBytes(_ type: Element.Type) - type.isRefType ? fieldReferenceBytes : max(1, MemoryLayout.stride) } +@inline(__always) +private func reserveFieldStorage( + _ context: ReadContext, + count: Int, + elementBytes: Int +) throws { + if count < 0 || elementBytes < 0 { + throw ForyError.invalidData("graph memory estimate overflows") + } + let (bytes, overflow) = count.multipliedReportingOverflow(by: elementBytes) + if overflow { + throw ForyError.invalidData("graph memory estimate overflows") + } + try context.reserveGraphMemory(bytes) +} + @inline(__always) private func reserveFieldArrayStorage( _ context: ReadContext, _ codec: ElementCodec.Type, count: Int ) throws { - try context.reserveCountedGraphMemory(count: count, elementBytes: fieldElementBytes(codec)) + try reserveFieldStorage(context, count: count, elementBytes: fieldElementBytes(codec)) } @inline(__always) @@ -44,7 +60,7 @@ private func reserveSerializerArrayMemory( _ type: Element.Type, count: Int ) throws { - try context.reserveCountedGraphMemory(count: count, elementBytes: serializerElementBytes(type)) + try reserveFieldStorage(context, count: count, elementBytes: serializerElementBytes(type)) } @inline(__always) @@ -58,9 +74,9 @@ private func reserveFieldMapStorage 0, - "maxGraphMemoryBytes must be positive or -1 for auto") precondition(maxTypeFields > 0, "maxTypeFields must be positive") precondition(maxTypeMetaBytes > 0, "maxTypeMetaBytes must be positive") precondition(maxSchemaVersionsPerType > 0, "maxSchemaVersionsPerType must be positive") @@ -78,7 +75,7 @@ public final class Fory { compatible: Bool? = nil, checkClassVersion: Bool? = nil, maxDepth: Int = 5, - maxGraphMemoryBytes: Int64 = -1, + maxGraphMemoryBytes: Int64 = 128 * 1024 * 1024, maxTypeFields: Int = 512, maxTypeMetaBytes: Int = 4096, maxSchemaVersionsPerType: Int = 10, @@ -497,7 +494,7 @@ public final class Fory { _ body: (ReadContext) throws -> R ) throws -> R { readContext.buffer.replace(with: data) - try readContext.initGraphMemoryBudgetKnown(rootBytes: data.count) + try readContext.initGraphMemoryBudget() defer { readContext.reset() } @@ -559,7 +556,7 @@ public final class Fory { ) throws -> R { try typeResolver.finishRegistration() readContext.buffer.swapState(with: buffer) - try readContext.initGraphMemoryBudgetKnown(rootBytes: readContext.buffer.remaining) + try readContext.initGraphMemoryBudget() defer { readContext.buffer.swapState(with: buffer) readContext.reset() diff --git a/swift/Sources/Fory/ReadContext.swift b/swift/Sources/Fory/ReadContext.swift index 413bcf6b00..09f99fa749 100644 --- a/swift/Sources/Fory/ReadContext.swift +++ b/swift/Sources/Fory/ReadContext.swift @@ -20,10 +20,6 @@ import Foundation private let typeMetaSizeMask = 0xFF public final class ReadContext { - static let knownGraphBudgetSlackBytes = 64 * 1024 - static let unknownGraphBudgetBytes = 128 * 1024 * 1024 - private static let maxKnownGraphRootBytes = (Int.max - knownGraphBudgetSlackBytes) / 8 - public let buffer: ByteBuffer let typeResolver: TypeResolver public let trackRef: Bool @@ -59,15 +55,8 @@ public final class ReadContext { } @inline(__always) - func initGraphMemoryBudgetKnown(rootBytes: Int) throws { - var limit = maxGraphMemoryBytes - if limit < 0 { - if rootBytes > Self.maxKnownGraphRootBytes { - try throwGraphMemoryOverflow() - } - limit = rootBytes * 8 + Self.knownGraphBudgetSlackBytes - } - remainingGraphMemoryBytes = limit + func initGraphMemoryBudget() throws { + remainingGraphMemoryBytes = maxGraphMemoryBytes > 0 ? maxGraphMemoryBytes : Int.max } @inline(__always) @@ -75,26 +64,15 @@ public final class ReadContext { if bytes < 0 { try throwGraphMemoryOverflow() } + if maxGraphMemoryBytes <= 0 { + return + } if bytes > remainingGraphMemoryBytes { try throwGraphMemoryExceeded(bytes: bytes) } remainingGraphMemoryBytes -= bytes } - @inline(__always) - func reserveCountedGraphMemory( - count: Int, - elementBytes: Int - ) throws { - if count < 0 || elementBytes < 0 { - try throwGraphMemoryOverflow() - } - if elementBytes != 0 && count > Int.max / elementBytes { - try throwGraphMemoryOverflow() - } - try reserveGraphMemory(count * elementBytes) - } - @inline(never) private func throwGraphMemoryOverflow() throws -> Never { throw ForyError.invalidData("graph memory estimate overflows") diff --git a/swift/Tests/ForyTests/ForySwiftTests.swift b/swift/Tests/ForyTests/ForySwiftTests.swift index 303f0dfc6c..15ff5b1ab9 100644 --- a/swift/Tests/ForyTests/ForySwiftTests.swift +++ b/swift/Tests/ForyTests/ForySwiftTests.swift @@ -382,7 +382,7 @@ func namedInitializerBuildsConfig() { #expect(defaultConfig.config.compatible == true) #expect(defaultConfig.config.checkClassVersion == false) #expect(defaultConfig.config.maxDepth == 5) - #expect(defaultConfig.config.maxGraphMemoryBytes == -1) + #expect(defaultConfig.config.maxGraphMemoryBytes == 128 * 1024 * 1024) #expect(defaultConfig.config.maxTypeFields == 512) #expect(defaultConfig.config.maxTypeMetaBytes == 4096) #expect(defaultConfig.config.maxSchemaVersionsPerType == 10) diff --git a/swift/Tests/ForyTests/GraphMemoryBudgetTests.swift b/swift/Tests/ForyTests/GraphMemoryBudgetTests.swift index 35181745ff..89fa19a701 100644 --- a/swift/Tests/ForyTests/GraphMemoryBudgetTests.swift +++ b/swift/Tests/ForyTests/GraphMemoryBudgetTests.swift @@ -45,7 +45,9 @@ private struct BudgetDenseHolder: Equatable { var dense: [Int32] = [] } -private func makeBudgetFory(maxGraphMemoryBytes: Int64 = -1) -> Fory { +private let defaultGraphMemoryBytes: Int64 = 128 * 1024 * 1024 + +private func makeBudgetFory(maxGraphMemoryBytes: Int64 = defaultGraphMemoryBytes) -> Fory { let fory = Fory( config: .init( trackRef: false, @@ -108,8 +110,7 @@ private func expectInvalidData(_ body: () throws -> Void) { } @Test -func knownLengthAutoBudgetUsesInputBytes() throws { - let expected = 17 * 8 + ReadContext.knownGraphBudgetSlackBytes +func fixedDefaultBudgetAndDisable() throws { let config = Config(trackRef: false, compatible: false) let context = ReadContext( buffer: ByteBuffer(), @@ -117,15 +118,24 @@ func knownLengthAutoBudgetUsesInputBytes() throws { config: config ) - try context.initGraphMemoryBudgetKnown(rootBytes: 17) - try context.reserveGraphMemory(expected) + try context.initGraphMemoryBudget() + try context.reserveGraphMemory(Int(defaultGraphMemoryBytes)) expectInvalidData { try context.reserveGraphMemory(testReferenceBytes) } + + let disabledConfig = Config(trackRef: false, compatible: false, maxGraphMemoryBytes: 0) + let disabled = ReadContext( + buffer: ByteBuffer(), + typeResolver: TypeResolver(config: disabledConfig), + config: disabledConfig + ) + try disabled.initGraphMemoryBudget() + try disabled.reserveGraphMemory(Int(defaultGraphMemoryBytes) + 1) } @Test -func byteBufferRootUsesKnownLengthAutoBudget() throws { +func byteBufferRootUsesFixedDefaultBudget() throws { let count = 6 let value = Array(repeating: [String](), count: count) let bytes = try makeBudgetFory().serialize(value) @@ -136,7 +146,7 @@ func byteBufferRootUsesKnownLengthAutoBudget() throws { } @Test -func explicitConfigOverridesAutoBudget() throws { +func explicitConfigOverridesDefault() throws { let values = (0..<16).map { "value-\($0)" } let bytes = try makeBudgetFory().serialize(values) let required = rootArrayBudget(String.self, count: values.count) From dff4a9c80a75b45de67864dad57a330af99d12cf Mon Sep 17 00:00:00 2001 From: chaokunyang Date: Wed, 1 Jul 2026 02:38:08 +0800 Subject: [PATCH 17/54] perf(cpp): trim graph budget root init --- cpp/fory/serialization/context.cc | 7 ++++-- cpp/fory/serialization/context.h | 42 +++++++++++-------------------- cpp/fory/serialization/fory.h | 4 +-- 3 files changed, 21 insertions(+), 32 deletions(-) diff --git a/cpp/fory/serialization/context.cc b/cpp/fory/serialization/context.cc index 0c3045efc6..2342a2d219 100644 --- a/cpp/fory/serialization/context.cc +++ b/cpp/fory/serialization/context.cc @@ -434,7 +434,11 @@ uint32_t WriteContext::get_type_id_for_cache(const std::type_index &type_idx) { ReadContext::ReadContext(const Config &config, std::unique_ptr type_resolver) : buffer_(nullptr), config_(&config), - type_resolver_(std::move(type_resolver)), current_dyn_depth_(0) {} + type_resolver_(std::move(type_resolver)), current_dyn_depth_(0), + graph_memory_limit_bytes_(config.max_graph_memory_bytes > 0 + ? static_cast( + config.max_graph_memory_bytes) + : size_t{0}) {} ReadContext::~ReadContext() = default; @@ -760,7 +764,6 @@ void ReadContext::reset() { } reading_type_infos_.clear(); current_dyn_depth_ = 0; - graph_budget_enabled_ = false; remaining_graph_memory_bytes_ = std::numeric_limits::max(); if (meta_string_table_active_) { meta_string_table_.reset(); diff --git a/cpp/fory/serialization/context.h b/cpp/fory/serialization/context.h index c405795be9..601ffd7285 100644 --- a/cpp/fory/serialization/context.h +++ b/cpp/fory/serialization/context.h @@ -505,34 +505,30 @@ class ReadContext { } } + template FORY_ALWAYS_INLINE bool init_graph_budget() { - return init_graph_budget(0); - } - - FORY_ALWAYS_INLINE bool init_graph_budget(size_t reserve_bytes) { - const int64_t configured = config_->max_graph_memory_bytes; - if (FORY_PREDICT_FALSE(configured > 0)) { - if constexpr (sizeof(size_t) < sizeof(uint64_t)) { - if (FORY_PREDICT_FALSE( - static_cast(configured) > - static_cast(std::numeric_limits::max()))) { - return set_graph_memory_error( - "max_graph_memory_bytes does not fit size_t"); + const size_t limit = graph_memory_limit_bytes_; + if (FORY_PREDICT_TRUE(limit != 0)) { + if constexpr (ReserveBytes != 0) { + if (FORY_PREDICT_FALSE(ReserveBytes > limit)) { + return set_graph_memory_exceeded(ReserveBytes, limit); } + remaining_graph_memory_bytes_ = limit - ReserveBytes; + } else { + remaining_graph_memory_bytes_ = limit; } - return init_graph_budget_limit(static_cast(configured), - reserve_bytes); + return true; } - graph_budget_enabled_ = false; remaining_graph_memory_bytes_ = std::numeric_limits::max(); return true; } FORY_ALWAYS_INLINE bool reserve_graph_memory(size_t bytes) { - if (FORY_PREDICT_FALSE(!graph_budget_enabled_)) { + const size_t remaining = remaining_graph_memory_bytes_; + if (FORY_PREDICT_FALSE(remaining == + std::numeric_limits::max())) { return true; } - const size_t remaining = remaining_graph_memory_bytes_; if (FORY_PREDICT_FALSE(bytes > remaining)) { return set_graph_memory_exceeded(bytes, remaining); } @@ -540,16 +536,6 @@ class ReadContext { return true; } - FORY_ALWAYS_INLINE bool init_graph_budget_limit(size_t limit, - size_t reserve_bytes) { - graph_budget_enabled_ = true; - if (FORY_PREDICT_FALSE(reserve_bytes > limit)) { - return set_graph_memory_exceeded(reserve_bytes, limit); - } - remaining_graph_memory_bytes_ = limit - reserve_bytes; - return true; - } - // =========================================================================== // Read methods with Error& parameter // All methods accept Error& as parameter for reduced overhead. @@ -719,7 +705,7 @@ class ReadContext { std::unique_ptr type_resolver_; RefReader ref_reader_; uint32_t current_dyn_depth_; - bool graph_budget_enabled_ = false; + size_t graph_memory_limit_bytes_ = 0; size_t remaining_graph_memory_bytes_ = std::numeric_limits::max(); // Meta sharing state (for compatible mode) diff --git a/cpp/fory/serialization/fory.h b/cpp/fory/serialization/fory.h index fc0fd862da..a0fb717ad2 100644 --- a/cpp/fory/serialization/fory.h +++ b/cpp/fory/serialization/fory.h @@ -892,11 +892,11 @@ class Fory : public BaseFory { constexpr bool has_child_budget = has_graph_budget_children_v; if constexpr (root_owner_bytes != 0) { if (FORY_PREDICT_FALSE( - !read_ctx_->init_graph_budget(root_owner_bytes))) { + !read_ctx_->template init_graph_budget())) { return Unexpected(read_ctx_->take_error()); } } else if constexpr (has_child_budget) { - if (FORY_PREDICT_FALSE(!read_ctx_->init_graph_budget())) { + if (FORY_PREDICT_FALSE(!read_ctx_->template init_graph_budget<>())) { return Unexpected(read_ctx_->take_error()); } } From c3bc40fe3d51e0f6eb96f9a9bc076351812ff88b Mon Sep 17 00:00:00 2001 From: chaokunyang Date: Wed, 1 Jul 2026 23:51:09 +0800 Subject: [PATCH 18/54] perf(go): trim graph budget root read overhead --- go/fory/fory.go | 94 ++++++++++++++++++++++++++++++++++++++--------- go/fory/reader.go | 34 +++++++++-------- 2 files changed, 96 insertions(+), 32 deletions(-) diff --git a/go/fory/fory.go b/go/fory/fory.go index b08d9ecee3..561a66ebc7 100644 --- a/go/fory/fory.go +++ b/go/fory/fory.go @@ -48,6 +48,11 @@ func splitRegisteredName(name string) (string, string, error) { return namespace, typeName, nil } +type ifaceWords struct { + typ unsafe.Pointer + data unsafe.Pointer +} + // ============================================================================ // Constants // ============================================================================ @@ -198,6 +203,9 @@ type Fory struct { rootGraphBytes int64 rootGraphHasChildren bool rootGraphSkipType reflect.Type + rootGraphSkipTypeID unsafe.Pointer + rootReadTypeID unsafe.Pointer + rootReadSerializer Serializer } // New creates a new Fory instance with the given options @@ -572,9 +580,17 @@ func (f *Fory) Serialize(value any) ([]byte, error) { func (f *Fory) Deserialize(data []byte, v any) error { defer f.resetReadState() f.readCtx.SetData(data) - target := reflect.ValueOf(v).Elem() - if err := f.initRootGraphBudget(target); err != nil { - return err + typeID := (*ifaceWords)(unsafe.Pointer(&v)).typ + var target reflect.Value + if typeID != f.rootGraphSkipTypeID { + target = reflect.ValueOf(v).Elem() + targetType := target.Type() + if err := f.initRootGraphBudgetType(targetType); err != nil { + return err + } + if targetType == f.rootGraphSkipType { + f.rootGraphSkipTypeID = typeID + } } readHeader(f.readCtx) @@ -583,7 +599,10 @@ func (f *Fory) Deserialize(data []byte, v any) error { } // Deserialize the value - TypeMeta is read inline using streaming protocol - f.readCtx.ReadValue(target, RefModeTracking, true) + if !target.IsValid() { + target = reflect.ValueOf(v).Elem() + } + f.readRootValue(target, typeID) if f.readCtx.HasError() { return f.readCtx.TakeError() } @@ -668,7 +687,8 @@ func (f *Fory) DeserializeFrom(buf *ByteBuffer, v any) error { origBuffer := f.readCtx.buffer f.readCtx.buffer = buf target := reflect.ValueOf(v).Elem() - if err := f.initRootGraphBudget(target); err != nil { + targetType := target.Type() + if err := f.initRootGraphBudgetType(targetType); err != nil { f.readCtx.buffer = origBuffer return err } @@ -784,7 +804,8 @@ func (f *Fory) DeserializeWithCallbackBuffers(buffer *ByteBuffer, v any, buffers } target := rv.Elem() - if err := f.initRootGraphBudget(target); err != nil { + targetType := target.Type() + if err := f.initRootGraphBudgetType(targetType); err != nil { return err } @@ -1046,12 +1067,14 @@ func Deserialize[T any](f *Fory, data []byte, target *T) error { f.readCtx.SetData(data) var targetVal reflect.Value + var targetType reflect.Type switch any(target).(type) { case *bool, *int8, *int16, *int32, *int64, *int, *float32, *float64, *string, *[]byte, *[]int8, *[]int16, *[]int32, *[]int64, *[]int, *[]float32, *[]float64, *[]bool: default: targetVal = reflect.ValueOf(target).Elem() - if err := f.initRootGraphBudget(targetVal); err != nil { + targetType = targetVal.Type() + if err := f.initRootGraphBudgetType(targetType); err != nil { return err } } @@ -1195,8 +1218,8 @@ func Deserialize[T any](f *Fory, data []byte, target *T) error { // Slow path: use serializer-based deserialization if !targetVal.IsValid() { targetVal = reflect.ValueOf(target).Elem() + targetType = targetVal.Type() } - targetType := targetVal.Type() // Get serializer for the target type serializer, err := f.typeResolver.getSerializerByType(targetType, false) @@ -1212,16 +1235,25 @@ func Deserialize[T any](f *Fory, data []byte, target *T) error { func (f *Fory) initRootGraphBudget(target reflect.Value) error { if !target.IsValid() { + return f.initRootGraphBudgetType(nil) + } + return f.initRootGraphBudgetType(target.Type()) +} + +func (f *Fory) initRootGraphBudgetType(targetType reflect.Type) error { + if targetType == nil { f.readCtx.initGraphMemoryBudget() if f.readCtx.HasError() { return f.readCtx.TakeError() } return nil } - targetType := target.Type() if targetType == f.rootGraphSkipType { return nil } + if targetType == f.rootGraphType && f.rootGraphHasChildren { + return f.initRootGraphBudgetWithSelf(f.rootGraphBytes) + } return f.initRootGraphBudgetSlow(targetType) } @@ -1236,14 +1268,7 @@ func (f *Fory) initRootGraphBudgetSlow(targetType reflect.Type) error { return nil } if hasChildren { - f.readCtx.initGraphMemoryBudget() - if f.readCtx.HasError() { - return f.readCtx.TakeError() - } - if bytes != 0 && !f.readCtx.ReserveGraphMemory(bytes) { - return f.readCtx.TakeError() - } - return nil + return f.initRootGraphBudgetWithSelf(bytes) } if f.config.MaxGraphMemoryBytes <= 0 || bytes <= f.config.MaxGraphMemoryBytes { f.rootGraphSkipType = targetType @@ -1252,6 +1277,23 @@ func (f *Fory) initRootGraphBudgetSlow(targetType reflect.Type) error { return f.checkRootGraphSelf(bytes) } +func (f *Fory) initRootGraphBudgetWithSelf(bytes int64) error { + limit := f.config.MaxGraphMemoryBytes + if limit <= 0 { + f.readCtx.graphMemoryLimitBytes = 0 + f.readCtx.remainingGraphMemoryBytes = MaxInt64 + return nil + } + if bytes > limit { + return DeserializationErrorf( + "estimated graph memory request %d bytes exceeds maxGraphMemoryBytes remaining budget %d bytes out of effective limit %d bytes", + bytes, limit, limit) + } + f.readCtx.graphMemoryLimitBytes = limit + f.readCtx.remainingGraphMemoryBytes = limit - bytes + return nil +} + func (f *Fory) rootGraphInfo(targetType reflect.Type) (int64, bool, bool) { if targetType == nil || targetType.Kind() != reflect.Struct { return 0, false, false @@ -1285,3 +1327,21 @@ func (f *Fory) checkRootGraphSelf(bytes int64) error { } return nil } + +func (f *Fory) readRootValue(target reflect.Value, typeID unsafe.Pointer) { + serializer := f.rootReadSerializer + if typeID == f.rootReadTypeID && serializer != nil { + serializer.Read(f.readCtx, RefModeTracking, true, false, target) + return + } + targetType := target.Type() + if targetType.Kind() == reflect.Struct { + if typeInfo := f.readCtx.getTypeInfoByType(targetType); typeInfo != nil && typeInfo.Serializer != nil { + f.rootReadTypeID = typeID + f.rootReadSerializer = typeInfo.Serializer + typeInfo.Serializer.Read(f.readCtx, RefModeTracking, true, false, target) + return + } + } + f.readCtx.ReadValue(target, RefModeTracking, true) +} diff --git a/go/fory/reader.go b/go/fory/reader.go index 0fe3c3481e..de46143fc2 100644 --- a/go/fory/reader.go +++ b/go/fory/reader.go @@ -169,20 +169,24 @@ func (c *ReadContext) initGraphMemoryBudget() { // ReserveGraphMemory reserves raw estimated graph-owner bytes. func (c *ReadContext) ReserveGraphMemory(bytes int64) bool { - if bytes < 0 { - c.setGraphMemoryError("estimated graph memory must be non-negative, got %d bytes", bytes) - return false - } - if c.graphMemoryLimitBytes <= 0 { - return true - } - remaining := c.remainingGraphMemoryBytes - if bytes > remaining { - c.setGraphMemoryExceeded(bytes, remaining) - return false + if bytes >= 0 { + if c.graphMemoryLimitBytes <= 0 { + return true + } + remaining := c.remainingGraphMemoryBytes + if bytes <= remaining { + c.remainingGraphMemoryBytes = remaining - bytes + return true + } + return c.rejectGraphMemoryExceeded(bytes, remaining) } - c.remainingGraphMemoryBytes = remaining - bytes - return true + return c.rejectGraphMemoryBytes(bytes) +} + +//go:noinline +func (c *ReadContext) rejectGraphMemoryBytes(bytes int64) bool { + c.setGraphMemoryError("estimated graph memory must be non-negative, got %d bytes", bytes) + return false } //go:noinline @@ -191,10 +195,11 @@ func (c *ReadContext) setGraphMemoryError(format string, args ...any) { } //go:noinline -func (c *ReadContext) setGraphMemoryExceeded(bytes int64, remaining int64) { +func (c *ReadContext) rejectGraphMemoryExceeded(bytes int64, remaining int64) bool { c.SetError(DeserializationErrorf( "estimated graph memory request %d bytes exceeds maxGraphMemoryBytes remaining budget %d bytes out of effective limit %d bytes", bytes, remaining, c.graphMemoryLimitBytes)) + return false } // SetData sets new input data (for buffer reuse) @@ -886,7 +891,6 @@ func (c *ReadContext) ReadValue(value reflect.Value, refMode RefMode, readType b return } valueType := value.Type() - // Handle array targets (arrays are serialized as slices) if valueType.Kind() == reflect.Array { c.ReadArrayValue(value, refMode, readType) From 6b0c403e019a5dd6e4f3ae6247290914a91c5613 Mon Sep 17 00:00:00 2001 From: chaokunyang Date: Wed, 1 Jul 2026 23:58:39 +0800 Subject: [PATCH 19/54] docs: fix graph memory budget guidance --- .agents/languages/csharp.md | 6 +++--- .agents/languages/swift.md | 10 ++++++---- docs/guide/cpp/configuration.md | 5 +++-- docs/guide/go/configuration.md | 2 +- docs/guide/java/configuration.md | 2 +- docs/guide/python/configuration.md | 8 +++++--- docs/guide/rust/configuration.md | 5 +++-- 7 files changed, 22 insertions(+), 16 deletions(-) diff --git a/.agents/languages/csharp.md b/.agents/languages/csharp.md index 7d15edfaf5..882f912930 100644 --- a/.agents/languages/csharp.md +++ b/.agents/languages/csharp.md @@ -13,9 +13,9 @@ Load this file when changing `csharp/` or C# xlang behavior. - C# generated schema modules are source-file owners. Service companions must use that module's `ThreadSafeFory` and must not introduce namespace-owned aliases or duplicate serializer registration paths. - Compatible scalar, list-array, and binary/uint8-array adaptations are immediate-field-only. Recursive matched-field comparison for collection elements, array elements, map keys, and map values must require exact nullability, ref tracking, generic arity, and type shape except documented user-type family normalization. - Root deserialization graph memory budget state belongs to `ReadContext`. C# public roots are - memory-backed today, so auto uses known input length. `ReadContext` may expose only raw byte - reservation and generic counted-byte arithmetic; concrete serializers and generated serializers - must compute list, array, map, struct, and object byte formulas before calling it. + memory-backed today, but the graph budget uses the same fixed default for every root shape. + `ReadContext` may expose only raw byte reservation; concrete serializers and generated + serializers must compute list, array, map, struct, and object byte formulas before calling it. - For C# graph budget formulas, distinguish inline value storage from reference storage: use cheap value-type size for `List`/`T[]` value paths and the 4-byte reference fallback for reference paths. Class/reference serializers reserve their own shallow self cost plus field storage when diff --git a/.agents/languages/swift.md b/.agents/languages/swift.md index 30182cd4ba..8371b107f0 100644 --- a/.agents/languages/swift.md +++ b/.agents/languages/swift.md @@ -6,6 +6,8 @@ Load this file when changing `swift/` or Swift xlang behavior. - Run Swift commands from within `swift/`. - Changes under `swift/` must pass lint and tests. +- SwiftLint is mandatory but not sufficient for readability; manually clean touched Swift source + formatting before committing because this repo does not run a Swift formatter in CI. - Swift code must compile without compiler warnings. Treat warnings as blockers, including warnings in generated Swift code. - Swift lint uses `swift/.swiftlint.yml`. - Swift formatting uses `swift/.swift-format`; do not rely on SwiftLint for indentation or source @@ -16,10 +18,10 @@ Load this file when changing `swift/` or Swift xlang behavior. - When temporal or public-type refactors touch generated Swift code, sweep message fields, union payloads, macros, xlang harnesses, and integration fixtures together. - Compatible scalar, list-array, and binary/uint8-array adaptations are immediate-field-only. Recursive matched-field comparison for collection elements, array elements, map keys, and map values must require exact nullability, ref tracking, generic arity, and type shape except documented user-type family normalization. - Root deserialization graph memory budget state belongs to `ReadContext`. Swift public roots are - `Data` and `ByteBuffer`, so auto uses known root bytes; do not add stream bytes-read accounting or - serializer-local budget state. `ReadContext` may expose only raw byte reservation and generic - counted-byte arithmetic; array, set, map, struct, and object formulas belong in serializer and - field-codec owners. + `Data` and `ByteBuffer`, and both use the same fixed default graph budget; do not add stream + bytes-read accounting or serializer-local budget state. `ReadContext` may expose only raw byte + reservation; array, set, map, struct, and object formulas belong in serializer and field-codec + owners. - For Swift graph budget formulas, distinguish inline/value storage from reference storage: use `MemoryLayout.stride` for value arrays/lists/sets/maps and the 4-byte reference fallback for `Serializer.isRefType` / `FieldCodec.isRefType` paths. Class/reference paths reserve their own diff --git a/docs/guide/cpp/configuration.md b/docs/guide/cpp/configuration.md index 93900070e8..2c4ec2524b 100644 --- a/docs/guide/cpp/configuration.md +++ b/docs/guide/cpp/configuration.md @@ -246,8 +246,9 @@ Security-related configuration: - Register all structs and polymorphic implementations before deserializing untrusted payloads. - Use `check_struct_version(true)` with `compatible(false)` for intentional same-schema payloads. -- Leave `max_graph_memory_bytes(-1)` enabled for automatic root-size-based graph limits, or set a - positive value for a stricter trusted-workload envelope. +- Keep `max_graph_memory_bytes(...)` at the fixed `128 MiB` default for most inputs, or set a + positive value for a trusted workload that needs a different envelope. Avoid explicit + non-positive values for untrusted data because they disable graph-memory enforcement. - Keep `max_dyn_depth(...)` as low as your model permits to reject unexpectedly deep polymorphic graphs. - Keep the remote schema metadata limits at their defaults unless the data is not malicious and a diff --git a/docs/guide/go/configuration.md b/docs/guide/go/configuration.md index af21f22c29..201c508df6 100644 --- a/docs/guide/go/configuration.md +++ b/docs/guide/go/configuration.md @@ -52,7 +52,7 @@ f := fory.New( fory.WithXlang(true), fory.WithTrackRef(true), fory.WithMaxDepth(10), - fory.WithMaxGraphMemoryBytes(-1), + fory.WithMaxGraphMemoryBytes(128 * 1024 * 1024), fory.WithMaxTypeFields(512), fory.WithMaxTypeMetaBytes(4096), fory.WithMaxSchemaVersionsPerType(10), diff --git a/docs/guide/java/configuration.md b/docs/guide/java/configuration.md index 4aef1c5ba9..73827559b6 100644 --- a/docs/guide/java/configuration.md +++ b/docs/guide/java/configuration.md @@ -91,7 +91,7 @@ Keep class registration enabled for production and any untrusted payload source: Fory fory = Fory.builder() .requireClassRegistration(true) .withMaxDepth(50) - .withMaxGraphMemoryBytes(-1) + .withMaxGraphMemoryBytes(128L * 1024 * 1024) .build(); ``` diff --git a/docs/guide/python/configuration.md b/docs/guide/python/configuration.md index 601daab5af..f9bfd1ac96 100644 --- a/docs/guide/python/configuration.md +++ b/docs/guide/python/configuration.md @@ -40,7 +40,7 @@ class Fory: max_type_meta_bytes: int = 4096, max_schema_versions_per_type: int = 10, max_average_schema_versions_per_type: int = 3, - max_graph_memory_bytes: int = -1, + max_graph_memory_bytes: int = 128 * 1024 * 1024, policy: DeserializationPolicy = None, field_nullable: bool = False, meta_compressor=None, @@ -199,7 +199,7 @@ fory = pyfory.Fory( max_type_meta_bytes=4096, max_schema_versions_per_type=10, max_average_schema_versions_per_type=3, - max_graph_memory_bytes=-1, + max_graph_memory_bytes=128 * 1024 * 1024, ) fory.register(UserModel, name="example.User") @@ -287,7 +287,9 @@ unchanged. - Register all expected application types before deserialization. - Use `DeserializationPolicy` when `strict=False` is necessary. - Keep `max_depth` low enough to reject unexpectedly deep payloads. -- Keep `max_graph_memory_bytes=-1` unless a trusted workload needs a higher explicit limit. +- Keep `max_graph_memory_bytes` at the fixed `128 MiB` default for most inputs, or set a positive + explicit limit for trusted workloads with different legitimate object-graph sizes. Avoid + explicit non-positive values for untrusted data because they disable graph-memory enforcement. - Do not treat xlang/native mode choice as a security control. ## Related Topics diff --git a/docs/guide/rust/configuration.md b/docs/guide/rust/configuration.md index 2a14d4cace..7426f968f6 100644 --- a/docs/guide/rust/configuration.md +++ b/docs/guide/rust/configuration.md @@ -192,8 +192,9 @@ Security-related configuration: - Register application structs and trait-object implementations before deserializing untrusted payloads. - Use `max_dyn_depth(...)` to reject unexpectedly deep dynamic object graphs. -- Keep `max_graph_memory_bytes(-1)` for the default input-shaped graph budget, or set a positive - byte limit for trusted workloads with larger legitimate object graphs. +- Keep `max_graph_memory_bytes(...)` at the fixed `128 MiB` default for most inputs, or set a + positive byte limit for trusted workloads with different legitimate object-graph sizes. Avoid + explicit non-positive values for untrusted data because they disable graph-memory enforcement. - Keep the remote schema metadata limits at their defaults unless the data is not malicious and a trusted peer sends larger metadata or many schema versions. - Prefer concrete typed fields over `dyn Any` or broad trait-object fields for untrusted input. From d7b0a2ddb522bda55da2fdc73ba904c75ddbda90 Mon Sep 17 00:00:00 2001 From: chaokunyang Date: Thu, 2 Jul 2026 00:24:36 +0800 Subject: [PATCH 20/54] fix: align graph budget collection owners --- .agents/languages/dart.md | 5 +- .agents/languages/javascript.md | 5 +- csharp/src/Fory/CollectionSerializers.cs | 47 +++++++++--- .../Fory.Tests/GraphMemoryBudgetTests.cs | 31 ++++++++ .../serializer/collection_serializers.dart | 33 ++++----- .../fory/test/graph_memory_budget_test.dart | 21 ++++-- .../packages/core/lib/gen/collection.ts | 74 +++++++------------ javascript/test/graphMemoryBudget.test.ts | 6 +- .../scala/internal/ForySerializerMacros.scala | 3 +- .../scala/ForySerializerDerivationTest.scala | 2 +- .../Sources/Fory/CollectionSerializers.swift | 21 +++++- .../ForyTests/GraphMemoryBudgetTests.swift | 15 ++++ 12 files changed, 165 insertions(+), 98 deletions(-) diff --git a/.agents/languages/dart.md b/.agents/languages/dart.md index d7a5dea99e..47017e29e8 100644 --- a/.agents/languages/dart.md +++ b/.agents/languages/dart.md @@ -21,8 +21,9 @@ Load this file when changing `dart/`. map, array, struct, and object formulas belong in serializer owners. Reserve Dart list/set/object-array reference slots plus nonzero owner self cost, map key/value slots plus nonzero owner - self cost, compatible list-to-array inline storage, compatible array-to-list - materialization, and generated object reads before allocation. Object/struct + self cost, compatible array-to-list materialization, and generated object reads before + allocation. Compatible list-to-typed-array reads skip the dense primitive-array leaf owner while + preserving byte checks. Object/struct owners reserve nonzero shallow self memory plus shallow field storage. Skip only dedicated string, binary, primitive scalar, `BoolList`, and typed-array dense owner paths with byte checks. Do not add stream bytes-read accounting, diff --git a/.agents/languages/javascript.md b/.agents/languages/javascript.md index 56a14c630a..d7e24ca968 100644 --- a/.agents/languages/javascript.md +++ b/.agents/languages/javascript.md @@ -22,8 +22,9 @@ Load this file when changing `javascript/`. list/set/map/array/struct/object readers must reserve before allocation while preserving existing byte checks. Lists/sets/object arrays reserve nonzero owner self cost plus 4-byte reference slots, maps reserve nonzero owner self cost plus key/value reference storage, object/struct readers - reserve nonzero shallow self memory plus shallow field storage, and compatible - list-to-typed-array reads reserve typed inline storage. Keep dedicated string, binary, primitive + reserve nonzero shallow self memory plus shallow field storage, compatible array-to-list reads + reserve target list materialization, and compatible list-to-typed-array reads skip the dense + primitive-array leaf owner while preserving byte checks. Keep dedicated string, binary, primitive scalar, and dense typed-array leaf owners out of this budget. - Regenerated compatible read serializers are remote-schema-specific. After classification marks a field as direct, compatible scalar, or skip, generated JavaScript should emit straight-line remote-field-order code. Do not add an outer matched-id switch unless the current regenerated shape cannot preserve those semantics. - Compatible scalar codegen must decide the exact remote/local scalar pair before emitting source. Generate the concrete `reader.readXxx()` call plus inline trivial conversions such as boolean-to-string or numeric widening, and keep helpers only for semantic validation such as range checks, exactness checks, decimal parsing/formatting, and string-to-bool. Do not call a generic hot-path converter that redispatches on `remoteTypeId`, `localTypeId`, field descriptors, or field names. diff --git a/csharp/src/Fory/CollectionSerializers.cs b/csharp/src/Fory/CollectionSerializers.cs index 72e6169c57..516012ca1c 100644 --- a/csharp/src/Fory/CollectionSerializers.cs +++ b/csharp/src/Fory/CollectionSerializers.cs @@ -208,13 +208,20 @@ public static void WriteCollectionData( } } - public static List ReadCollectionData(Serializer elementSerializer, ReadContext context) + public static List ReadCollectionData( + Serializer elementSerializer, + ReadContext context, + bool reserveOwner = true) { TypeInfo elementTypeInfo = context.TypeResolver.GetTypeInfo(); int length = checked((int)context.Reader.ReadVarUInt32()); if (length == 0) { - ReserveElementStorage(context, length); + if (reserveOwner) + { + ReserveElementStorage(context, length); + } + return []; } @@ -227,7 +234,11 @@ public static List ReadCollectionData(Serializer elementSerializer, Rea bool hasNull = (header & CollectionBits.HasNull) != 0; bool declared = (header & CollectionBits.DeclaredElementType) != 0; bool sameType = (header & CollectionBits.SameType) != 0; - ReserveElementStorage(context, length); + if (reserveOwner) + { + ReserveElementStorage(context, length); + } + context.Reader.CheckBound(length); List values = new(length); if (!sameType) @@ -569,7 +580,10 @@ public override void WriteData(WriteContext context, in HashSet value, bool h public override HashSet ReadData(ReadContext context) { - List values = CollectionCodec.ReadCollectionData(context.TypeResolver.GetSerializer(), context); + List values = CollectionCodec.ReadCollectionData( + context.TypeResolver.GetSerializer(), + context, + reserveOwner: false); CollectionCodec.ReserveElementStorage(context, values.Count); return [.. values]; } @@ -587,7 +601,10 @@ public override void WriteData(WriteContext context, in SortedSet value, bool public override SortedSet ReadData(ReadContext context) { - List values = CollectionCodec.ReadCollectionData(context.TypeResolver.GetSerializer(), context); + List values = CollectionCodec.ReadCollectionData( + context.TypeResolver.GetSerializer(), + context, + reserveOwner: false); CollectionCodec.ReserveElementStorage(context, values.Count); return [.. values]; } @@ -605,7 +622,10 @@ public override void WriteData(WriteContext context, in ImmutableHashSet valu public override ImmutableHashSet ReadData(ReadContext context) { - List values = CollectionCodec.ReadCollectionData(context.TypeResolver.GetSerializer(), context); + List values = CollectionCodec.ReadCollectionData( + context.TypeResolver.GetSerializer(), + context, + reserveOwner: false); CollectionCodec.ReserveElementStorage(context, values.Count); return ImmutableHashSet.CreateRange(values); } @@ -623,7 +643,10 @@ public override void WriteData(WriteContext context, in LinkedList value, boo public override LinkedList ReadData(ReadContext context) { - List values = CollectionCodec.ReadCollectionData(context.TypeResolver.GetSerializer(), context); + List values = CollectionCodec.ReadCollectionData( + context.TypeResolver.GetSerializer(), + context, + reserveOwner: false); CollectionCodec.ReserveElementStorage(context, values.Count); return new LinkedList(values); } @@ -641,7 +664,10 @@ public override void WriteData(WriteContext context, in Queue value, bool has public override Queue ReadData(ReadContext context) { - List values = CollectionCodec.ReadCollectionData(context.TypeResolver.GetSerializer(), context); + List values = CollectionCodec.ReadCollectionData( + context.TypeResolver.GetSerializer(), + context, + reserveOwner: false); CollectionCodec.ReserveElementStorage(context, values.Count); Queue queue = new(values.Count); for (int i = 0; i < values.Count; i++) @@ -678,7 +704,10 @@ public override void WriteData(WriteContext context, in Stack value, bool has public override Stack ReadData(ReadContext context) { - List values = CollectionCodec.ReadCollectionData(context.TypeResolver.GetSerializer(), context); + List values = CollectionCodec.ReadCollectionData( + context.TypeResolver.GetSerializer(), + context, + reserveOwner: false); CollectionCodec.ReserveElementStorage(context, values.Count); Stack stack = new(values.Count); for (int i = 0; i < values.Count; i++) diff --git a/csharp/tests/Fory.Tests/GraphMemoryBudgetTests.cs b/csharp/tests/Fory.Tests/GraphMemoryBudgetTests.cs index 2971dcf065..1d85e9079f 100644 --- a/csharp/tests/Fory.Tests/GraphMemoryBudgetTests.cs +++ b/csharp/tests/Fory.Tests/GraphMemoryBudgetTests.cs @@ -16,6 +16,7 @@ // under the License. using System.Buffers; +using System.Collections.Immutable; using System.Runtime.CompilerServices; using Apache.Fory; using ForyRuntime = Apache.Fory.Fory; @@ -267,6 +268,36 @@ public void GeneratedSchemaContainersAreCharged() Assert.Equal(map.Values, NewFory(mapRequired).Deserialize(mapBytes).Values); } + [Fact] + public void ConversionCollectionsAreChargedOnce() + { + long required = ListBudget(3); + + Check(new HashSet { 1, 2, 3 }, v => v.SetEquals([1, 2, 3])); + Check(new SortedSet { 1, 2, 3 }, v => v.SetEquals([1, 2, 3])); + Check(ImmutableHashSet.Create(1, 2, 3), v => v.SetEquals([1, 2, 3])); + Check(new LinkedList([1, 2, 3]), v => v.SequenceEqual([1, 2, 3])); + + Queue queue = new(); + queue.Enqueue(1); + queue.Enqueue(2); + queue.Enqueue(3); + Check(queue, v => v.SequenceEqual([1, 2, 3])); + + Stack stack = new(); + stack.Push(1); + stack.Push(2); + stack.Push(3); + Check(stack, v => v.SequenceEqual([3, 2, 1])); + + void Check(T value, Func assertValue) + { + byte[] bytes = Serialize(value); + Assert.Throws(() => NewFory(required - 1).Deserialize(bytes)); + Assert.True(assertValue(NewFory(required).Deserialize(bytes))); + } + } + [Fact] public void DenseStringBinaryAndPrimitiveArraysAreSkipped() { diff --git a/dart/packages/fory/lib/src/serializer/collection_serializers.dart b/dart/packages/fory/lib/src/serializer/collection_serializers.dart index 344d0f85f7..71b2d7dd27 100644 --- a/dart/packages/fory/lib/src/serializer/collection_serializers.dart +++ b/dart/packages/fory/lib/src/serializer/collection_serializers.dart @@ -335,8 +335,13 @@ final class ListSerializer extends Serializer { ReadContext context, FieldType? elementFieldType, { bool hasPreservedRef = false, + bool reserveOwner = true, }) { - final state = _prepareListRead(context, elementFieldType); + final state = _prepareListRead( + context, + elementFieldType, + reserveOwner: reserveOwner, + ); context.buffer.checkReadableBytes(state.size); final result = List.filled(state.size, null, growable: false); if (hasPreservedRef) { @@ -384,10 +389,15 @@ final class SetSerializer extends Serializer { final values = ListSerializer.readPayload( context, elementFieldType, - hasPreservedRef: hasPreservedRef, + hasPreservedRef: false, + reserveOwner: false, ); context.reserveGraphMemory(_ownerBytes + values.length * _referenceBytes); - return Set.of(values); + final result = Set.of(values); + if (hasPreservedRef) { + context.reference(result); + } + return result; } } @@ -511,9 +521,6 @@ Object _readCompatibleListAsArrayField( String fieldName, ) { final size = context.buffer.readVarUint32(); - context.reserveGraphMemory( - _ownerBytes + size * _arrayElementBytes(arrayTypeId), - ); if (size == 0) { return _newArrayValue(arrayTypeId, 0); } @@ -576,20 +583,6 @@ int _compatibleArrayElementTypeId(int typeId) { }; } -int _arrayElementBytes(int arrayTypeId) { - return switch (arrayTypeId) { - TypeIds.boolArray || TypeIds.int8Array || TypeIds.uint8Array => 1, - TypeIds.int16Array || - TypeIds.uint16Array || - TypeIds.float16Array || - TypeIds.bfloat16Array => 2, - TypeIds.int32Array || TypeIds.uint32Array || TypeIds.float32Array => 4, - TypeIds.int64Array || TypeIds.uint64Array || TypeIds.float64Array => 8, - _ => - throw StateError('Unsupported compatible array field type $arrayTypeId.'), - }; -} - Object _newArrayValue(int arrayTypeId, int length) { return switch (arrayTypeId) { TypeIds.boolArray => BoolList(length), diff --git a/dart/packages/fory/test/graph_memory_budget_test.dart b/dart/packages/fory/test/graph_memory_budget_test.dart index 60318b5c1b..499e043168 100644 --- a/dart/packages/fory/test/graph_memory_budget_test.dart +++ b/dart/packages/fory/test/graph_memory_budget_test.dart @@ -206,6 +206,14 @@ void main() { expect(_readWithBudget(value, _mapGraphBytes(1)), equals(value)); }); + test('reserves generic set owner once', () { + final value = {'x', 7}; + final required = _listGraphBytes(value.length); + + expect(() => _readWithBudget(value, required - 1), _throwsGraphBudget); + expect(_readWithBudget(value, required), equals(value)); + }); + test('reserves generated list set and map reads', () { final writer = Fory(); _registerGenerated(writer); @@ -238,22 +246,22 @@ void main() { expect(roundTrip.counts, equals({'one': 1})); }); - test('reserves compatible list array materialization', () { + test('skips compatible list to typed array leaf', () { final listWriter = Fory(); _registerCompatibleList(listWriter); final listBytes = listWriter.serialize( BudgetCompatibleListEnvelope()..values = [1, 2, 3], ); - final required = _objectGraphBytes(1) + _objectBytes + 3 * 4; - final arrayFail = Fory(maxGraphMemoryBytes: required - 1); + final arrayRequired = _objectGraphBytes(1); + final arrayFail = Fory(maxGraphMemoryBytes: arrayRequired - 1); _registerCompatibleArray(arrayFail); expect( () => arrayFail.deserialize(listBytes), _throwsGraphBudget, ); - final arrayPass = Fory(maxGraphMemoryBytes: required); + final arrayPass = Fory(maxGraphMemoryBytes: arrayRequired); _registerCompatibleArray(arrayPass); expect( arrayPass @@ -270,14 +278,15 @@ void main() { ..values = Int32List.fromList([1, 2, 3]), ); - final listFail = Fory(maxGraphMemoryBytes: required - 1); + final listRequired = _objectGraphBytes(1) + _listGraphBytes(3); + final listFail = Fory(maxGraphMemoryBytes: listRequired - 1); _registerCompatibleList(listFail); expect( () => listFail.deserialize(arrayBytes), _throwsGraphBudget, ); - final listPass = Fory(maxGraphMemoryBytes: required); + final listPass = Fory(maxGraphMemoryBytes: listRequired); _registerCompatibleList(listPass); expect( listPass.deserialize(arrayBytes).values, diff --git a/javascript/packages/core/lib/gen/collection.ts b/javascript/packages/core/lib/gen/collection.ts index bca49665b7..d98a228ac0 100644 --- a/javascript/packages/core/lib/gen/collection.ts +++ b/javascript/packages/core/lib/gen/collection.ts @@ -103,30 +103,6 @@ function compatibleArrayCollectionExpr( } } -function compatibleArrayElementBytes(elementTypeId: number): number { - switch (elementTypeId) { - case TypeId.BOOL: - case TypeId.INT8: - case TypeId.UINT8: - return 1; - case TypeId.INT16: - case TypeId.UINT16: - case TypeId.FLOAT16: - case TypeId.BFLOAT16: - return 2; - case TypeId.INT32: - case TypeId.UINT32: - case TypeId.FLOAT32: - return 4; - case TypeId.INT64: - case TypeId.UINT64: - case TypeId.FLOAT64: - return 8; - default: - return 4; - } -} - function compatibleArrayPutAccessor( elementTypeId: number, result: string, @@ -178,9 +154,9 @@ class CollectionAnySerializer { } if (isSame) { if ( - serializer !== null && - serializer !== undefined && - current !== serializer + serializer !== null + && serializer !== undefined + && current !== serializer ) { isSame = false; } else { @@ -213,8 +189,8 @@ class CollectionAnySerializer { if (size === 0) { return; } - const { serializer, isSame, includeNone, trackingRef } = - this.writeElementsHeader(value); + const { serializer, isSame, includeNone, trackingRef } + = this.writeElementsHeader(value); if (isSame) { serializer!.writeTypeInfo(value); if (trackingRef) { @@ -240,8 +216,8 @@ class CollectionAnySerializer { } else { if (trackingRef) { for (const item of value) { - const serializer = - this.writeContext.typeResolver.getSerializerByData(item); + const serializer + = this.writeContext.typeResolver.getSerializerByData(item); serializer?.writeRef(item); } } else if (includeNone) { @@ -249,16 +225,16 @@ class CollectionAnySerializer { if (item === null || item === undefined) { this.writeContext.writer.writeInt8(RefFlags.NullFlag); } else { - const serializer = - this.writeContext.typeResolver.getSerializerByData(item); + const serializer + = this.writeContext.typeResolver.getSerializerByData(item); this.writeContext.writer.writeInt8(RefFlags.NotNullValueFlag); serializer!.writeNoRef(item); } } } else { for (const item of value) { - const serializer = - this.writeContext.typeResolver.getSerializerByData(item); + const serializer + = this.writeContext.typeResolver.getSerializerByData(item); serializer!.writeNoRef(item); } } @@ -391,12 +367,12 @@ export abstract class CollectionSerializerGenerator extends BaseSerializerGenera private isDeclaredElementType() { const innerTypeId = this.innerGenerator.getTypeId(); return ( - innerTypeId !== TypeId.STRUCT && - innerTypeId !== TypeId.COMPATIBLE_STRUCT && - innerTypeId !== TypeId.NAMED_STRUCT && - innerTypeId !== TypeId.NAMED_COMPATIBLE_STRUCT && - innerTypeId !== TypeId.EXT && - innerTypeId !== TypeId.NAMED_EXT + innerTypeId !== TypeId.STRUCT + && innerTypeId !== TypeId.COMPATIBLE_STRUCT + && innerTypeId !== TypeId.NAMED_STRUCT + && innerTypeId !== TypeId.NAMED_COMPATIBLE_STRUCT + && innerTypeId !== TypeId.EXT + && innerTypeId !== TypeId.NAMED_EXT ); } @@ -488,16 +464,16 @@ export abstract class CollectionSerializerGenerator extends BaseSerializerGenera ? compatibleArrayCollectionExpr(compatibleReadAction!.elementTypeId, len) : this.newCollection(len); const reserveMemory = compatibleListToArray - ? `${readContextName}.reserveGraphMemory(${COLLECTION_BYTES} + ${len} * ${compatibleArrayElementBytes(compatibleReadAction!.elementTypeId)});` + ? "" : `${readContextName}.reserveGraphMemory(${COLLECTION_BYTES} + ${len} * ${REFERENCE_BYTES});`; const putAccessor = (item: string, index: string) => compatibleListToArray ? compatibleArrayPutAccessor( - compatibleReadAction!.elementTypeId, - result, - item, - index, - ) + compatibleReadAction!.elementTypeId, + result, + item, + index, + ) : this.putAccessor(result, item, index); const rejectCompatiblePayload = compatibleListToArray ? ` @@ -524,8 +500,8 @@ export abstract class CollectionSerializerGenerator extends BaseSerializerGenera }; const readElementTypeInfo = useDeclaredStructElementReader ? this.innerGenerator - .readEmbed() - .readTypeInfo((expr: string) => `${elemSerializer} = ${expr};`) + .readEmbed() + .readTypeInfo((expr: string) => `${elemSerializer} = ${expr};`) : `${elemSerializer} = ${anyHelper}.detectSerializer(${readContextName});`; return ` const ${len} = ${this.builder.reader.readVarUint32Small7()}; diff --git a/javascript/test/graphMemoryBudget.test.ts b/javascript/test/graphMemoryBudget.test.ts index b7b21158ae..b86bc6fa34 100644 --- a/javascript/test/graphMemoryBudget.test.ts +++ b/javascript/test/graphMemoryBudget.test.ts @@ -238,7 +238,7 @@ describe("graph memory budget", () => { }); }); - test("reserves compatible typed arrays", () => { + test("skips compatible list to typed array leaf", () => { const writerType = Type.struct(9010, { values: Type.list(Type.int32({ encoding: "fixed" })).setId(1), }); @@ -249,11 +249,11 @@ describe("graph memory budget", () => { const bytes = writer.register(writerType).serialize({ values: [1, 2, 3] }); const passingReader = new Fory({ compatible: true, - maxGraphMemoryBytes: objectBytes(1) + OBJECT_BYTES + 12, + maxGraphMemoryBytes: objectBytes(1), }).register(readerType); const failingReader = new Fory({ compatible: true, - maxGraphMemoryBytes: objectBytes(1) + OBJECT_BYTES + 12 - 1, + maxGraphMemoryBytes: objectBytes(1) - 1, }).register(readerType); expect(() => failingReader.deserialize(bytes)).toThrow( diff --git a/scala/src/main/scala-3/org/apache/fory/scala/internal/ForySerializerMacros.scala b/scala/src/main/scala-3/org/apache/fory/scala/internal/ForySerializerMacros.scala index c20ee9e2b0..cbefa8ccdb 100644 --- a/scala/src/main/scala-3/org/apache/fory/scala/internal/ForySerializerMacros.scala +++ b/scala/src/main/scala-3/org/apache/fory/scala/internal/ForySerializerMacros.scala @@ -122,8 +122,6 @@ object ForySerializerMacros { else 4L } - val objectGraphMemoryBytes: Long = 1L + fields.map(field => graphFieldBytes(field.sourceType)).sum - def classFor(tpe: TypeRepr): Expr[Class[?]] = { val normalized = peelAnnotations(tpe.widen)._1.dealias val fullName = normalized.typeSymbol.fullName @@ -222,6 +220,7 @@ object ForySerializerMacros { !privateField, constructorOwned || (field.flags.is(Flags.Mutable) && !privateField)) } + val objectGraphMemoryBytes: Long = 1L + fields.map(field => graphFieldBytes(field.sourceType)).sum val hasNestedCompatibleStructFields = fields.exists(field => hasNestedCompatibleStruct(field.sourceType)) diff --git a/scala/src/test/scala-3/org/apache/fory/serializer/scala/ForySerializerDerivationTest.scala b/scala/src/test/scala-3/org/apache/fory/serializer/scala/ForySerializerDerivationTest.scala index 22e7044d23..7cf598f381 100644 --- a/scala/src/test/scala-3/org/apache/fory/serializer/scala/ForySerializerDerivationTest.scala +++ b/scala/src/test/scala-3/org/apache/fory/serializer/scala/ForySerializerDerivationTest.scala @@ -571,7 +571,7 @@ import org.apache.fory.scala.ForyScala buffer.readVarUInt32() shouldBe 0 buffer.readerIndex(0) val readContext = fory.getReadContext - readContext.prepare(buffer, null, false, buffer.remaining(), false) + readContext.prepare(buffer, null, false) try serializer.read(readContext) shouldBe value finally readContext.reset() } diff --git a/swift/Sources/Fory/CollectionSerializers.swift b/swift/Sources/Fory/CollectionSerializers.swift index 51efae5202..f286ea3950 100644 --- a/swift/Sources/Fory/CollectionSerializers.swift +++ b/swift/Sources/Fory/CollectionSerializers.swift @@ -623,11 +623,20 @@ extension Array: Serializer where Element: Serializer { } public static func foryReadData(_ context: ReadContext) throws -> [Element] { + try readData(context, reserveGraphStorage: true) + } + + fileprivate static func readData( + _ context: ReadContext, + reserveGraphStorage: Bool + ) throws -> [Element] { let buffer = context.buffer let length = Int(try buffer.readVarUInt32()) try context.ensureCollectionLength(length, label: "array") if length == 0 { - try reserveGraphArrayMemory(context, Element.self, count: length) + if reserveGraphStorage { + try reserveGraphArrayMemory(context, Element.self, count: length) + } return [] } @@ -637,7 +646,9 @@ extension Array: Serializer where Element: Serializer { let declared = (header & CollectionHeader.declaredElementType) != 0 let sameType = (header & CollectionHeader.sameType) != 0 if !sameType { - try reserveGraphArrayMemory(context, Element.self, count: length) + if reserveGraphStorage { + try reserveGraphArrayMemory(context, Element.self, count: length) + } try context.ensureRemainingBytes(length, label: "array") if trackRef { return try readArrayUninitialized(count: length) { destination in @@ -676,7 +687,9 @@ extension Array: Serializer where Element: Serializer { } let elementTypeInfo = declared ? nil : try Element.foryReadTypeInfo(context) - try reserveGraphArrayMemory(context, Element.self, count: length) + if reserveGraphStorage { + try reserveGraphArrayMemory(context, Element.self, count: length) + } try context.ensureRemainingBytes(length, label: "array") return try context.withTypeInfo(elementTypeInfo, for: Element.self) { if trackRef { @@ -735,7 +748,7 @@ extension Set: Serializer where Element: Serializer & Hashable { } public static func foryReadData(_ context: ReadContext) throws -> Set { - let values = try [Element].foryReadData(context) + let values = try [Element].readData(context, reserveGraphStorage: false) try reserveGraphArrayMemory(context, Element.self, count: values.count) return Set(values) } diff --git a/swift/Tests/ForyTests/GraphMemoryBudgetTests.swift b/swift/Tests/ForyTests/GraphMemoryBudgetTests.swift index 89fa19a701..1c3f88b8f6 100644 --- a/swift/Tests/ForyTests/GraphMemoryBudgetTests.swift +++ b/swift/Tests/ForyTests/GraphMemoryBudgetTests.swift @@ -222,6 +222,21 @@ func referenceAndInlineValueArraysAreCharged() throws { #expect(try makeBudgetFory(maxGraphMemoryBytes: Int64(intBudget)).deserialize(intBytes) == ints) } +@Test +func setConversionOwnerChargedOnce() throws { + let values: Set = [1, 2, 3] + let bytes = try makeBudgetFory().serialize(values) + let required = ownerBytes(Set.self) + arrayBudget(Int32.self, count: values.count) + + expectInvalidData { + let _: Set = try makeBudgetFory(maxGraphMemoryBytes: Int64(required - 1)) + .deserialize(bytes) + } + let decoded: Set = try makeBudgetFory(maxGraphMemoryBytes: Int64(required)) + .deserialize(bytes) + #expect(decoded == values) +} + @Test func stringBinaryAndPrimitiveDenseArrayOwnersAreSkipped() throws { let value = BudgetDenseHolder( From d373183e0c6fe240526a18e50ccff555acf2af94 Mon Sep 17 00:00:00 2001 From: chaokunyang Date: Thu, 2 Jul 2026 12:11:15 +0800 Subject: [PATCH 21/54] fix(csharp): clean graph budget read ownership --- .../src/Fory.Generator/ForyModelGenerator.cs | 156 +++- csharp/src/Fory/AnySerializer.cs | 14 +- csharp/src/Fory/CollectionSerializers.cs | 801 +++++++++++++++++- csharp/src/Fory/DictionarySerializers.cs | 13 +- csharp/src/Fory/FieldSkipper.cs | 28 +- csharp/src/Fory/Fory.cs | 67 +- csharp/src/Fory/NullableKeyDictionary.cs | 14 +- .../Fory/PrimitiveDictionarySerializers.cs | 7 +- csharp/src/Fory/ReadContext.cs | 120 ++- csharp/src/Fory/RefResolver.cs | 18 +- csharp/src/Fory/Serializer.cs | 20 +- csharp/src/Fory/TypeResolver.cs | 5 +- csharp/src/Fory/UnionSerializer.cs | 5 +- csharp/src/Fory/UnknownCaseSerializer.cs | 14 +- csharp/tests/Fory.Tests/ForyRuntimeTests.cs | 151 +++- .../Fory.Tests/GraphMemoryBudgetTests.cs | 32 + docs/guide/csharp/basic-serialization.md | 5 + docs/guide/csharp/references.md | 4 + 18 files changed, 1345 insertions(+), 129 deletions(-) diff --git a/csharp/src/Fory.Generator/ForyModelGenerator.cs b/csharp/src/Fory.Generator/ForyModelGenerator.cs index 68ba979c72..4bbb832afe 100644 --- a/csharp/src/Fory.Generator/ForyModelGenerator.cs +++ b/csharp/src/Fory.Generator/ForyModelGenerator.cs @@ -214,7 +214,16 @@ private static void EmitObjectSerializer(StringBuilder sb, TypeModel model) $" private const bool __ForyAllFieldsBuiltIn = {BoolLiteral(model.SortedMembers.All(m => m.DynamicAnyKind == DynamicAnyKind.None && m.Classification.IsBuiltIn))};"); if (model.Kind == DeclKind.Class) { - sb.AppendLine($" private static readonly long __ForyGraphMemoryBytes = {ModelGraphMemoryExpr(model)};"); + string graphMemoryExpr = ModelGraphMemoryExpr(model); + bool constGraphMemory = IsConstGraphMemoryExpr(graphMemoryExpr); + string graphMemoryStorage = constGraphMemory ? "const" : "static readonly"; + string graphMemoryType = constGraphMemory ? "int" : "long"; + if (constGraphMemory) + { + graphMemoryExpr = graphMemoryExpr.Replace("L", string.Empty); + } + + sb.AppendLine($" private {graphMemoryStorage} {graphMemoryType} __ForyGraphMemoryBytes = {graphMemoryExpr};"); } sb.AppendLine( @@ -451,7 +460,19 @@ private static void EmitObjectSerializer(StringBuilder sb, TypeModel model) sb.AppendLine(" }"); sb.AppendLine(); - sb.AppendLine($" private {model.TypeName} ReadDataWithoutTypeMeta(global::Apache.Fory.ReadContext context)"); + EmitReadDataWithoutTypeMeta(sb, model, "ReadDataWithoutTypeMeta", "context.ShouldStoreRef"); + EmitReadDataMethod(sb, model, "ReadData", "ReadDataWithoutTypeMeta", "context.ShouldStoreRef", "public"); + + sb.AppendLine("}"); + } + + private static void EmitReadDataWithoutTypeMeta( + StringBuilder sb, + TypeModel model, + string methodName, + string? storeRefCondition) + { + sb.AppendLine($" private {model.TypeName} {methodName}(global::Apache.Fory.ReadContext context)"); sb.AppendLine(" {"); if (model.Kind == DeclKind.Class) { @@ -459,10 +480,7 @@ private static void EmitObjectSerializer(StringBuilder sb, TypeModel model) } sb.AppendLine($" {model.TypeName} valueNoTypeMeta = new {model.TypeName}();"); - if (model.Kind == DeclKind.Class) - { - sb.AppendLine(" context.StoreRef(valueNoTypeMeta);"); - } + EmitStoreRef(sb, model, storeRefCondition, "valueNoTypeMeta", 2); foreach (MemberModel member in model.SortedMembers) { @@ -480,7 +498,17 @@ private static void EmitObjectSerializer(StringBuilder sb, TypeModel model) sb.AppendLine(" return valueNoTypeMeta;"); sb.AppendLine(" }"); sb.AppendLine(); - sb.AppendLine($" public override {model.TypeName} ReadData(global::Apache.Fory.ReadContext context)"); + } + + private static void EmitReadDataMethod( + StringBuilder sb, + TypeModel model, + string methodName, + string noTypeMetaMethodName, + string? storeRefCondition, + string accessibility) + { + sb.AppendLine($" {accessibility} override {model.TypeName} {methodName}(global::Apache.Fory.ReadContext context)"); sb.AppendLine(" {"); sb.AppendLine(" if (context.Compatible)"); sb.AppendLine(" {"); @@ -488,7 +516,7 @@ private static void EmitObjectSerializer(StringBuilder sb, TypeModel model) $" global::Apache.Fory.TypeMeta? maybeTypeMeta = context.GetTypeMeta<{model.TypeName}>();"); sb.AppendLine(" if (maybeTypeMeta is null)"); sb.AppendLine(" {"); - sb.AppendLine(" return ReadDataWithoutTypeMeta(context);"); + sb.AppendLine($" return {noTypeMetaMethodName}(context);"); sb.AppendLine(" }"); sb.AppendLine(); sb.AppendLine(" global::Apache.Fory.TypeMeta typeMeta = maybeTypeMeta;"); @@ -498,10 +526,7 @@ private static void EmitObjectSerializer(StringBuilder sb, TypeModel model) } sb.AppendLine($" {model.TypeName} value = new {model.TypeName}();"); - if (model.Kind == DeclKind.Class) - { - sb.AppendLine(" context.StoreRef(value);"); - } + EmitStoreRef(sb, model, storeRefCondition, "value", 3); sb.AppendLine(" bool __ForyExactTypeMeta = __ForyMatchesCachedTypeMeta(typeMeta, context.TrackRef, context.TypeResolver);"); sb.AppendLine(" if (__ForyAllFieldsBuiltIn && __ForyExactTypeMeta)"); @@ -613,10 +638,7 @@ private static void EmitObjectSerializer(StringBuilder sb, TypeModel model) } sb.AppendLine($" {model.TypeName} valueSchema = new {model.TypeName}();"); - if (model.Kind == DeclKind.Class) - { - sb.AppendLine(" context.StoreRef(valueSchema);"); - } + EmitStoreRef(sb, model, storeRefCondition, "valueSchema", 2); foreach (MemberModel member in model.SortedMembers) { @@ -625,7 +647,32 @@ private static void EmitObjectSerializer(StringBuilder sb, TypeModel model) sb.AppendLine(" return valueSchema;"); sb.AppendLine(" }"); - sb.AppendLine("}"); + sb.AppendLine(); + } + + private static void EmitStoreRef( + StringBuilder sb, + TypeModel model, + string? condition, + string valueName, + int indentLevel) + { + if (model.Kind != DeclKind.Class || condition is null) + { + return; + } + + string indent = new(' ', indentLevel * 4); + if (condition == "true") + { + sb.AppendLine($"{indent}context.StoreRef({valueName});"); + return; + } + + sb.AppendLine($"{indent}if ({condition})"); + sb.AppendLine($"{indent}{{"); + sb.AppendLine($"{indent} context.StoreRef({valueName});"); + sb.AppendLine($"{indent}}}"); } private static void EmitUnionSerializer(StringBuilder sb, TypeModel model) @@ -697,6 +744,7 @@ private static void EmitUnionSerializer(StringBuilder sb, TypeModel model) sb.AppendLine(); sb.AppendLine($" public override {model.TypeName} ReadData(global::Apache.Fory.ReadContext context)"); sb.AppendLine(" {"); + sb.AppendLine(" uint __foryPausedRef = context.PauseRefPublication();"); sb.AppendLine(" uint rawCaseId = context.Reader.ReadVarUInt32();"); sb.AppendLine(" if (rawCaseId > int.MaxValue)"); sb.AppendLine(" {"); @@ -713,7 +761,9 @@ private static void EmitUnionSerializer(StringBuilder sb, TypeModel model) sb.AppendLine($" case {caseId}:"); sb.AppendLine(" {"); EmitReadUnionCasePayload(sb, unionCase, valueVar, 4); - sb.AppendLine($" return new {unionCase.TypeName}({valueVar});"); + sb.AppendLine($" {model.TypeName} __foryUnion = new {unionCase.TypeName}({valueVar});"); + sb.AppendLine(" context.ResumeRefPublication(__foryPausedRef);"); + sb.AppendLine(" return __foryUnion;"); sb.AppendLine(" }"); } @@ -725,7 +775,9 @@ private static void EmitUnionSerializer(StringBuilder sb, TypeModel model) } else { - sb.AppendLine($" return new {unknownCase.TypeName}(global::Apache.Fory.UnknownCaseSerializer.ReadPayload(context, caseId));"); + sb.AppendLine($" {model.TypeName} __foryUnion = new {unknownCase.TypeName}(global::Apache.Fory.UnknownCaseSerializer.ReadPayload(context, caseId));"); + sb.AppendLine(" context.ResumeRefPublication(__foryPausedRef);"); + sb.AppendLine(" return __foryUnion;"); } sb.AppendLine(" }"); @@ -1185,7 +1237,6 @@ private static void EmitReadCompatibleListArrayPayload( $"(typeof({elementTypeName}).IsValueType ? global::System.Runtime.CompilerServices.Unsafe.SizeOf<{elementTypeName}>() : 4)"; if (codec.CarrierKind == CarrierKind.Array) { - sb.AppendLine($"{indent}context.ReserveGraphMemory(1L + (long){lengthVar} * {elementBytesExpr});"); sb.AppendLine($"{indent}{codec.TypeName} {targetVar} = new {ElementTypeName(codec.TypeName)}[{lengthVar}];"); } else @@ -1824,6 +1875,8 @@ private static MemberModel NonNullableMember(MemberModel member) member.Group, member.IsCollection, member.UseDictionaryTypeInfoCache, + member.UsesReferenceStorage, + member.FixedValueBytes, member.IsRefType, member.NeedsFieldTypeInfo, member.DynamicAnyKind, @@ -1879,8 +1932,24 @@ private static string FieldGraphMemoryExpr(MemberModel member) return $"{member.Classification.PrimitiveSize}L"; } + if (member.UsesReferenceStorage) + { + return "4L"; + } + + if (member.FixedValueBytes > 0) + { + return $"{member.FixedValueBytes}L"; + } + string typeName = StripNullableForTypeOf(member.TypeName); - return $"(typeof({typeName}).IsValueType ? global::System.Runtime.CompilerServices.Unsafe.SizeOf<{typeName}>() : 4L)"; + return $"global::System.Runtime.CompilerServices.Unsafe.SizeOf<{typeName}>()"; + } + + private static bool IsConstGraphMemoryExpr(string expression) + { + return expression.IndexOf("typeof(", StringComparison.Ordinal) < 0 && + expression.IndexOf("Unsafe.", StringComparison.Ordinal) < 0; } private static string PackedArrayElementTypeName(uint typeId) @@ -3123,6 +3192,7 @@ private static ForyAttributeKind GetForyAttributeKind(INamedTypeSymbol typeSymbo } TypeClassification classification = resolution.Classification; + int fixedValueBytes = FixedGraphValueBytes(unwrappedType, classification); int group = classification.IsPrimitive ? (isOptional ? 2 : 1) : 3; @@ -3157,6 +3227,8 @@ memberType is INamedTypeSymbol nts && group, classification.IsCollection || classification.IsMap, classification.IsMap && !IsTypeSealed(unwrappedType), + !unwrappedType.IsValueType, + fixedValueBytes, !unwrappedType.IsValueType && classification.TypeId != 21, FieldNeedsTypeInfo(classification, dynamicAnyKind, unwrappedType), dynamicAnyKind == DynamicAnyKind.None ? DynamicAnyKind.None : dynamicAnyKind, @@ -3165,6 +3237,42 @@ memberType is INamedTypeSymbol nts && schemaType is not null); } + private static int FixedGraphValueBytes(ITypeSymbol type, TypeClassification classification) + { + if (classification.IsPrimitive && classification.PrimitiveSize > 0) + { + return classification.PrimitiveSize; + } + + if (type.TypeKind == TypeKind.Enum && + type is INamedTypeSymbol enumType && + enumType.EnumUnderlyingType is not null) + { + return SpecialTypeBytes(enumType.EnumUnderlyingType.SpecialType); + } + + return type.SpecialType == SpecialType.System_Decimal ? 16 : 0; + } + + private static int SpecialTypeBytes(SpecialType specialType) + { + return specialType switch + { + SpecialType.System_Boolean or + SpecialType.System_SByte or + SpecialType.System_Byte => 1, + SpecialType.System_Int16 or + SpecialType.System_UInt16 => 2, + SpecialType.System_Int32 or + SpecialType.System_UInt32 or + SpecialType.System_Single => 4, + SpecialType.System_Int64 or + SpecialType.System_UInt64 or + SpecialType.System_Double => 8, + _ => 0, + }; + } + private static TypeMetaFieldTypeModel BuildTypeMetaFieldTypeModel( ITypeSymbol memberType, bool nullable, @@ -4431,6 +4539,8 @@ public MemberModel( int group, bool isCollection, bool useDictionaryTypeInfoCache, + bool usesReferenceStorage, + int fixedValueBytes, bool isRefType, bool needsFieldTypeInfo, DynamicAnyKind dynamicAnyKind, @@ -4450,6 +4560,8 @@ public MemberModel( Group = group; IsCollection = isCollection; UseDictionaryTypeInfoCache = useDictionaryTypeInfoCache; + UsesReferenceStorage = usesReferenceStorage; + FixedValueBytes = fixedValueBytes; IsRefType = isRefType; NeedsFieldTypeInfo = needsFieldTypeInfo; DynamicAnyKind = dynamicAnyKind; @@ -4470,6 +4582,8 @@ public MemberModel( public int Group { get; } public bool IsCollection { get; } public bool UseDictionaryTypeInfoCache { get; } + public bool UsesReferenceStorage { get; } + public int FixedValueBytes { get; } public bool IsRefType { get; } public bool NeedsFieldTypeInfo { get; } public DynamicAnyKind DynamicAnyKind { get; } diff --git a/csharp/src/Fory/AnySerializer.cs b/csharp/src/Fory/AnySerializer.cs index 1be0c3e20c..80fe56b478 100644 --- a/csharp/src/Fory/AnySerializer.cs +++ b/csharp/src/Fory/AnySerializer.cs @@ -95,16 +95,10 @@ public override void Write(WriteContext context, in object? value, RefMode refMo { uint reservedRefId = context.RefReader.ReserveRefId(); context.SetReservedRefId(reservedRefId); - try - { - object? value = ReadNonNullDynamicAny(context, readTypeInfo); - context.StoreRef(value); - return value; - } - finally - { - context.ClearReservedRefId(); - } + object? value = ReadNonNullDynamicAny(context, readTypeInfo); + context.StoreRef(value); + context.ClearReservedRefId(); + return value; } case RefFlag.NotNullValue: break; diff --git a/csharp/src/Fory/CollectionSerializers.cs b/csharp/src/Fory/CollectionSerializers.cs index 516012ca1c..b0a08b96fc 100644 --- a/csharp/src/Fory/CollectionSerializers.cs +++ b/csharp/src/Fory/CollectionSerializers.cs @@ -36,7 +36,7 @@ internal static class CollectionCodec private const int ReferenceBytes = 4; [MethodImpl(MethodImplOptions.AggressiveInlining)] - internal static int ElementBytes() => typeof(T).IsValueType ? Unsafe.SizeOf() : ReferenceBytes; + internal static int ElementBytes() => ElementStorage.Bytes; [MethodImpl(MethodImplOptions.AggressiveInlining)] internal static void ReserveElementStorage(ReadContext context, int count) @@ -208,13 +208,53 @@ public static void WriteCollectionData( } } + private static class ElementStorage + { + internal static readonly int Bytes = typeof(T).IsValueType ? Unsafe.SizeOf() : ReferenceBytes; + } + + internal readonly struct CollectionFrame + { + internal CollectionFrame(bool trackRef, bool hasNull, bool declared, bool sameType) + { + TrackRef = trackRef; + HasNull = hasNull; + Declared = declared; + SameType = sameType; + } + + internal bool TrackRef { get; } + internal bool HasNull { get; } + internal bool Declared { get; } + internal bool SameType { get; } + } + + internal static int ReadLengthAndReserve(ReadContext context) + { + int length = checked((int)context.Reader.ReadVarUInt32()); + ReserveElementStorage(context, length); + return length; + } + + internal static CollectionFrame ReadFrame(ReadContext context, int length) + { + byte header = context.Reader.ReadUInt8(); + context.Reader.CheckBound(length); + return new CollectionFrame( + (header & CollectionBits.TrackingRef) != 0, + (header & CollectionBits.HasNull) != 0, + (header & CollectionBits.DeclaredElementType) != 0, + (header & CollectionBits.SameType) != 0); + } + public static List ReadCollectionData( Serializer elementSerializer, ReadContext context, - bool reserveOwner = true) + bool reserveOwner = true, + bool storeOwnerRef = true) { - TypeInfo elementTypeInfo = context.TypeResolver.GetTypeInfo(); int length = checked((int)context.Reader.ReadVarUInt32()); + bool storeRef = storeOwnerRef && context.ShouldStoreRef; if (length == 0) { if (reserveOwner) @@ -222,7 +262,13 @@ public static List ReadCollectionData( ReserveElementStorage(context, length); } - return []; + List empty = []; + if (storeRef) + { + context.StoreRef(empty); + } + + return empty; } byte header = context.Reader.ReadUInt8(); @@ -241,6 +287,11 @@ public static List ReadCollectionData( context.Reader.CheckBound(length); List values = new(length); + if (storeRef) + { + context.StoreRef(values); + } + if (!sameType) { if (trackRef) @@ -333,6 +384,127 @@ public static List ReadCollectionData( return values; } + + public static T[] ReadArrayData(Serializer elementSerializer, ReadContext context) + { + int length = checked((int)context.Reader.ReadVarUInt32()); + if (length == 0) + { + ReserveElementStorage(context, length); + T[] empty = []; + if (context.ShouldStoreRef) + { + context.StoreRef(empty); + } + + return empty; + } + + byte header = context.Reader.ReadUInt8(); + bool trackRef = (header & CollectionBits.TrackingRef) != 0; + bool hasNull = (header & CollectionBits.HasNull) != 0; + bool declared = (header & CollectionBits.DeclaredElementType) != 0; + bool sameType = (header & CollectionBits.SameType) != 0; + ReserveElementStorage(context, length); + context.Reader.CheckBound(length); + T[] values = new T[length]; + if (context.ShouldStoreRef) + { + context.StoreRef(values); + } + + if (!sameType) + { + if (trackRef) + { + for (int i = 0; i < length; i++) + { + values[i] = elementSerializer.Read(context, RefMode.Tracking, true); + } + + return values; + } + + if (hasNull) + { + for (int i = 0; i < length; i++) + { + sbyte refFlag = context.Reader.ReadInt8(); + if (refFlag == (sbyte)RefFlag.Null) + { + values[i] = (T)elementSerializer.DefaultObject!; + } + else if (refFlag == (sbyte)RefFlag.NotNullValue) + { + values[i] = elementSerializer.Read(context, RefMode.None, true); + } + else + { + throw new RefException($"invalid nullability flag {refFlag}"); + } + } + } + else + { + for (int i = 0; i < length; i++) + { + values[i] = elementSerializer.Read(context, RefMode.None, true); + } + } + + return values; + } + + if (!declared) + { + context.TypeResolver.ReadTypeInfo(elementSerializer, context); + } + + if (trackRef) + { + for (int i = 0; i < length; i++) + { + values[i] = elementSerializer.Read(context, RefMode.Tracking, false); + } + + if (!declared) + { + context.ClearReadTypeInfo(typeof(T)); + } + + return values; + } + + if (hasNull) + { + for (int i = 0; i < length; i++) + { + sbyte refFlag = context.Reader.ReadInt8(); + if (refFlag == (sbyte)RefFlag.Null) + { + values[i] = (T)elementSerializer.DefaultObject!; + } + else + { + values[i] = elementSerializer.ReadData(context); + } + } + } + else + { + for (int i = 0; i < length; i++) + { + values[i] = elementSerializer.ReadData(context); + } + } + + if (!declared) + { + context.ClearReadTypeInfo(typeof(T)); + } + + return values; + } } internal static class DynamicContainerCodec @@ -416,7 +588,17 @@ public static bool TryWritePayload(object value, WriteContext context, bool hasG public static object ReadMapPayload(ReadContext context) { - NullableKeyDictionary map = context.TypeResolver.GetSerializer>().ReadData(context); + Serializer> serializer = + context.TypeResolver.GetSerializer>(); + bool storeRef = context.ShouldStoreRef; + NullableKeyDictionary map = serializer.ReadData(context); + if (storeRef) + { + // Dynamic tracked maps publish the nullable-key owner itself so + // nested references resolve to the same returned object. + return map; + } + if (map.HasNullKey) { return map; @@ -547,8 +729,7 @@ public override void WriteData(WriteContext context, in T[] value, bool hasGener public override T[] ReadData(ReadContext context) { - List values = CollectionCodec.ReadCollectionData(context.TypeResolver.GetSerializer(), context); - return values.ToArray(); + return CollectionCodec.ReadArrayData(context.TypeResolver.GetSerializer(), context); } } @@ -580,13 +761,130 @@ public override void WriteData(WriteContext context, in HashSet value, bool h public override HashSet ReadData(ReadContext context) { + if (context.ShouldStoreRef) + { + return ReadStoredSetData(context); + } + List values = CollectionCodec.ReadCollectionData( context.TypeResolver.GetSerializer(), context, - reserveOwner: false); + reserveOwner: false, + storeOwnerRef: false); CollectionCodec.ReserveElementStorage(context, values.Count); return [.. values]; } + + private static HashSet ReadStoredSetData(ReadContext context) + { + Serializer elementSerializer = context.TypeResolver.GetSerializer(); + int length = checked((int)context.Reader.ReadVarUInt32()); + CollectionCodec.ReserveElementStorage(context, length); + HashSet values = new(length); + context.StoreRef(values); + if (length == 0) + { + return values; + } + + byte header = context.Reader.ReadUInt8(); + bool trackRef = (header & CollectionBits.TrackingRef) != 0; + bool hasNull = (header & CollectionBits.HasNull) != 0; + bool declared = (header & CollectionBits.DeclaredElementType) != 0; + bool sameType = (header & CollectionBits.SameType) != 0; + context.Reader.CheckBound(length); + if (!sameType) + { + if (trackRef) + { + for (int i = 0; i < length; i++) + { + values.Add(elementSerializer.Read(context, RefMode.Tracking, true)); + } + + return values; + } + + if (hasNull) + { + for (int i = 0; i < length; i++) + { + sbyte refFlag = context.Reader.ReadInt8(); + if (refFlag == (sbyte)RefFlag.Null) + { + values.Add((T)elementSerializer.DefaultObject!); + } + else if (refFlag == (sbyte)RefFlag.NotNullValue) + { + values.Add(elementSerializer.Read(context, RefMode.None, true)); + } + else + { + throw new RefException($"invalid nullability flag {refFlag}"); + } + } + } + else + { + for (int i = 0; i < length; i++) + { + values.Add(elementSerializer.Read(context, RefMode.None, true)); + } + } + + return values; + } + + if (!declared) + { + context.TypeResolver.ReadTypeInfo(elementSerializer, context); + } + + if (trackRef) + { + for (int i = 0; i < length; i++) + { + values.Add(elementSerializer.Read(context, RefMode.Tracking, false)); + } + + if (!declared) + { + context.ClearReadTypeInfo(typeof(T)); + } + + return values; + } + + if (hasNull) + { + for (int i = 0; i < length; i++) + { + sbyte refFlag = context.Reader.ReadInt8(); + if (refFlag == (sbyte)RefFlag.Null) + { + values.Add((T)elementSerializer.DefaultObject!); + } + else + { + values.Add(elementSerializer.ReadData(context)); + } + } + } + else + { + for (int i = 0; i < length; i++) + { + values.Add(elementSerializer.ReadData(context)); + } + } + + if (!declared) + { + context.ClearReadTypeInfo(typeof(T)); + } + + return values; + } } public sealed class SortedSetSerializer : Serializer> where T : notnull @@ -601,31 +899,147 @@ public override void WriteData(WriteContext context, in SortedSet value, bool public override SortedSet ReadData(ReadContext context) { + if (context.ShouldStoreRef) + { + return ReadStoredSortedSetData(context); + } + + uint refId = context.PauseRefPublication(); List values = CollectionCodec.ReadCollectionData( context.TypeResolver.GetSerializer(), context, - reserveOwner: false); + reserveOwner: false, + storeOwnerRef: false); + context.ResumeRefPublication(refId); CollectionCodec.ReserveElementStorage(context, values.Count); return [.. values]; } -} - -public sealed class ImmutableHashSetSerializer : Serializer> where T : notnull -{ - public override ImmutableHashSet DefaultValue => null!; - public override void WriteData(WriteContext context, in ImmutableHashSet value, bool hasGenerics) + private static SortedSet ReadStoredSortedSetData(ReadContext context) { - ImmutableHashSet safe = value ?? ImmutableHashSet.Empty; - CollectionCodec.WriteCollectionData(safe, context.TypeResolver.GetSerializer(), context, hasGenerics); - } - - public override ImmutableHashSet ReadData(ReadContext context) + Serializer elementSerializer = context.TypeResolver.GetSerializer(); + int length = CollectionCodec.ReadLengthAndReserve(context); + SortedSet values = new(); + context.StoreRef(values); + if (length == 0) + { + return values; + } + + CollectionCodec.CollectionFrame frame = CollectionCodec.ReadFrame(context, length); + if (!frame.SameType) + { + if (frame.TrackRef) + { + for (int i = 0; i < length; i++) + { + values.Add(elementSerializer.Read(context, RefMode.Tracking, true)); + } + + return values; + } + + if (frame.HasNull) + { + for (int i = 0; i < length; i++) + { + sbyte refFlag = context.Reader.ReadInt8(); + if (refFlag == (sbyte)RefFlag.Null) + { + values.Add((T)elementSerializer.DefaultObject!); + } + else if (refFlag == (sbyte)RefFlag.NotNullValue) + { + values.Add(elementSerializer.Read(context, RefMode.None, true)); + } + else + { + throw new RefException($"invalid nullability flag {refFlag}"); + } + } + } + else + { + for (int i = 0; i < length; i++) + { + values.Add(elementSerializer.Read(context, RefMode.None, true)); + } + } + + return values; + } + + if (!frame.Declared) + { + context.TypeResolver.ReadTypeInfo(elementSerializer, context); + } + + if (frame.TrackRef) + { + for (int i = 0; i < length; i++) + { + values.Add(elementSerializer.Read(context, RefMode.Tracking, false)); + } + + if (!frame.Declared) + { + context.ClearReadTypeInfo(typeof(T)); + } + + return values; + } + + if (frame.HasNull) + { + for (int i = 0; i < length; i++) + { + sbyte refFlag = context.Reader.ReadInt8(); + if (refFlag == (sbyte)RefFlag.Null) + { + values.Add((T)elementSerializer.DefaultObject!); + } + else + { + values.Add(elementSerializer.ReadData(context)); + } + } + } + else + { + for (int i = 0; i < length; i++) + { + values.Add(elementSerializer.ReadData(context)); + } + } + + if (!frame.Declared) + { + context.ClearReadTypeInfo(typeof(T)); + } + + return values; + } +} + +public sealed class ImmutableHashSetSerializer : Serializer> where T : notnull +{ + public override ImmutableHashSet DefaultValue => null!; + + public override void WriteData(WriteContext context, in ImmutableHashSet value, bool hasGenerics) + { + ImmutableHashSet safe = value ?? ImmutableHashSet.Empty; + CollectionCodec.WriteCollectionData(safe, context.TypeResolver.GetSerializer(), context, hasGenerics); + } + + public override ImmutableHashSet ReadData(ReadContext context) { + uint refId = context.PauseRefPublication(); List values = CollectionCodec.ReadCollectionData( context.TypeResolver.GetSerializer(), context, - reserveOwner: false); + reserveOwner: false, + storeOwnerRef: false); + context.ResumeRefPublication(refId); CollectionCodec.ReserveElementStorage(context, values.Count); return ImmutableHashSet.CreateRange(values); } @@ -643,13 +1057,126 @@ public override void WriteData(WriteContext context, in LinkedList value, boo public override LinkedList ReadData(ReadContext context) { + if (context.ShouldStoreRef) + { + return ReadStoredLinkedListData(context); + } + + uint refId = context.PauseRefPublication(); List values = CollectionCodec.ReadCollectionData( context.TypeResolver.GetSerializer(), context, - reserveOwner: false); + reserveOwner: false, + storeOwnerRef: false); + context.ResumeRefPublication(refId); CollectionCodec.ReserveElementStorage(context, values.Count); return new LinkedList(values); } + + private static LinkedList ReadStoredLinkedListData(ReadContext context) + { + Serializer elementSerializer = context.TypeResolver.GetSerializer(); + int length = CollectionCodec.ReadLengthAndReserve(context); + LinkedList values = new(); + context.StoreRef(values); + if (length == 0) + { + return values; + } + + CollectionCodec.CollectionFrame frame = CollectionCodec.ReadFrame(context, length); + if (!frame.SameType) + { + if (frame.TrackRef) + { + for (int i = 0; i < length; i++) + { + values.AddLast(elementSerializer.Read(context, RefMode.Tracking, true)); + } + + return values; + } + + if (frame.HasNull) + { + for (int i = 0; i < length; i++) + { + sbyte refFlag = context.Reader.ReadInt8(); + if (refFlag == (sbyte)RefFlag.Null) + { + values.AddLast((T)elementSerializer.DefaultObject!); + } + else if (refFlag == (sbyte)RefFlag.NotNullValue) + { + values.AddLast(elementSerializer.Read(context, RefMode.None, true)); + } + else + { + throw new RefException($"invalid nullability flag {refFlag}"); + } + } + } + else + { + for (int i = 0; i < length; i++) + { + values.AddLast(elementSerializer.Read(context, RefMode.None, true)); + } + } + + return values; + } + + if (!frame.Declared) + { + context.TypeResolver.ReadTypeInfo(elementSerializer, context); + } + + if (frame.TrackRef) + { + for (int i = 0; i < length; i++) + { + values.AddLast(elementSerializer.Read(context, RefMode.Tracking, false)); + } + + if (!frame.Declared) + { + context.ClearReadTypeInfo(typeof(T)); + } + + return values; + } + + if (frame.HasNull) + { + for (int i = 0; i < length; i++) + { + sbyte refFlag = context.Reader.ReadInt8(); + if (refFlag == (sbyte)RefFlag.Null) + { + values.AddLast((T)elementSerializer.DefaultObject!); + } + else + { + values.AddLast(elementSerializer.ReadData(context)); + } + } + } + else + { + for (int i = 0; i < length; i++) + { + values.AddLast(elementSerializer.ReadData(context)); + } + } + + if (!frame.Declared) + { + context.ClearReadTypeInfo(typeof(T)); + } + + return values; + } } public sealed class QueueSerializer : Serializer> @@ -664,10 +1191,18 @@ public override void WriteData(WriteContext context, in Queue value, bool has public override Queue ReadData(ReadContext context) { + if (context.ShouldStoreRef) + { + return ReadStoredQueueData(context); + } + + uint refId = context.PauseRefPublication(); List values = CollectionCodec.ReadCollectionData( context.TypeResolver.GetSerializer(), context, - reserveOwner: false); + reserveOwner: false, + storeOwnerRef: false); + context.ResumeRefPublication(refId); CollectionCodec.ReserveElementStorage(context, values.Count); Queue queue = new(values.Count); for (int i = 0; i < values.Count; i++) @@ -677,6 +1212,111 @@ public override Queue ReadData(ReadContext context) return queue; } + + private static Queue ReadStoredQueueData(ReadContext context) + { + Serializer elementSerializer = context.TypeResolver.GetSerializer(); + int length = CollectionCodec.ReadLengthAndReserve(context); + Queue values = new(length); + context.StoreRef(values); + if (length == 0) + { + return values; + } + + CollectionCodec.CollectionFrame frame = CollectionCodec.ReadFrame(context, length); + if (!frame.SameType) + { + if (frame.TrackRef) + { + for (int i = 0; i < length; i++) + { + values.Enqueue(elementSerializer.Read(context, RefMode.Tracking, true)); + } + + return values; + } + + if (frame.HasNull) + { + for (int i = 0; i < length; i++) + { + sbyte refFlag = context.Reader.ReadInt8(); + if (refFlag == (sbyte)RefFlag.Null) + { + values.Enqueue((T)elementSerializer.DefaultObject!); + } + else if (refFlag == (sbyte)RefFlag.NotNullValue) + { + values.Enqueue(elementSerializer.Read(context, RefMode.None, true)); + } + else + { + throw new RefException($"invalid nullability flag {refFlag}"); + } + } + } + else + { + for (int i = 0; i < length; i++) + { + values.Enqueue(elementSerializer.Read(context, RefMode.None, true)); + } + } + + return values; + } + + if (!frame.Declared) + { + context.TypeResolver.ReadTypeInfo(elementSerializer, context); + } + + if (frame.TrackRef) + { + for (int i = 0; i < length; i++) + { + values.Enqueue(elementSerializer.Read(context, RefMode.Tracking, false)); + } + + if (!frame.Declared) + { + context.ClearReadTypeInfo(typeof(T)); + } + + return values; + } + + if (frame.HasNull) + { + for (int i = 0; i < length; i++) + { + sbyte refFlag = context.Reader.ReadInt8(); + if (refFlag == (sbyte)RefFlag.Null) + { + values.Enqueue((T)elementSerializer.DefaultObject!); + } + else + { + values.Enqueue(elementSerializer.ReadData(context)); + } + } + } + else + { + for (int i = 0; i < length; i++) + { + values.Enqueue(elementSerializer.ReadData(context)); + } + } + + if (!frame.Declared) + { + context.ClearReadTypeInfo(typeof(T)); + } + + return values; + } } public sealed class StackSerializer : Serializer> @@ -704,10 +1344,18 @@ public override void WriteData(WriteContext context, in Stack value, bool has public override Stack ReadData(ReadContext context) { + if (context.ShouldStoreRef) + { + return ReadStoredStackData(context); + } + + uint refId = context.PauseRefPublication(); List values = CollectionCodec.ReadCollectionData( context.TypeResolver.GetSerializer(), context, - reserveOwner: false); + reserveOwner: false, + storeOwnerRef: false); + context.ResumeRefPublication(refId); CollectionCodec.ReserveElementStorage(context, values.Count); Stack stack = new(values.Count); for (int i = 0; i < values.Count; i++) @@ -717,4 +1365,109 @@ public override Stack ReadData(ReadContext context) return stack; } + + private static Stack ReadStoredStackData(ReadContext context) + { + Serializer elementSerializer = context.TypeResolver.GetSerializer(); + int length = CollectionCodec.ReadLengthAndReserve(context); + Stack values = new(length); + context.StoreRef(values); + if (length == 0) + { + return values; + } + + CollectionCodec.CollectionFrame frame = CollectionCodec.ReadFrame(context, length); + if (!frame.SameType) + { + if (frame.TrackRef) + { + for (int i = 0; i < length; i++) + { + values.Push(elementSerializer.Read(context, RefMode.Tracking, true)); + } + + return values; + } + + if (frame.HasNull) + { + for (int i = 0; i < length; i++) + { + sbyte refFlag = context.Reader.ReadInt8(); + if (refFlag == (sbyte)RefFlag.Null) + { + values.Push((T)elementSerializer.DefaultObject!); + } + else if (refFlag == (sbyte)RefFlag.NotNullValue) + { + values.Push(elementSerializer.Read(context, RefMode.None, true)); + } + else + { + throw new RefException($"invalid nullability flag {refFlag}"); + } + } + } + else + { + for (int i = 0; i < length; i++) + { + values.Push(elementSerializer.Read(context, RefMode.None, true)); + } + } + + return values; + } + + if (!frame.Declared) + { + context.TypeResolver.ReadTypeInfo(elementSerializer, context); + } + + if (frame.TrackRef) + { + for (int i = 0; i < length; i++) + { + values.Push(elementSerializer.Read(context, RefMode.Tracking, false)); + } + + if (!frame.Declared) + { + context.ClearReadTypeInfo(typeof(T)); + } + + return values; + } + + if (frame.HasNull) + { + for (int i = 0; i < length; i++) + { + sbyte refFlag = context.Reader.ReadInt8(); + if (refFlag == (sbyte)RefFlag.Null) + { + values.Push((T)elementSerializer.DefaultObject!); + } + else + { + values.Push(elementSerializer.ReadData(context)); + } + } + } + else + { + for (int i = 0; i < length; i++) + { + values.Push(elementSerializer.ReadData(context)); + } + } + + if (!frame.Declared) + { + context.ClearReadTypeInfo(typeof(T)); + } + + return values; + } } diff --git a/csharp/src/Fory/DictionarySerializers.cs b/csharp/src/Fory/DictionarySerializers.cs index 3be9465a49..a1b13f6c5e 100644 --- a/csharp/src/Fory/DictionarySerializers.cs +++ b/csharp/src/Fory/DictionarySerializers.cs @@ -229,12 +229,23 @@ public override TDictionary ReadData(ReadContext context) if (totalLength == 0) { ReserveMapStorage(context, totalLength); - return CreateMap(0); + TDictionary empty = CreateMap(0); + if (context.ShouldStoreRef) + { + context.StoreRef(empty); + } + + return empty; } ReserveMapStorage(context, totalLength); context.Reader.CheckBound(totalLength); TDictionary map = CreateMap(totalLength); + if (context.ShouldStoreRef) + { + context.StoreRef(map); + } + bool keyDynamicType = keyTypeInfo.IsDynamicType; bool valueDynamicType = valueTypeInfo.IsDynamicType; int readCount = 0; diff --git a/csharp/src/Fory/FieldSkipper.cs b/csharp/src/Fory/FieldSkipper.cs index aaca46ec23..3120e16bf7 100644 --- a/csharp/src/Fory/FieldSkipper.cs +++ b/csharp/src/Fory/FieldSkipper.cs @@ -119,16 +119,10 @@ private static bool HasInlineTypeInfo(uint typeId) { uint reservedRefId = context.RefReader.ReserveRefId(); context.SetReservedRefId(reservedRefId); - try - { - object? value = ReadInlineTypedPayload(context); - context.StoreRef(value); - return value; - } - finally - { - context.ClearReservedRefId(); - } + object? value = ReadInlineTypedPayload(context); + context.StoreRef(value); + context.ClearReservedRefId(); + return value; } case RefFlag.NotNullValue: return ReadInlineTypedPayload(context); @@ -184,16 +178,10 @@ private static bool HasInlineTypeInfo(uint typeId) { uint reservedRefId = context.RefReader.ReserveRefId(); context.SetReservedRefId(reservedRefId); - try - { - object? value = context.TypeResolver.ReadAnyValue(typeInfo, context); - context.StoreRef(value); - return value; - } - finally - { - context.ClearReservedRefId(); - } + object? value = context.TypeResolver.ReadAnyValue(typeInfo, context); + context.StoreRef(value); + context.ClearReservedRefId(); + return value; } case RefFlag.NotNullValue: return context.TypeResolver.ReadAnyValue(typeInfo, context); diff --git a/csharp/src/Fory/Fory.cs b/csharp/src/Fory/Fory.cs index 79bfa9dfa6..a1169201ee 100644 --- a/csharp/src/Fory/Fory.cs +++ b/csharp/src/Fory/Fory.cs @@ -28,12 +28,14 @@ namespace Apache.Fory; public sealed class Fory { private readonly TypeResolver _typeResolver; + private readonly bool _trackRef; private WriteContext _writeContext; private ReadContext _readContext; internal Fory(Config config) { Config = config; + _trackRef = config.TrackRef; _typeResolver = new TypeResolver(); _writeContext = new WriteContext( new ByteWriter(), @@ -282,19 +284,74 @@ private T DeserializeFromReader(ByteReader reader) Serializer serializer = _typeResolver.GetSerializer(); ReadContext readContext = _readContext; readContext.ResetFor(reader); - GraphMemory.ReserveRootValue(readContext); - RefMode refMode = Config.TrackRef ? RefMode.Tracking : RefMode.NullOnly; - T value = serializer.Read(readContext, refMode, true); - readContext.RefReader.Reset(); + if (typeof(T).IsValueType) + { + GraphMemory.ReserveRootValue(readContext); + } + + T value = _trackRef + ? serializer.Read(readContext, RefMode.Tracking, true) + : ReadRootNoRef(serializer, readContext); + if (_trackRef || readContext.RefReader.HasRefs) + { + readContext.RefReader.Reset(); + } + if (readContext._reservedRefIds.Count != 0) + { + readContext._reservedRefIds.Clear(); + } readContext._typeMetaType = null; readContext._typeMeta = null; readContext._typeMetaByType?.ClearKeys(); readContext._readTypeInfoByType.ClearKeys(); - readContext._reservedRefIds.Clear(); readContext._cachedTypeMetaType = null; readContext._cachedTypeMeta = null; readContext._currentDynamicReadDepth = 0; return value; } + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static T ReadRootNoRef(Serializer serializer, ReadContext context) + { + RefFlag flag = (RefFlag)context.Reader.ReadInt8(); + if (flag == RefFlag.NotNullValue) + { + context.TypeResolver.ReadTypeInfo(serializer, context); + return serializer.ReadData(context); + } + + if (flag == RefFlag.Null) + { + return serializer.DefaultValue; + } + + return ReadRootRefFallback(serializer, context, flag); + } + + [MethodImpl(MethodImplOptions.NoInlining)] + private static T ReadRootRefFallback(Serializer serializer, ReadContext context, RefFlag flag) + { + switch (flag) + { + case RefFlag.Ref: + { + uint refId = context.RefReader.ReadRefId(context.Reader); + return context.RefReader.GetRef(refId); + } + case RefFlag.RefValue: + { + uint reservedRefId = context.RefReader.ReserveRefId(); + context.SetReservedRefId(reservedRefId); + context.TypeResolver.ReadTypeInfo(serializer, context); + T value = serializer.ReadData(context); + context.StoreRef(value); + context.ClearReservedRefId(); + context.RefReader.Reset(); + return value; + } + default: + throw new RefException($"invalid ref flag {(sbyte)flag}"); + } + } + } diff --git a/csharp/src/Fory/NullableKeyDictionary.cs b/csharp/src/Fory/NullableKeyDictionary.cs index cf479e6c81..00adde9aa4 100644 --- a/csharp/src/Fory/NullableKeyDictionary.cs +++ b/csharp/src/Fory/NullableKeyDictionary.cs @@ -548,16 +548,28 @@ public override NullableKeyDictionary ReadData(ReadContext context Serializer valueSerializer = context.TypeResolver.GetSerializer(); TypeInfo keyTypeInfo = context.TypeResolver.GetTypeInfo(); TypeInfo valueTypeInfo = context.TypeResolver.GetTypeInfo(); + bool storeRef = context.ShouldStoreRef; int totalLength = checked((int)context.Reader.ReadVarUInt32()); if (totalLength == 0) { ReserveMapStorage(context, totalLength); - return new NullableKeyDictionary(); + NullableKeyDictionary empty = new(); + if (storeRef) + { + context.StoreRef(empty); + } + + return empty; } ReserveMapStorage(context, totalLength); context.Reader.CheckBound(totalLength); NullableKeyDictionary map = new(totalLength); + if (storeRef) + { + context.StoreRef(map); + } + bool keyDynamicType = keyTypeInfo.IsDynamicType; bool valueDynamicType = valueTypeInfo.IsDynamicType; int readCount = 0; diff --git a/csharp/src/Fory/PrimitiveDictionarySerializers.cs b/csharp/src/Fory/PrimitiveDictionarySerializers.cs index 4d56772192..d9520b3525 100644 --- a/csharp/src/Fory/PrimitiveDictionarySerializers.cs +++ b/csharp/src/Fory/PrimitiveDictionarySerializers.cs @@ -668,7 +668,7 @@ internal static class PrimitiveDictionaryCodecReader private const int ReferenceBytes = 4; [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static int ElementBytes() => typeof(T).IsValueType ? Unsafe.SizeOf() : ReferenceBytes; + private static int ElementBytes() => ElementStorage.Bytes; [MethodImpl(MethodImplOptions.AggressiveInlining)] private static void ReserveMapStorage(ReadContext context, int count) @@ -676,6 +676,11 @@ private static void ReserveMapStorage(ReadContext context, int cou context.ReserveGraphMemory(MapBytes + count * ((long)ElementBytes() + ElementBytes())); } + private static class ElementStorage + { + internal static readonly int Bytes = typeof(T).IsValueType ? Unsafe.SizeOf() : ReferenceBytes; + } + public static TMap ReadMap(ReadContext context) where TKey : notnull where TKeyCodec : struct, IPrimitiveDictionaryCodec diff --git a/csharp/src/Fory/ReadContext.cs b/csharp/src/Fory/ReadContext.cs index 9cece92aa6..f82d329632 100644 --- a/csharp/src/Fory/ReadContext.cs +++ b/csharp/src/Fory/ReadContext.cs @@ -22,6 +22,7 @@ namespace Apache.Fory; public sealed class ReadContext { private const int MinRemoteTypeMetaLimit = 8192; + private const uint NoReservedRefId = uint.MaxValue; private readonly ReusableArray _typeMetaRefs = new(); private readonly UInt64Map _typeMetasByHeader = new(); @@ -39,6 +40,7 @@ public sealed class ReadContext internal Type? _cachedTypeMetaType; internal TypeMeta? _cachedTypeMeta; internal int _currentDynamicReadDepth; + private bool _hasReservedRefId; private readonly Dictionary _remoteSchemaVersionsByType = []; private readonly Config _config; private int _totalAcceptedSchemaVersions; @@ -72,6 +74,12 @@ public ReadContext( public bool CheckStructVersion { get; } + public bool ShouldStoreRef + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + get => _hasReservedRefId; + } + internal RefReader RefReader { get; } [MethodImpl(MethodImplOptions.AggressiveInlining)] @@ -99,34 +107,57 @@ internal void InitGraphBudget() [MethodImpl(MethodImplOptions.AggressiveInlining)] public void ReserveGraphMemory(long bytes) { - if (bytes < 0) - { - ThrowGraphBudgetOverflow(); - } - if (_graphMemoryLimitBytes <= 0) + long remaining = _remainingGraphMemoryBytes; + if ((ulong)bytes > (ulong)remaining) { + ReserveGraphMemorySlow(bytes, remaining); return; } + + _remainingGraphMemoryBytes = remaining - bytes; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public void ReserveGraphMemory(int bytes) + { long remaining = _remainingGraphMemoryBytes; - if (bytes > remaining) + if (bytes < 0 || bytes > remaining) { - ThrowGraphBudgetExceeded(bytes, remaining, _graphMemoryLimitBytes); + ReserveGraphMemorySlow(bytes, remaining); + return; } _remainingGraphMemoryBytes = remaining - bytes; } - [MethodImpl(MethodImplOptions.NoInlining)] - private static void ThrowGraphBudgetOverflow() + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public void ReserveGraphMemory(uint bytes) { - throw new InvalidDataException("graph memory estimate overflows"); + long remaining = _remainingGraphMemoryBytes - bytes; + if (remaining < 0) + { + ReserveGraphMemorySlow(bytes, _remainingGraphMemoryBytes); + return; + } + + _remainingGraphMemoryBytes = remaining; } [MethodImpl(MethodImplOptions.NoInlining)] - private static void ThrowGraphBudgetExceeded(long bytes, long remaining, long limit) + private void ReserveGraphMemorySlow(long bytes, long remaining) { + if (bytes < 0) + { + throw new InvalidDataException("graph memory estimate overflows"); + } + + if (_graphMemoryLimitBytes <= 0) + { + return; + } + throw new InvalidDataException( - $"estimated graph memory request {bytes} bytes exceeds MaxGraphMemoryBytes remaining budget {remaining} bytes out of effective limit {limit} bytes"); + $"estimated graph memory request {bytes} bytes exceeds MaxGraphMemoryBytes remaining budget {remaining} bytes out of effective limit {_graphMemoryLimitBytes} bytes"); } internal void ResetFor(ByteReader reader) @@ -463,27 +494,83 @@ internal void ClearReadTypeInfo(Type type) _readTypeInfoByType.Remove(TypeMapKey.Get(type)); } + [MethodImpl(MethodImplOptions.AggressiveInlining)] public void StoreRef(object? value) { - if (_reservedRefIds.Count == 0) + if (!_hasReservedRefId) { return; } - RefReader.StoreRefAt(_reservedRefIds[^1], value); + int index = _reservedRefIds.Count - 1; + if (index < 0) + { + _hasReservedRefId = false; + return; + } + + RefReader.StoreRefAt(_reservedRefIds[index], value); + _hasReservedRefId = false; } internal void SetReservedRefId(uint refId) { _reservedRefIds.Add(refId); + _hasReservedRefId = true; + } + + /// + /// Hides the current publishable ref id while a serializer reads a child or temporary owner. + /// + /// + /// The reserved slot stays on the stack so the outer owner can publish it after materialization. + /// This prevents immutable wrappers and conversion serializers from letting children consume the + /// parent ref id before the parent object exists. + /// + public uint PauseRefPublication() + { + if (!_hasReservedRefId) + { + return NoReservedRefId; + } + + int index = _reservedRefIds.Count - 1; + if (index < 0) + { + _hasReservedRefId = false; + return NoReservedRefId; + } + + _hasReservedRefId = false; + return _reservedRefIds[index]; + } + + /// Restores a ref id hidden by . + public void ResumeRefPublication(uint refId) + { + if (refId == NoReservedRefId) + { + return; + } + + int index = _reservedRefIds.Count - 1; + if (index < 0) + { + throw new RefException($"cannot resume ref publication for ref id {refId}"); + } + + _hasReservedRefId = true; } internal void ClearReservedRefId() { - if (_reservedRefIds.Count > 0) + int count = _reservedRefIds.Count; + if (count > 0) { - _reservedRefIds.RemoveAt(_reservedRefIds.Count - 1); + _reservedRefIds.RemoveAt(count - 1); } + + _hasReservedRefId = false; } internal void IncreaseReadDepth() @@ -512,6 +599,7 @@ internal void Reset() _typeMetaByType?.ClearKeys(); _readTypeInfoByType.ClearKeys(); _reservedRefIds.Clear(); + _hasReservedRefId = false; _cachedTypeMetaType = null; _cachedTypeMeta = null; _currentDynamicReadDepth = 0; diff --git a/csharp/src/Fory/RefResolver.cs b/csharp/src/Fory/RefResolver.cs index d4e6dbf574..98ccb81fff 100644 --- a/csharp/src/Fory/RefResolver.cs +++ b/csharp/src/Fory/RefResolver.cs @@ -60,6 +60,11 @@ public sealed class RefReader { private readonly List _refs = []; + internal bool HasRefs + { + get => _refs.Count != 0; + } + public RefFlag ReadRefFlag(ByteReader reader) { return (RefFlag)reader.ReadInt8(); @@ -107,11 +112,22 @@ public T GetRef(uint refId) throw new RefException($"ref_id out of range: {refId}"); } - return _refs[index]; + object? value = _refs[index]; + if (value is null) + { + throw new RefException($"ref_id {refId} has not been published"); + } + + return value; } public void Reset() { + if (_refs.Count == 0) + { + return; + } + _refs.Clear(); } } diff --git a/csharp/src/Fory/Serializer.cs b/csharp/src/Fory/Serializer.cs index d2bc727c20..33e874208e 100644 --- a/csharp/src/Fory/Serializer.cs +++ b/csharp/src/Fory/Serializer.cs @@ -114,21 +114,15 @@ public virtual T Read(ReadContext context, RefMode refMode, bool readTypeInfo) { uint reservedRefId = context.RefReader.ReserveRefId(); context.SetReservedRefId(reservedRefId); - try + if (readTypeInfo) { - if (readTypeInfo) - { - context.TypeResolver.ReadTypeInfo(this, context); - } - - T value = ReadData(context); - context.StoreRef(value); - return value; - } - finally - { - context.ClearReservedRefId(); + context.TypeResolver.ReadTypeInfo(this, context); } + + T value = ReadData(context); + context.StoreRef(value); + context.ClearReservedRefId(); + return value; } case RefFlag.NotNullValue: break; diff --git a/csharp/src/Fory/TypeResolver.cs b/csharp/src/Fory/TypeResolver.cs index a5cb888c4b..9bb5c7ff37 100644 --- a/csharp/src/Fory/TypeResolver.cs +++ b/csharp/src/Fory/TypeResolver.cs @@ -982,10 +982,7 @@ TypeId.CompatibleStruct or return actualWireType == expected; } - private object? ReadRegisteredValue( - TypeInfo typeInfo, - ReadContext context, - TypeMeta? typeMeta) + private object? ReadRegisteredValue(TypeInfo typeInfo, ReadContext context, TypeMeta? typeMeta) { if (typeMeta is not null) { diff --git a/csharp/src/Fory/UnionSerializer.cs b/csharp/src/Fory/UnionSerializer.cs index 2cd549068c..47197c9271 100644 --- a/csharp/src/Fory/UnionSerializer.cs +++ b/csharp/src/Fory/UnionSerializer.cs @@ -50,6 +50,7 @@ public override void WriteData(WriteContext context, in TUnion value, bool hasGe public override TUnion ReadData(ReadContext context) { + uint refId = context.PauseRefPublication(); uint rawCaseId = context.Reader.ReadVarUInt32(); if (rawCaseId > int.MaxValue) { @@ -67,7 +68,9 @@ public override TUnion ReadData(ReadContext context) caseValue = DynamicAnyCodec.ReadAny(context, RefMode.Tracking, true); } - return Factory(caseId, caseValue); + TUnion value = Factory(caseId, caseValue); + context.ResumeRefPublication(refId); + return value; } private static void CheckWireCaseId(int caseId) diff --git a/csharp/src/Fory/UnknownCaseSerializer.cs b/csharp/src/Fory/UnknownCaseSerializer.cs index f3866247c4..bd4c3c9597 100644 --- a/csharp/src/Fory/UnknownCaseSerializer.cs +++ b/csharp/src/Fory/UnknownCaseSerializer.cs @@ -51,16 +51,10 @@ public static UnknownCase ReadPayload(ReadContext context, int caseId) { uint reservedRefId = context.RefReader.ReserveRefId(); context.SetReservedRefId(reservedRefId); - try - { - (uint typeId, object? value) = ReadNonNullPayload(context); - context.StoreRef(value); - return UnknownCase.FromRuntime(caseId, typeId, value); - } - finally - { - context.ClearReservedRefId(); - } + (uint typeId, object? value) = ReadNonNullPayload(context); + context.StoreRef(value); + context.ClearReservedRefId(); + return UnknownCase.FromRuntime(caseId, typeId, value); } case RefFlag.NotNullValue: { diff --git a/csharp/tests/Fory.Tests/ForyRuntimeTests.cs b/csharp/tests/Fory.Tests/ForyRuntimeTests.cs index 6d93003948..650eccd3c4 100644 --- a/csharp/tests/Fory.Tests/ForyRuntimeTests.cs +++ b/csharp/tests/Fory.Tests/ForyRuntimeTests.cs @@ -62,6 +62,12 @@ public sealed class Node public Node? Next { get; set; } } +[ForyStruct] +public sealed class AnyNode +{ + public object? Next { get; set; } +} + [ForyStruct] public sealed class FieldOrder { @@ -606,6 +612,148 @@ public void ForyReusedContextsHandleSequentialCalls() } } + [Fact] + public void WireTrackingRefsWorkWithNoRefReader() + { + ForyRuntime writer = ForyRuntime.Builder().TrackRef(true).Build(); + writer.Register(953); + ForyRuntime reader = ForyRuntime.Builder().TrackRef(false).Build(); + reader.Register(953); + + Node source = new() { Value = 7 }; + source.Next = source; + + Node decoded = reader.Deserialize(writer.Serialize(source)); + + Assert.Equal(7, decoded.Value); + Assert.NotNull(decoded.Next); + Assert.Same(decoded, decoded.Next); + } + + [Fact] + public void WireTrackingRefsResetAfterNoRefRoot() + { + ForyRuntime writer = ForyRuntime.Builder().TrackRef(true).Build(); + writer.Register(954); + ForyRuntime reader = ForyRuntime.Builder().TrackRef(false).Build(); + reader.Register(954); + + Node source = new() { Value = 9 }; + source.Next = source; + _ = reader.Deserialize(writer.Serialize(source)); + + ByteWriter staleRefPayload = new(); + staleRefPayload.WriteUInt8(ForyHeaderFlag.IsXlang); + staleRefPayload.WriteInt8((sbyte)RefFlag.Ref); + staleRefPayload.WriteVarUInt32(0); + + Assert.Throws(() => reader.Deserialize(staleRefPayload.ToArray())); + } + + [Fact] + public void DynamicAnyPublishesTrackedRef() + { + ForyRuntime writer = ForyRuntime.Builder().TrackRef(true).Build(); + writer.Register(955); + ForyRuntime reader = ForyRuntime.Builder().TrackRef(false).Build(); + reader.Register(955); + + Node source = new() { Value = 11 }; + source.Next = source; + + object decodedObject = reader.Deserialize(writer.Serialize(source)); + Node decoded = Assert.IsType(decodedObject); + + Assert.Equal(11, decoded.Value); + Assert.NotNull(decoded.Next); + Assert.Same(decoded, decoded.Next); + } + + [Fact] + public void DynamicContainersPublishTrackedRefs() + { + ForyRuntime writer = ForyRuntime.Builder().TrackRef(true).Build(); + ForyRuntime reader = ForyRuntime.Builder().TrackRef(false).Build(); + + List list = []; + list.Add(list); + object decodedListObject = reader.Deserialize(writer.Serialize(list)); + List decodedList = Assert.IsType>(decodedListObject); + Assert.Same(decodedList, decodedList[0]); + + HashSet set = []; + set.Add(set); + object decodedSetObject = reader.Deserialize(writer.Serialize(set)); + HashSet decodedSet = Assert.IsType>(decodedSetObject); + Assert.Contains(decodedSet, decodedSet); + + Dictionary map = []; + map["self"] = map; + object decodedMapObject = reader.Deserialize(writer.Serialize(map)); + NullableKeyDictionary decodedMap = + Assert.IsType>(decodedMapObject); + Assert.True(decodedMap.TryGetValue("self", out object? self)); + Assert.Same(decodedMap, self); + } + + [Fact] + public void ArraysAndMapsPublishRefs() + { + ForyRuntime writer = ForyRuntime.Builder().TrackRef(true).Build(); + ForyRuntime reader = ForyRuntime.Builder().TrackRef(false).Build(); + + object?[] array = new object?[1]; + array[0] = array; + object?[] decodedArray = reader.Deserialize(writer.Serialize(array)); + Assert.Same(decodedArray, decodedArray[0]); + + Dictionary map = []; + map["self"] = map; + Dictionary decodedMap = + reader.Deserialize>(writer.Serialize(map)); + Assert.Same(decodedMap, decodedMap["self"]); + } + + [Fact] + public void MutableCollectionsPublishRefs() + { + ForyRuntime writer = ForyRuntime.Builder().TrackRef(true).Build(); + ForyRuntime reader = ForyRuntime.Builder().TrackRef(false).Build(); + + LinkedList linkedList = []; + linkedList.AddLast(linkedList); + LinkedList decodedList = + reader.Deserialize>(writer.Serialize(linkedList)); + Assert.Same(decodedList, decodedList.First!.Value); + + Queue queue = []; + queue.Enqueue(queue); + Queue decodedQueue = reader.Deserialize>(writer.Serialize(queue)); + Assert.Same(decodedQueue, decodedQueue.Peek()); + + Stack stack = []; + stack.Push(stack); + Stack decodedStack = reader.Deserialize>(writer.Serialize(stack)); + Assert.Same(decodedStack, decodedStack.Peek()); + } + + [Fact] + public void UnionCycleRefFailsLoudly() + { + ForyRuntime writer = ForyRuntime.Builder().TrackRef(true).Build(); + writer.Register(956); + writer.Register(957); + ForyRuntime reader = ForyRuntime.Builder().TrackRef(false).Build(); + reader.Register(956); + reader.Register(957); + + AnyNode node = new(); + Union union = new(0, node); + node.Next = union; + + Assert.Throws(() => reader.Deserialize(writer.Serialize(union))); + } + [Fact] public void ThreadSafeForySupportsParallelPrimitiveRoundTrip() { @@ -2427,7 +2575,8 @@ public void GeneratedSerializerSupportsObjectKeyMap() }; DynamicAnyHolder decoded = fory.Deserialize(fory.Serialize(source)); - Dictionary dynamicMap = Assert.IsType>(decoded.AnyValue); + NullableKeyDictionary dynamicMap = + Assert.IsType>(decoded.AnyValue); Assert.Equal(9, dynamicMap["inner"]); Assert.Equal("ten", dynamicMap[10]); Assert.Equal(source.AnySet.Count, decoded.AnySet.Count); diff --git a/csharp/tests/Fory.Tests/GraphMemoryBudgetTests.cs b/csharp/tests/Fory.Tests/GraphMemoryBudgetTests.cs index 1d85e9079f..26e5f12c9b 100644 --- a/csharp/tests/Fory.Tests/GraphMemoryBudgetTests.cs +++ b/csharp/tests/Fory.Tests/GraphMemoryBudgetTests.cs @@ -76,6 +76,20 @@ public sealed class GeneratedSchemaMapBudget public Dictionary Values { get; set; } = []; } +[ForyStruct] +public sealed class CompatibleBudgetList +{ + [ForyField(Type = typeof(S.List))] + public List Values { get; set; } = []; +} + +[ForyStruct] +public sealed class CompatibleBudgetArray +{ + [ForyField(Type = typeof(S.Array))] + public int[] Values { get; set; } = []; +} + public sealed class GraphMemoryBudgetTests { private const int ReferenceBytes = 4; @@ -142,6 +156,7 @@ public void DefaultFixedBudgetAndDisable() ReadContext disabled = new(new ByteReader([]), new TypeResolver(), NewFory(0).Config); disabled.InitGraphBudget(); disabled.ReserveGraphMemory(long.MaxValue); + Assert.Throws(() => disabled.ReserveGraphMemory(-1)); } [Fact] @@ -306,6 +321,23 @@ public void DenseStringBinaryAndPrimitiveArraysAreSkipped() Assert.Equal(new[] { 1, 2, 3 }, NewFory(1).Deserialize(Serialize(new[] { 1, 2, 3 }))); } + [Fact] + public void CompatibleListToDenseArrayIsSkipped() + { + ForyRuntime writer = ForyRuntime.Builder().Compatible(true).TrackRef(false).Build(); + writer.Register(1010); + byte[] bytes = writer.Serialize(new CompatibleBudgetList { Values = [1, 2, 3] }); + + ForyRuntime reader = ForyRuntime.Builder() + .Compatible(true) + .TrackRef(false) + .MaxGraphMemoryBytes(GeneratedGraphHolderBytes) + .Build(); + reader.Register(1010); + + Assert.Equal(new[] { 1, 2, 3 }, reader.Deserialize(bytes).Values); + } + [Fact] public void ByteAvailabilityCheckStillRejectsLargeLength() { diff --git a/docs/guide/csharp/basic-serialization.md b/docs/guide/csharp/basic-serialization.md index d112990ccd..5c88675bc8 100644 --- a/docs/guide/csharp/basic-serialization.md +++ b/docs/guide/csharp/basic-serialization.md @@ -104,6 +104,11 @@ byte[] payload = fory.Serialize(value); object? decoded = fory.Deserialize(payload); ``` +Dynamic maps normally decode as `Dictionary` when they have no +null key. If the payload uses reference tracking for the dynamic map itself, C# +returns `NullableKeyDictionary` so nested references and null +keys point to the decoded map owner. + ## Buffer Writer API Serialize directly into `IBufferWriter` targets. diff --git a/docs/guide/csharp/references.md b/docs/guide/csharp/references.md index 772b2535ca..6b6b2363c0 100644 --- a/docs/guide/csharp/references.md +++ b/docs/guide/csharp/references.md @@ -65,6 +65,10 @@ System.Diagnostics.Debug.Assert(object.ReferenceEquals(decoded, decoded.Next)); `TrackRef(false)` can be faster for tree-like, acyclic data where reference identity does not matter. +C# union wrappers are immutable and are created after their case payload is read. +Reference cycles from a union case payload back to the containing union are not +supported; Fory rejects unresolved refs instead of returning a partial union. + ## Related Topics - [Configuration](configuration.md) From 81d1ebc1e02ad94966e42c77064c591fd7b97024 Mon Sep 17 00:00:00 2001 From: chaokunyang Date: Thu, 2 Jul 2026 12:20:27 +0800 Subject: [PATCH 22/54] fix: address graph budget CI breaks --- cpp/fory/serialization/fory.h | 3 +- .../CompatibleDifferentSchemaExample.java | 2 +- javascript/package-lock.json | 277 +++++++++++++++--- javascript/package.json | 11 +- 4 files changed, 238 insertions(+), 55 deletions(-) diff --git a/cpp/fory/serialization/fory.h b/cpp/fory/serialization/fory.h index a0fb717ad2..d6e8cfc0e2 100644 --- a/cpp/fory/serialization/fory.h +++ b/cpp/fory/serialization/fory.h @@ -889,13 +889,12 @@ class Fory : public BaseFory { read_ctx_->attach(buffer); if constexpr (needs_graph_budget_v) { constexpr size_t root_owner_bytes = graph_value_owner_self_bytes(); - constexpr bool has_child_budget = has_graph_budget_children_v; if constexpr (root_owner_bytes != 0) { if (FORY_PREDICT_FALSE( !read_ctx_->template init_graph_budget())) { return Unexpected(read_ctx_->take_error()); } - } else if constexpr (has_child_budget) { + } else if constexpr (has_graph_budget_children_v) { if (FORY_PREDICT_FALSE(!read_ctx_->template init_graph_budget<>())) { return Unexpected(read_ctx_->take_error()); } diff --git a/integration_tests/graalvm_tests/src/main/java/org/apache/fory/graalvm/CompatibleDifferentSchemaExample.java b/integration_tests/graalvm_tests/src/main/java/org/apache/fory/graalvm/CompatibleDifferentSchemaExample.java index c70d016deb..41ddae5cda 100644 --- a/integration_tests/graalvm_tests/src/main/java/org/apache/fory/graalvm/CompatibleDifferentSchemaExample.java +++ b/integration_tests/graalvm_tests/src/main/java/org/apache/fory/graalvm/CompatibleDifferentSchemaExample.java @@ -89,7 +89,7 @@ private static Serializer readSerializerForTarget( MemoryBuffer buffer = MemoryUtils.wrap(bytes); buffer.readByte(); ReadContext readContext = fory.getReadContext(); - readContext.prepare(buffer, null, false, buffer.remaining(), false); + readContext.prepare(buffer, null, false); try { readContext.getRefReader().tryPreserveRefId(buffer); TypeInfo typeInfo = fory.getTypeResolver().readTypeInfo(readContext, targetClass); diff --git a/javascript/package-lock.json b/javascript/package-lock.json index 1674116f60..c095ab19a2 100644 --- a/javascript/package-lock.json +++ b/javascript/package-lock.json @@ -9,14 +9,13 @@ "packages/core" ], "devDependencies": { + "@stylistic/eslint-plugin": "^1.5.1", "@types/js-beautify": "^1.14.3", - "@types/node": "18.19.130", + "@types/node": "^18.19.68", "eslint": "^8.55.0", - "eslint-config-prettier": "^10.1.8", "jest": "^29.5.0", "jest-junit": "^17.0.0", "js-beautify": "^1.14.11", - "prettier": "^3.9.4", "ts-jest": "^29.0.2", "typescript": "^4.8.4" } @@ -1450,6 +1449,97 @@ "@sinonjs/commons": "^3.0.0" } }, + "node_modules/@stylistic/eslint-plugin": { + "version": "1.8.1", + "resolved": "https://registry.npmjs.org/@stylistic/eslint-plugin/-/eslint-plugin-1.8.1.tgz", + "integrity": "sha512-64My6I7uCcmSQ//427Pfg2vjSf9SDzfsGIWohNFgISMLYdC5BzJqDo647iDDJzSxINh3WTC0Ql46ifiKuOoTyA==", + "dev": true, + "license": "MIT", + "dependencies": { + "@stylistic/eslint-plugin-js": "1.8.1", + "@stylistic/eslint-plugin-jsx": "1.8.1", + "@stylistic/eslint-plugin-plus": "1.8.1", + "@stylistic/eslint-plugin-ts": "1.8.1", + "@types/eslint": "^8.56.10" + }, + "engines": { + "node": "^16.0.0 || >=18.0.0" + }, + "peerDependencies": { + "eslint": ">=8.40.0" + } + }, + "node_modules/@stylistic/eslint-plugin-js": { + "version": "1.8.1", + "resolved": "https://registry.npmjs.org/@stylistic/eslint-plugin-js/-/eslint-plugin-js-1.8.1.tgz", + "integrity": "sha512-c5c2C8Mos5tTQd+NWpqwEu7VT6SSRooAguFPMj1cp2RkTYl1ynKoXo8MWy3k4rkbzoeYHrqC2UlUzsroAN7wtQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "@types/eslint": "^8.56.10", + "acorn": "^8.11.3", + "escape-string-regexp": "^4.0.0", + "eslint-visitor-keys": "^3.4.3", + "espree": "^9.6.1" + }, + "engines": { + "node": "^16.0.0 || >=18.0.0" + }, + "peerDependencies": { + "eslint": ">=8.40.0" + } + }, + "node_modules/@stylistic/eslint-plugin-jsx": { + "version": "1.8.1", + "resolved": "https://registry.npmjs.org/@stylistic/eslint-plugin-jsx/-/eslint-plugin-jsx-1.8.1.tgz", + "integrity": "sha512-k1Eb6rcjMP+mmjvj+vd9y5KUdWn1OBkkPLHXhsrHt5lCDFZxJEs0aVQzE5lpYrtVZVkpc5esTtss/cPJux0lfA==", + "dev": true, + "license": "MIT", + "dependencies": { + "@stylistic/eslint-plugin-js": "^1.8.1", + "@types/eslint": "^8.56.10", + "estraverse": "^5.3.0", + "picomatch": "^4.0.2" + }, + "engines": { + "node": "^16.0.0 || >=18.0.0" + }, + "peerDependencies": { + "eslint": ">=8.40.0" + } + }, + "node_modules/@stylistic/eslint-plugin-plus": { + "version": "1.8.1", + "resolved": "https://registry.npmjs.org/@stylistic/eslint-plugin-plus/-/eslint-plugin-plus-1.8.1.tgz", + "integrity": "sha512-4+40H3lHYTN8OWz+US8CamVkO+2hxNLp9+CAjorI7top/lHqemhpJvKA1LD9Uh+WMY9DYWiWpL2+SZ2wAXY9fQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "@types/eslint": "^8.56.10", + "@typescript-eslint/utils": "^6.21.0" + }, + "peerDependencies": { + "eslint": "*" + } + }, + "node_modules/@stylistic/eslint-plugin-ts": { + "version": "1.8.1", + "resolved": "https://registry.npmjs.org/@stylistic/eslint-plugin-ts/-/eslint-plugin-ts-1.8.1.tgz", + "integrity": "sha512-/q1m+ZuO1JHfiSF16EATFzv7XSJkc5W6DocfvH5o9oB6WWYFMF77fVoBWnKT3wGptPOc2hkRupRKhmeFROdfWA==", + "dev": true, + "license": "MIT", + "dependencies": { + "@stylistic/eslint-plugin-js": "1.8.1", + "@types/eslint": "^8.56.10", + "@typescript-eslint/utils": "^6.21.0" + }, + "engines": { + "node": "^16.0.0 || >=18.0.0" + }, + "peerDependencies": { + "eslint": ">=8.40.0" + } + }, "node_modules/@types/babel__core": { "version": "7.20.5", "resolved": "https://registry.npmjs.org/@types/babel__core/-/babel__core-7.20.5.tgz", @@ -1495,6 +1585,24 @@ "@babel/types": "^7.28.2" } }, + "node_modules/@types/eslint": { + "version": "8.56.12", + "resolved": "https://registry.npmjs.org/@types/eslint/-/eslint-8.56.12.tgz", + "integrity": "sha512-03ruubjWyOHlmljCVoxSuNDdmfZDzsrrz0P2LeJsOXr+ZwFQ+0yQIwNCwt/GYhV7Z31fgtXJTAEs+FYlEL851g==", + "dev": true, + "license": "MIT", + "dependencies": { + "@types/estree": "*", + "@types/json-schema": "*" + } + }, + "node_modules/@types/estree": { + "version": "1.0.9", + "resolved": "https://registry.npmjs.org/@types/estree/-/estree-1.0.9.tgz", + "integrity": "sha512-GhdPgy1el4/ImP05X05Uw4cw2/M93BCUmnEvWZNStlCzEKME4Fkk+YpoA5OiHNQmoS7Cafb8Xa3Pya8m1Qrzeg==", + "dev": true, + "license": "MIT" + }, "node_modules/@types/graceful-fs": { "version": "4.1.9", "resolved": "https://registry.npmjs.org/@types/graceful-fs/-/graceful-fs-4.1.9.tgz", @@ -1547,9 +1655,9 @@ "license": "MIT" }, "node_modules/@types/node": { - "version": "18.19.130", - "resolved": "https://registry.npmjs.org/@types/node/-/node-18.19.130.tgz", - "integrity": "sha512-GRaXQx6jGfL8sKfaIDD6OupbIHBr9jv7Jnaml9tB7l4v068PAOXqfcujMMo5PhbIs6ggR1XODELqahT2R8v0fg==", + "version": "18.19.68", + "resolved": "https://registry.npmjs.org/@types/node/-/node-18.19.68.tgz", + "integrity": "sha512-QGtpFH1vB99ZmTa63K4/FU8twThj4fuVSBkGddTp7uIL/cuoLWIUSL2RcOaigBhfR+hg5pgGkBnkoOxrTVBMKw==", "dev": true, "license": "MIT", "dependencies": { @@ -1857,6 +1965,24 @@ "url": "https://opencollective.com/typescript-eslint" } }, + "node_modules/@typescript-eslint/scope-manager": { + "version": "6.21.0", + "resolved": "https://registry.npmjs.org/@typescript-eslint/scope-manager/-/scope-manager-6.21.0.tgz", + "integrity": "sha512-OwLUIWZJry80O99zvqXVEioyniJMa+d2GrqpUTqi5/v5D5rOrppJVBPa0yKCblcigC0/aYAzxxqQ1B+DS2RYsg==", + "dev": true, + "license": "MIT", + "dependencies": { + "@typescript-eslint/types": "6.21.0", + "@typescript-eslint/visitor-keys": "6.21.0" + }, + "engines": { + "node": "^16.0.0 || >=18.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/typescript-eslint" + } + }, "node_modules/@typescript-eslint/type-utils": { "version": "5.62.0", "resolved": "https://registry.npmjs.org/@typescript-eslint/type-utils/-/type-utils-5.62.0.tgz", @@ -2014,6 +2140,93 @@ "node": ">=4.0" } }, + "node_modules/@typescript-eslint/types": { + "version": "6.21.0", + "resolved": "https://registry.npmjs.org/@typescript-eslint/types/-/types-6.21.0.tgz", + "integrity": "sha512-1kFmZ1rOm5epu9NZEZm1kckCDGj5UJEf7P1kliH4LKu/RkwpsfqqGmY2OOcUs18lSlQBKLDYBOGxRVtrMN5lpg==", + "dev": true, + "license": "MIT", + "engines": { + "node": "^16.0.0 || >=18.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/typescript-eslint" + } + }, + "node_modules/@typescript-eslint/typescript-estree": { + "version": "6.21.0", + "resolved": "https://registry.npmjs.org/@typescript-eslint/typescript-estree/-/typescript-estree-6.21.0.tgz", + "integrity": "sha512-6npJTkZcO+y2/kr+z0hc4HwNfrrP4kNYh57ek7yCNlrBjWQ1Y0OS7jiZTkgumrvkX5HkEKXFZkkdFNkaW2wmUQ==", + "dev": true, + "license": "BSD-2-Clause", + "dependencies": { + "@typescript-eslint/types": "6.21.0", + "@typescript-eslint/visitor-keys": "6.21.0", + "debug": "^4.3.4", + "globby": "^11.1.0", + "is-glob": "^4.0.3", + "minimatch": "9.0.3", + "semver": "^7.5.4", + "ts-api-utils": "^1.0.1" + }, + "engines": { + "node": "^16.0.0 || >=18.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/typescript-eslint" + }, + "peerDependenciesMeta": { + "typescript": { + "optional": true + } + } + }, + "node_modules/@typescript-eslint/utils": { + "version": "6.21.0", + "resolved": "https://registry.npmjs.org/@typescript-eslint/utils/-/utils-6.21.0.tgz", + "integrity": "sha512-NfWVaC8HP9T8cbKQxHcsJBY5YE1O33+jpMwN45qzWWaPDZgLIbo12toGMWnmhvCpd3sIxkpDw3Wv1B3dYrbDQQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "@eslint-community/eslint-utils": "^4.4.0", + "@types/json-schema": "^7.0.12", + "@types/semver": "^7.5.0", + "@typescript-eslint/scope-manager": "6.21.0", + "@typescript-eslint/types": "6.21.0", + "@typescript-eslint/typescript-estree": "6.21.0", + "semver": "^7.5.4" + }, + "engines": { + "node": "^16.0.0 || >=18.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/typescript-eslint" + }, + "peerDependencies": { + "eslint": "^7.0.0 || ^8.0.0" + } + }, + "node_modules/@typescript-eslint/visitor-keys": { + "version": "6.21.0", + "resolved": "https://registry.npmjs.org/@typescript-eslint/visitor-keys/-/visitor-keys-6.21.0.tgz", + "integrity": "sha512-JJtkDduxLi9bivAB+cYOVMtbkqdPOhZ+ZI5LC47MIRrDV4Yn2o+ZnW10Nkmr28xRpSpdJ6Sm42Hjf2+REYXm0A==", + "dev": true, + "license": "MIT", + "dependencies": { + "@typescript-eslint/types": "6.21.0", + "eslint-visitor-keys": "^3.4.1" + }, + "engines": { + "node": "^16.0.0 || >=18.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/typescript-eslint" + } + }, "node_modules/@ungap/structured-clone": { "version": "1.3.1", "resolved": "https://registry.npmjs.org/@ungap/structured-clone/-/structured-clone-1.3.1.tgz", @@ -2968,22 +3181,6 @@ "url": "https://opencollective.com/eslint" } }, - "node_modules/eslint-config-prettier": { - "version": "10.1.8", - "resolved": "https://registry.npmjs.org/eslint-config-prettier/-/eslint-config-prettier-10.1.8.tgz", - "integrity": "sha512-82GZUjRS0p/jganf6q1rEO25VSoHH0hKPCTrgillPjdI/3bgBhAE1QzHrHTizjpRvy6pGAvKjDJtk2pF9NDq8w==", - "dev": true, - "license": "MIT", - "bin": { - "eslint-config-prettier": "bin/cli.js" - }, - "funding": { - "url": "https://opencollective.com/eslint-config-prettier" - }, - "peerDependencies": { - "eslint": ">=7.0.0" - } - }, "node_modules/eslint-scope": { "version": "7.2.2", "resolved": "https://registry.npmjs.org/eslint-scope/-/eslint-scope-7.2.2.tgz", @@ -5591,22 +5788,6 @@ "node": ">= 0.8.0" } }, - "node_modules/prettier": { - "version": "3.9.4", - "resolved": "https://registry.npmjs.org/prettier/-/prettier-3.9.4.tgz", - "integrity": "sha512-yWG/o/4oJfo036EKAfK6ACAoDOfHeRHx4tuxkfBZiauURiaSmYwlpOr5LQqKtIkRD2z1PLteme2WoxEnj4tHTg==", - "dev": true, - "license": "MIT", - "bin": { - "prettier": "bin/prettier.cjs" - }, - "engines": { - "node": ">=14" - }, - "funding": { - "url": "https://github.com/prettier/prettier?sponsor=1" - } - }, "node_modules/pretty-format": { "version": "29.7.0", "resolved": "https://registry.npmjs.org/pretty-format/-/pretty-format-29.7.0.tgz", @@ -6293,6 +6474,19 @@ "node": ">=8.0" } }, + "node_modules/ts-api-utils": { + "version": "1.4.3", + "resolved": "https://registry.npmjs.org/ts-api-utils/-/ts-api-utils-1.4.3.tgz", + "integrity": "sha512-i3eMG77UTMD0hZhgRS562pv83RC6ukSAC2GMNWc+9dieh/+jDM5u5YG+NHX6VNDRHQcHwmsTHctP9LhbC3WxVw==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=16" + }, + "peerDependencies": { + "typescript": ">=4.2.0" + } + }, "node_modules/ts-jest": { "version": "29.4.11", "resolved": "https://registry.npmjs.org/ts-jest/-/ts-jest-29.4.11.tgz", @@ -6757,13 +6951,6 @@ "undici-types": "~5.26.4" } }, - "packages/core/node_modules/undici-types": { - "version": "5.26.5", - "resolved": "https://registry.npmjs.org/undici-types/-/undici-types-5.26.5.tgz", - "integrity": "sha512-JlCMO+ehdEIKqlFxk6IfVoAUVmgz7cU7zD/h9XZ0qzeosSHmUJVOzSQvvYSYWXkFXC+IfLKSIffhv0sVZup6pA==", - "dev": true, - "license": "MIT" - }, "packages/hps": { "name": "@apache-fory/hps", "version": "1.4.0-alpha.0", diff --git a/javascript/package.json b/javascript/package.json index bf4ae68e47..a856cd827a 100644 --- a/javascript/package.json +++ b/javascript/package.json @@ -4,10 +4,8 @@ "test": "npm run build && jest", "clear": "rm -rf ./packages/core/dist && rm -rf ./packages/hps/dist", "build": "npm run clear && npm run build -w packages/core -w packages/hps", - "lint": "npm run format-check", - "lint-fix": "npm run format", - "format": "prettier --write \"{packages,test}/**/*.ts\" && eslint . --fix", - "format-check": "prettier --check \"{packages,test}/**/*.ts\" && eslint ." + "lint": "eslint .", + "lint-fix": "eslint . --fix" }, "repository": "git@github.com:apache/fory.git", "workspaces": [ @@ -15,14 +13,13 @@ "packages/core" ], "devDependencies": { + "@stylistic/eslint-plugin": "^1.5.1", "@types/js-beautify": "^1.14.3", - "@types/node": "18.19.130", + "@types/node": "^18.19.68", "eslint": "^8.55.0", - "eslint-config-prettier": "^10.1.8", "jest": "^29.5.0", "jest-junit": "^17.0.0", "js-beautify": "^1.14.11", - "prettier": "^3.9.4", "ts-jest": "^29.0.2", "typescript": "^4.8.4" }, From 53e103df18435dba20810411e99620902d3437ed Mon Sep 17 00:00:00 2001 From: chaokunyang Date: Thu, 2 Jul 2026 12:33:35 +0800 Subject: [PATCH 23/54] fix: unblock graph budget ci checks --- cpp/fory/serialization/fory.h | 4 +- cpp/fory/serialization/serializer_traits.h | 6 +- .../src/main/java/org/apache/fory/Fory.java | 100 +++++++++--------- .../org/apache/fory/context/ReadContext.java | 4 +- 4 files changed, 56 insertions(+), 58 deletions(-) diff --git a/cpp/fory/serialization/fory.h b/cpp/fory/serialization/fory.h index d6e8cfc0e2..75f1334685 100644 --- a/cpp/fory/serialization/fory.h +++ b/cpp/fory/serialization/fory.h @@ -112,8 +112,8 @@ class ForyBuilder { /// Set maximum estimated graph memory for one root deserialization. /// - /// Defaults to 128 MiB. Positive values are explicit byte limits; non-positive - /// values intentionally disable this protection. + /// Defaults to 128 MiB. Positive values are explicit byte limits; + /// non-positive values intentionally disable this protection. ForyBuilder &max_graph_memory_bytes(int64_t max_bytes) { config_.max_graph_memory_bytes = max_bytes; return *this; diff --git a/cpp/fory/serialization/serializer_traits.h b/cpp/fory/serialization/serializer_traits.h index 2ee4ae5723..38ff4db440 100644 --- a/cpp/fory/serialization/serializer_traits.h +++ b/cpp/fory/serialization/serializer_traits.h @@ -394,9 +394,9 @@ struct has_graph_budget_children>> { private: using Value = std::remove_cv_t>; - using FieldInfo = - decltype(::fory::meta::fory_field_info(std::declval())); - using Ptrs = typename FieldInfo::PtrsType; + using Ptrs = + decltype(::fory::meta::fory_field_info(std::declval()) + .ptrs()); public: static constexpr bool value = struct_has_graph_children_impl( diff --git a/java/fory-core/src/main/java/org/apache/fory/Fory.java b/java/fory-core/src/main/java/org/apache/fory/Fory.java index de76df2bd9..1db63df9b6 100644 --- a/java/fory-core/src/main/java/org/apache/fory/Fory.java +++ b/java/fory-core/src/main/java/org/apache/fory/Fory.java @@ -428,30 +428,6 @@ public T deserialize(MemoryBuffer buffer, Class type) { return deserializeRoot(buffer, type); } - private T deserializeRoot(MemoryBuffer buffer, Class type) { - ensureRegistrationFinished(); - byte bitmap = buffer.readByte(); - if (bitmap != headerBitmap) { - checkHeaderBitmapWithoutOutOfBand(bitmap); - } - readContext.prepare(buffer, null, false); - try { - try { - jitContext.lock(); - if (readContext.getDepth() > 0) { - throwDepthDeserializationException(); - } - return deserializeByType(buffer, type); - } finally { - jitContext.unlock(); - } - } catch (Throwable t) { - throw ExceptionUtils.handleReadFailed(this, t); - } finally { - readContext.reset(); - } - } - @Override public T deserialize(ForyInputStream inputStream, Class type) { try { @@ -494,6 +470,56 @@ public Object deserialize(MemoryBuffer buffer, Iterable outOfBandB return deserializeRoot(buffer, outOfBandBuffers); } + @Override + public Object deserialize(ForyInputStream inputStream) { + return deserialize(inputStream, (Iterable) null); + } + + @Override + public Object deserialize(ForyInputStream inputStream, Iterable outOfBandBuffers) { + try { + MemoryBuffer buf = inputStream.getBuffer(); + return deserializeRoot(buf, outOfBandBuffers); + } finally { + inputStream.shrinkBuffer(); + } + } + + @Override + public Object deserialize(ForyReadableChannel channel) { + return deserialize(channel, (Iterable) null); + } + + @Override + public Object deserialize(ForyReadableChannel channel, Iterable outOfBandBuffers) { + MemoryBuffer buf = channel.getBuffer(); + return deserializeRoot(buf, outOfBandBuffers); + } + + private T deserializeRoot(MemoryBuffer buffer, Class type) { + ensureRegistrationFinished(); + byte bitmap = buffer.readByte(); + if (bitmap != headerBitmap) { + checkHeaderBitmapWithoutOutOfBand(bitmap); + } + readContext.prepare(buffer, null, false); + try { + try { + jitContext.lock(); + if (readContext.getDepth() > 0) { + throwDepthDeserializationException(); + } + return deserializeByType(buffer, type); + } finally { + jitContext.unlock(); + } + } catch (Throwable t) { + throw ExceptionUtils.handleReadFailed(this, t); + } finally { + readContext.reset(); + } + } + private Object deserializeRoot(MemoryBuffer buffer, Iterable outOfBandBuffers) { ensureRegistrationFinished(); byte bitmap = buffer.readByte(); @@ -531,32 +557,6 @@ private Object deserializeRoot(MemoryBuffer buffer, Iterable outOf } } - @Override - public Object deserialize(ForyInputStream inputStream) { - return deserialize(inputStream, (Iterable) null); - } - - @Override - public Object deserialize(ForyInputStream inputStream, Iterable outOfBandBuffers) { - try { - MemoryBuffer buf = inputStream.getBuffer(); - return deserializeRoot(buf, outOfBandBuffers); - } finally { - inputStream.shrinkBuffer(); - } - } - - @Override - public Object deserialize(ForyReadableChannel channel) { - return deserialize(channel, (Iterable) null); - } - - @Override - public Object deserialize(ForyReadableChannel channel, Iterable outOfBandBuffers) { - MemoryBuffer buf = channel.getBuffer(); - return deserializeRoot(buf, outOfBandBuffers); - } - @SuppressWarnings("unchecked") private T deserializeByType(MemoryBuffer buffer, Class type) { readContext diff --git a/java/fory-core/src/main/java/org/apache/fory/context/ReadContext.java b/java/fory-core/src/main/java/org/apache/fory/context/ReadContext.java index e28b31fb25..6ca56c1a0b 100644 --- a/java/fory-core/src/main/java/org/apache/fory/context/ReadContext.java +++ b/java/fory-core/src/main/java/org/apache/fory/context/ReadContext.java @@ -112,9 +112,7 @@ public ReadContext( * flag for one operation. */ public void prepare( - MemoryBuffer buffer, - Iterable outOfBandBuffers, - boolean peerOutOfBandEnabled) { + MemoryBuffer buffer, Iterable outOfBandBuffers, boolean peerOutOfBandEnabled) { this.buffer = buffer; this.peerOutOfBandEnabled = peerOutOfBandEnabled; this.outOfBandBuffers = outOfBandBuffers == null ? null : outOfBandBuffers.iterator(); From ad543b2dca9a4d2b3ad54b3b0dda9e1226e51e1e Mon Sep 17 00:00:00 2001 From: chaokunyang Date: Thu, 2 Jul 2026 12:40:51 +0800 Subject: [PATCH 24/54] style(js): satisfy graph budget lint rules --- javascript/packages/core/lib/context.ts | 196 ++++++++++----------- javascript/packages/core/lib/fory.ts | 30 ++-- javascript/packages/core/lib/gen/ext.ts | 26 +-- javascript/packages/core/lib/gen/map.ts | 98 +++++------ javascript/packages/core/lib/gen/struct.ts | 156 ++++++++-------- javascript/packages/core/lib/type.ts | 12 +- 6 files changed, 259 insertions(+), 259 deletions(-) diff --git a/javascript/packages/core/lib/context.ts b/javascript/packages/core/lib/context.ts index 3fc57aa761..69c21f6a5e 100644 --- a/javascript/packages/core/lib/context.ts +++ b/javascript/packages/core/lib/context.ts @@ -551,8 +551,8 @@ export class ReadContext { private readonly maxGraphMemoryBytes: number; private effectiveGraphMemoryBytes = 0; private remainingGraphMemoryBytes = 0; - private remoteSchemaVersionsByType: Map | undefined = - undefined; + private remoteSchemaVersionsByType: Map | undefined + = undefined; constructor( readonly typeResolver: TypeResolverLike, @@ -571,8 +571,8 @@ export class ReadContext { this.metaStringReader.reset(); this.typeMeta = []; this._depth = 0; - this.effectiveGraphMemoryBytes = - this.maxGraphMemoryBytes > 0 ? this.maxGraphMemoryBytes : 0; + this.effectiveGraphMemoryBytes + = this.maxGraphMemoryBytes > 0 ? this.maxGraphMemoryBytes : 0; this.remainingGraphMemoryBytes = this.effectiveGraphMemoryBytes; } @@ -598,9 +598,9 @@ export class ReadContext { private throwGraphBudgetExceeded(bytes: number): never { throw new Error( - `maxGraphMemoryBytes exceeded: requested ${bytes} estimated graph bytes, ` + - `${this.remainingGraphMemoryBytes} remaining, effective limit ` + - `${this.effectiveGraphMemoryBytes}`, + `maxGraphMemoryBytes exceeded: requested ${bytes} estimated graph bytes, ` + + `${this.remainingGraphMemoryBytes} remaining, effective limit ` + + `${this.effectiveGraphMemoryBytes}`, ); } @@ -612,8 +612,8 @@ export class ReadContext { this._depth++; if (this._depth > this._maxDepth) { throw new Error( - `Deserialization depth limit exceeded: ${this._depth} > ${this._maxDepth}. ` + - "The data may be malicious, or increase maxDepth if needed.", + `Deserialization depth limit exceeded: ${this._depth} > ${this._maxDepth}. ` + + "The data may be malicious, or increase maxDepth if needed.", ); } } @@ -800,14 +800,14 @@ export class ReadContext { expectedTypeName: string, ) { if ( - typeMeta.getTypeId() !== expectedTypeId || - typeMeta.getNs() !== expectedNamespace || - typeMeta.getTypeName() !== expectedTypeName + typeMeta.getTypeId() !== expectedTypeId + || typeMeta.getNs() !== expectedNamespace + || typeMeta.getTypeName() !== expectedTypeName ) { throw new Error( - `TypeMeta mismatch: expected ${expectedNamespace}$${expectedTypeName} ` + - `type ${expectedTypeId}, got ${typeMeta.getNs()}$${typeMeta.getTypeName()} ` + - `type ${typeMeta.getTypeId()}`, + `TypeMeta mismatch: expected ${expectedNamespace}$${expectedTypeName} ` + + `type ${expectedTypeId}, got ${typeMeta.getNs()}$${typeMeta.getTypeName()} ` + + `type ${typeMeta.getTypeId()}`, ); } } @@ -855,8 +855,8 @@ export class ReadContext { } else { const localSerializer = original ?? this.serializerByTypeMeta(typeMeta); if ( - localSerializer === undefined && - !TypeId.structType(typeMeta.getTypeId()) + localSerializer === undefined + && !TypeId.structType(typeMeta.getTypeId()) ) { throw new Error( `can't find serializer for TypeMeta ${typeMeta.getNs()}$${typeMeta.getTypeName()}`, @@ -864,8 +864,8 @@ export class ReadContext { } const typeKey = this.checkRemoteTypeMetaLimit(typeMeta); if ( - localSerializer !== undefined && - TypeId.structType(typeMeta.getTypeId()) + localSerializer !== undefined + && TypeId.structType(typeMeta.getTypeId()) ) { const expectedHash = localHash ?? localSerializer.getHash(); if (expectedHash !== typeMeta.getHash()) { @@ -877,8 +877,8 @@ export class ReadContext { ); } } else if ( - localHash !== undefined && - localHash !== typeMeta.getHash() + localHash !== undefined + && localHash !== typeMeta.getHash() ) { this.ensureCompatibleReadSerializer( typeMeta, @@ -918,33 +918,33 @@ export class ReadContext { : typeMeta.getUserTypeId(); const versionsByType = this.remoteSchemaVersionsByType; const versionsForType = versionsByType?.get(typeKey) ?? 0; - const maxSchemaVersionsPerType = - this.typeResolver.config.maxSchemaVersionsPerType; + const maxSchemaVersionsPerType + = this.typeResolver.config.maxSchemaVersionsPerType; if (versionsForType >= maxSchemaVersionsPerType) { throw new Error( - `Remote schema version limit exceeded for type ${String(typeKey)}: ` + - `${versionsForType} >= ${maxSchemaVersionsPerType}. The data may ` + - "be malicious. If the data is not malicious, please increase " + - "maxSchemaVersionsPerType.", + `Remote schema version limit exceeded for type ${String(typeKey)}: ` + + `${versionsForType} >= ${maxSchemaVersionsPerType}. The data may ` + + "be malicious. If the data is not malicious, please increase " + + "maxSchemaVersionsPerType.", ); } - const acceptedTypeCount = - versionsForType === 0 + const acceptedTypeCount + = versionsForType === 0 ? (versionsByType?.size ?? 0) + 1 : versionsByType!.size; - const maxAverageSchemaVersionsPerType = - this.typeResolver.config.maxAverageSchemaVersionsPerType; + const maxAverageSchemaVersionsPerType + = this.typeResolver.config.maxAverageSchemaVersionsPerType; const globalLimit = Math.max( ReadContext.MIN_REMOTE_TYPE_META_LIMIT, acceptedTypeCount * maxAverageSchemaVersionsPerType, ); if (this.totalAcceptedSchemaVersions >= globalLimit) { throw new Error( - `Remote schema version limit exceeded: ${this.totalAcceptedSchemaVersions} ` + - `metadata versions for ${acceptedTypeCount} accepted remote types ` + - `exceeds the average limit ${maxAverageSchemaVersionsPerType}. ` + - "The data may be malicious. If the data is not malicious, please " + - "increase maxAverageSchemaVersionsPerType.", + `Remote schema version limit exceeded: ${this.totalAcceptedSchemaVersions} ` + + `metadata versions for ${acceptedTypeCount} accepted remote types ` + + `exceeds the average limit ${maxAverageSchemaVersionsPerType}. ` + + "The data may be malicious. If the data is not malicious, please " + + "increase maxAverageSchemaVersionsPerType.", ); } return typeKey; @@ -1043,16 +1043,16 @@ export class ReadContext { return false; } if ( - (remote.trackingRef === true) !== (local.trackingRef === true) || - (remote.nullable === true) !== (local.nullable === true) + (remote.trackingRef === true) !== (local.trackingRef === true) + || (remote.nullable === true) !== (local.nullable === true) ) { return false; } switch (remote.typeId) { case TypeId.MAP: return ( - this.fieldSchemasEqual(remote.options?.key, local.options?.key) && - this.fieldSchemasEqual(remote.options?.value, local.options?.value) + this.fieldSchemasEqual(remote.options?.key, local.options?.key) + && this.fieldSchemasEqual(remote.options?.value, local.options?.value) ); case TypeId.LIST: return this.fieldSchemasEqual( @@ -1083,24 +1083,24 @@ export class ReadContext { return compatible; } if ( - isCompatibleScalarType(fieldInfo.typeId) && - isCompatibleScalarType(fallbackTypeInfo.typeId) && - ((fieldInfo.trackingRef === true) !== - (fallbackTypeInfo.trackingRef === true) || - ((fieldInfo.trackingRef === true || - fallbackTypeInfo.trackingRef === true) && - (fieldInfo.typeId !== fallbackTypeInfo.typeId || - fieldInfo.nullable !== fallbackTypeInfo.nullable))) + isCompatibleScalarType(fieldInfo.typeId) + && isCompatibleScalarType(fallbackTypeInfo.typeId) + && ((fieldInfo.trackingRef === true) + !== (fallbackTypeInfo.trackingRef === true) + || ((fieldInfo.trackingRef === true + || fallbackTypeInfo.trackingRef === true) + && (fieldInfo.typeId !== fallbackTypeInfo.typeId + || fieldInfo.nullable !== fallbackTypeInfo.nullable))) ) { throw new Error( "unsupported compatible scalar tracking-ref schema mismatch", ); } if ( - isCompatibleScalarPair(fieldInfo.typeId, fallbackTypeInfo.typeId) && - fieldInfo.typeId !== fallbackTypeInfo.typeId && - (fieldInfo.trackingRef === true || - fallbackTypeInfo.trackingRef === true) + isCompatibleScalarPair(fieldInfo.typeId, fallbackTypeInfo.typeId) + && fieldInfo.typeId !== fallbackTypeInfo.typeId + && (fieldInfo.trackingRef === true + || fallbackTypeInfo.trackingRef === true) ) { throw new Error( "unsupported compatible scalar tracking-ref schema mismatch", @@ -1116,10 +1116,10 @@ export class ReadContext { throw new Error("unsupported compatible list/array schema mismatch"); } if ( - fieldInfo.typeId !== TypeId.UNKNOWN && - this.canonicalFieldTypeId(fallbackTypeInfo) !== TypeId.UNKNOWN && - this.canonicalTypeId(fieldInfo.typeId) !== - this.canonicalFieldTypeId(fallbackTypeInfo) + fieldInfo.typeId !== TypeId.UNKNOWN + && this.canonicalFieldTypeId(fallbackTypeInfo) !== TypeId.UNKNOWN + && this.canonicalTypeId(fieldInfo.typeId) + !== this.canonicalFieldTypeId(fallbackTypeInfo) ) { throw new Error("unsupported compatible field schema mismatch"); } @@ -1209,31 +1209,31 @@ export class ReadContext { return false; } if ( - this.schemaMatchTypeId(remote.typeId) !== - this.schemaMatchTypeId(this.typeResolver.computeTypeId(local)) + this.schemaMatchTypeId(remote.typeId) + !== this.schemaMatchTypeId(this.typeResolver.computeTypeId(local)) ) { return true; } const remoteTracksRef = remote.trackingRef === true; const localTracksRef = local.trackingRef === true; if ( - remoteTracksRef !== localTracksRef || - ((remoteTracksRef || localTracksRef) && - (remote.nullable === true) !== (local.nullable === true)) + remoteTracksRef !== localTracksRef + || ((remoteTracksRef || localTracksRef) + && (remote.nullable === true) !== (local.nullable === true)) ) { return true; } switch (remote.typeId) { case TypeId.MAP: return ( - local.options?.key === undefined || - local.options?.value === undefined || - this.hasNestedSchemaMismatch( + local.options?.key === undefined + || local.options?.value === undefined + || this.hasNestedSchemaMismatch( remote.options!.key!, local.options.key, false, - ) || - this.hasNestedSchemaMismatch( + ) + || this.hasNestedSchemaMismatch( remote.options!.value!, local.options.value, false, @@ -1241,8 +1241,8 @@ export class ReadContext { ); case TypeId.LIST: return ( - local.options?.inner === undefined || - this.hasNestedSchemaMismatch( + local.options?.inner === undefined + || this.hasNestedSchemaMismatch( remote.options!.inner!, local.options.inner, false, @@ -1250,8 +1250,8 @@ export class ReadContext { ); case TypeId.SET: return ( - local.options?.key === undefined || - this.hasNestedSchemaMismatch( + local.options?.key === undefined + || this.hasNestedSchemaMismatch( remote.options!.key!, local.options.key, false, @@ -1272,19 +1272,19 @@ export class ReadContext { ): TypeInfo | undefined { if (this.isByteSequenceRootPair(remote, local)) { if ( - (remote.nullable === true) !== (local.nullable === true) || - (remote.trackingRef === true) !== (local.trackingRef === true) + (remote.nullable === true) !== (local.nullable === true) + || (remote.trackingRef === true) !== (local.trackingRef === true) ) { return undefined; } return local.clone(); } if ( - this.isListArrayRootPair(remote, local) && - (remote.nullable === true || - local.nullable === true || - remote.trackingRef === true || - local.trackingRef === true) + this.isListArrayRootPair(remote, local) + && (remote.nullable === true + || local.nullable === true + || remote.trackingRef === true + || local.trackingRef === true) ) { return undefined; } @@ -1304,22 +1304,22 @@ export class ReadContext { } const remoteArrayElement = denseArrayElementTypeId(remote.typeId); if ( - remoteArrayElement !== undefined && - local.typeId === TypeId.LIST && - local.options?.inner && - compatibleArrayElementTypeId(local.options.inner.typeId) === - remoteArrayElement + remoteArrayElement !== undefined + && local.typeId === TypeId.LIST + && local.options?.inner + && compatibleArrayElementTypeId(local.options.inner.typeId) + === remoteArrayElement ) { return compatibleArrayToListTypeInfo(remoteArrayElement); } if ( - remote.trackingRef !== true && - local.trackingRef !== true && - !( - remote.typeId === local.typeId && - (remote.nullable === true) === (local.nullable === true) - ) && - isCompatibleScalarPair(remote.typeId, local.typeId) + remote.trackingRef !== true + && local.trackingRef !== true + && !( + remote.typeId === local.typeId + && (remote.nullable === true) === (local.nullable === true) + ) + && isCompatibleScalarPair(remote.typeId, local.typeId) ) { return markCompatibleScalarRead(local.clone(), { remoteTypeId: remote.typeId, @@ -1351,8 +1351,8 @@ export class ReadContext { remote.options!.key!, local.options?.key, false, - ) || - this.hasUnsupportedListArrayMismatch( + ) + || this.hasUnsupportedListArrayMismatch( remote.options!.value!, local.options?.value, false, @@ -1380,10 +1380,10 @@ export class ReadContext { local: TypeInfo, ): boolean { return ( - (remote.typeId === TypeId.LIST && - denseArrayElementTypeId(local.typeId) !== undefined) || - (denseArrayElementTypeId(remote.typeId) !== undefined && - local.typeId === TypeId.LIST) + (remote.typeId === TypeId.LIST + && denseArrayElementTypeId(local.typeId) !== undefined) + || (denseArrayElementTypeId(remote.typeId) !== undefined + && local.typeId === TypeId.LIST) ); } @@ -1392,9 +1392,9 @@ export class ReadContext { local: TypeInfo, ): boolean { return ( - (remote.typeId === TypeId.BINARY && - local.typeId === TypeId.UINT8_ARRAY) || - (remote.typeId === TypeId.UINT8_ARRAY && local.typeId === TypeId.BINARY) + (remote.typeId === TypeId.BINARY + && local.typeId === TypeId.UINT8_ARRAY) + || (remote.typeId === TypeId.UINT8_ARRAY && local.typeId === TypeId.BINARY) ); } diff --git a/javascript/packages/core/lib/fory.ts b/javascript/packages/core/lib/fory.ts index f17f5ab7c5..a7d3944a45 100644 --- a/javascript/packages/core/lib/fory.ts +++ b/javascript/packages/core/lib/fory.ts @@ -78,36 +78,36 @@ export default class Fory { `maxTypeFields must be a positive integer but got ${maxTypeFields}`, ); } - const maxTypeMetaBytes = - config?.maxTypeMetaBytes ?? DEFAULT_MAX_TYPE_META_BYTES; + const maxTypeMetaBytes + = config?.maxTypeMetaBytes ?? DEFAULT_MAX_TYPE_META_BYTES; if (!Number.isInteger(maxTypeMetaBytes) || maxTypeMetaBytes <= 0) { throw new Error( `maxTypeMetaBytes must be a positive integer but got ${maxTypeMetaBytes}`, ); } - const maxSchemaVersionsPerType = - config?.maxSchemaVersionsPerType ?? DEFAULT_MAX_SCHEMA_VERSIONS_PER_TYPE; + const maxSchemaVersionsPerType + = config?.maxSchemaVersionsPerType ?? DEFAULT_MAX_SCHEMA_VERSIONS_PER_TYPE; if ( - !Number.isInteger(maxSchemaVersionsPerType) || - maxSchemaVersionsPerType <= 0 + !Number.isInteger(maxSchemaVersionsPerType) + || maxSchemaVersionsPerType <= 0 ) { throw new Error( `maxSchemaVersionsPerType must be a positive integer but got ${maxSchemaVersionsPerType}`, ); } - const maxAverageSchemaVersionsPerType = - config?.maxAverageSchemaVersionsPerType ?? - DEFAULT_MAX_AVERAGE_SCHEMA_VERSIONS_PER_TYPE; + const maxAverageSchemaVersionsPerType + = config?.maxAverageSchemaVersionsPerType + ?? DEFAULT_MAX_AVERAGE_SCHEMA_VERSIONS_PER_TYPE; if ( - !Number.isInteger(maxAverageSchemaVersionsPerType) || - maxAverageSchemaVersionsPerType <= 0 + !Number.isInteger(maxAverageSchemaVersionsPerType) + || maxAverageSchemaVersionsPerType <= 0 ) { throw new Error( `maxAverageSchemaVersionsPerType must be a positive integer but got ${maxAverageSchemaVersionsPerType}`, ); } - const maxGraphMemoryBytes = - config?.maxGraphMemoryBytes ?? DEFAULT_MAX_GRAPH_MEMORY_BYTES; + const maxGraphMemoryBytes + = config?.maxGraphMemoryBytes ?? DEFAULT_MAX_GRAPH_MEMORY_BYTES; if (!Number.isSafeInteger(maxGraphMemoryBytes)) { throw new Error( `maxGraphMemoryBytes must be a safe integer but got ${maxGraphMemoryBytes}`, @@ -191,8 +191,8 @@ export default class Fory { } private throwInvalidRootHeader(bitmap: number): never { - const knownFlags = - ConfigFlags.isCrossLanguageFlag | ConfigFlags.isOutOfBandFlag; + const knownFlags + = ConfigFlags.isCrossLanguageFlag | ConfigFlags.isOutOfBandFlag; if ((bitmap & ~knownFlags) !== 0) { throw new Error( `unsupported root header bitmap 0x${bitmap.toString(16)}`, diff --git a/javascript/packages/core/lib/gen/ext.ts b/javascript/packages/core/lib/gen/ext.ts index 928e848275..11ec957d94 100644 --- a/javascript/packages/core/lib/gen/ext.ts +++ b/javascript/packages/core/lib/gen/ext.ts @@ -46,8 +46,8 @@ class ExtSerializerGenerator extends BaseSerializerGenerator { private objectGraphBytes(): number { return ( - OBJECT_BYTES + - Object.keys(this.typeInfo.options?.props ?? {}).length * REFERENCE_BYTES + OBJECT_BYTES + + Object.keys(this.typeInfo.options?.props ?? {}).length * REFERENCE_BYTES ); } @@ -82,7 +82,7 @@ class ExtSerializerGenerator extends BaseSerializerGenerator { ${this.readTypeInfo()} ${this.builder.getReadContextName()}.incReadDepth(); let ${result}; - ${this.read((v) => `${result} = ${v}`, refState)}; + ${this.read(v => `${result} = ${v}`, refState)}; ${this.builder.getReadContextName()}.decReadDepth(); ${assignStmt(result)}; `; @@ -132,12 +132,12 @@ class ExtSerializerGenerator extends BaseSerializerGenerator { "ext_ser", TypeId.isNamedType(this.typeInfo.typeId) ? this.builder.typeResolver.getSerializerByName( - CodecBuilder.replaceBackslashAndQuote(this.typeInfo.named!), - ) + CodecBuilder.replaceBackslashAndQuote(this.typeInfo.named!), + ) : this.builder.typeResolver.getSerializerById( - this.typeInfo.typeId, - this.typeInfo.userTypeId, - ), + this.typeInfo.typeId, + this.typeInfo.userTypeId, + ), ); return accessor(`${name}.${prop}(${args.join(",")})`); }; @@ -156,12 +156,12 @@ class ExtSerializerGenerator extends BaseSerializerGenerator { "ext_ser", TypeId.isNamedType(this.typeInfo.typeId) ? this.builder.typeResolver.getSerializerByName( - CodecBuilder.replaceBackslashAndQuote(this.typeInfo.named!), - ) + CodecBuilder.replaceBackslashAndQuote(this.typeInfo.named!), + ) : this.builder.typeResolver.getSerializerById( - this.typeInfo.typeId, - this.typeInfo.userTypeId, - ), + this.typeInfo.typeId, + this.typeInfo.userTypeId, + ), ); return `${name}.${prop}(${accessor})`; }; diff --git a/javascript/packages/core/lib/gen/map.ts b/javascript/packages/core/lib/gen/map.ts index ee402c3882..ff16c969ca 100644 --- a/javascript/packages/core/lib/gen/map.ts +++ b/javascript/packages/core/lib/gen/map.ts @@ -52,9 +52,9 @@ class ElementInfo { return false; } return ( - this.serializer === other.serializer && - this.isNull === other.isNull && - this.trackRef === other.trackRef + this.serializer === other.serializer + && this.isNull === other.isNull + && this.trackRef === other.trackRef ); } } @@ -131,10 +131,10 @@ class MapChunkWriter { } // max size of chunk is 255 if ( - this.chunkSize == 255 || - this.chunkOffset == 0 || - !keyInfo.equalTo(this.preKeyInfo) || - !valueInfo.equalTo(this.preValueInfo) + this.chunkSize == 255 + || this.chunkOffset == 0 + || !keyInfo.equalTo(this.preKeyInfo) + || !valueInfo.equalTo(this.preValueInfo) ) { // new chunk this.endChunk(); @@ -199,12 +199,12 @@ class MapAnySerializer { ); this.writeContext.writer.writeVarUint32Small7(value.size); for (const [k, v] of value.entries()) { - const keySerializer = - this.keySerializer !== null + const keySerializer + = this.keySerializer !== null ? this.keySerializer : this.writeContext.typeResolver.getSerializerByData(k); - const valueSerializer = - this.valueSerializer !== null + const valueSerializer + = this.valueSerializer !== null ? this.valueSerializer : this.writeContext.typeResolver.getSerializerByData(v); @@ -224,8 +224,8 @@ class MapAnySerializer { const valueHeader = header >> 3; if (mapChunkWriter.isFirst()) { if ( - !(keyHeader & MapFlags.HAS_NULL) && - !(valueHeader & MapFlags.HAS_NULL) + !(keyHeader & MapFlags.HAS_NULL) + && !(valueHeader & MapFlags.HAS_NULL) ) { if (!(keyHeader & MapFlags.DECL_ELEMENT_TYPE)) { keySerializer?.writeTypeInfo(null); @@ -236,8 +236,8 @@ class MapAnySerializer { } } - const includeNone = - keyHeader & MapFlags.HAS_NULL || valueHeader & MapFlags.HAS_NULL; + const includeNone + = keyHeader & MapFlags.HAS_NULL || valueHeader & MapFlags.HAS_NULL; if (!this.writeFlag(keyHeader, k)) { if (!includeNone) { keySerializer!.write(k); @@ -269,8 +269,8 @@ class MapAnySerializer { return null; } if (!trackingRef) { - serializer = - serializer == null + serializer + = serializer == null ? AnyHelper.detectSerializer(this.readContext) : serializer; return this.readSerializerWithDepth(serializer!, false); @@ -279,8 +279,8 @@ class MapAnySerializer { const flag = this.readContext.reader.readInt8(); switch (flag) { case RefFlags.RefValueFlag: - serializer = - serializer == null + serializer + = serializer == null ? AnyHelper.detectSerializer(this.readContext) : serializer; return this.readSerializerWithDepth(serializer!, true); @@ -291,8 +291,8 @@ class MapAnySerializer { case RefFlags.NullFlag: return null; case RefFlags.NotNullValueFlag: - serializer = - serializer == null + serializer + = serializer == null ? AnyHelper.detectSerializer(this.readContext) : serializer; return this.readSerializerWithDepth(serializer!, false); @@ -322,8 +322,8 @@ class MapAnySerializer { let valueSerializer = this.valueSerializer; if ( - !(keyHeader & MapFlags.HAS_NULL) && - !(valueHeader & MapFlags.HAS_NULL) + !(keyHeader & MapFlags.HAS_NULL) + && !(valueHeader & MapFlags.HAS_NULL) ) { if (!(keyHeader & MapFlags.DECL_ELEMENT_TYPE)) { keySerializer = AnyHelper.detectSerializer(this.readContext); @@ -369,10 +369,10 @@ export class MapSerializerGenerator extends BaseSerializerGenerator { const keyTypeId = this.typeInfo.options?.key!.typeId; const valueTypeId = this.typeInfo.options?.value!.typeId; return ( - keyTypeId === TypeId.UNKNOWN || - valueTypeId === TypeId.UNKNOWN || - !TypeId.isBuiltin(keyTypeId!) || - !TypeId.isBuiltin(valueTypeId!) + keyTypeId === TypeId.UNKNOWN + || valueTypeId === TypeId.UNKNOWN + || !TypeId.isBuiltin(keyTypeId!) + || !TypeId.isBuiltin(valueTypeId!) ); } @@ -477,12 +477,12 @@ export class MapSerializerGenerator extends BaseSerializerGenerator { "map_inner_ser", TypeId.isNamedType(innerTypeInfo.typeId) ? this.builder.typeResolver.getSerializerByName( - CodecBuilder.replaceBackslashAndQuote(innerTypeInfo.named!), - ) + CodecBuilder.replaceBackslashAndQuote(innerTypeInfo.named!), + ) : this.builder.typeResolver.getSerializerById( - innerTypeInfo.typeId, - innerTypeInfo.userTypeId, - ), + innerTypeInfo.typeId, + innerTypeInfo.userTypeId, + ), ); }; return `new (${anySerializer})(${this.builder.getWriteContextName()}, ${this.builder.getReadContextName()}, ${ @@ -572,12 +572,12 @@ export class MapSerializerGenerator extends BaseSerializerGenerator { switch (flag) { case ${RefFlags.RefValueFlag}: if (${keyDeclaredType}) { - ${readKey((x) => `key = ${x}`, "true")} + ${readKey(x => `key = ${x}`, "true")} } else { if (!${keySerializer}) { ${keySerializer} = ${anyHelper}.detectSerializer(${readContextName}); } - ${readDynamic(keySerializer, (x) => `key = ${x}`, "true")} + ${readDynamic(keySerializer, x => `key = ${x}`, "true")} } break; case ${RefFlags.RefFlag}: @@ -588,23 +588,23 @@ export class MapSerializerGenerator extends BaseSerializerGenerator { break; case ${RefFlags.NotNullValueFlag}: if (${keyDeclaredType}) { - ${readKey((x) => `key = ${x}`, "false")} + ${readKey(x => `key = ${x}`, "false")} } else { if (!${keySerializer}) { ${keySerializer} = ${anyHelper}.detectSerializer(${readContextName}); } - ${readDynamic(keySerializer, (x) => `key = ${x}`, "false")} + ${readDynamic(keySerializer, x => `key = ${x}`, "false")} } break; } } else { if (${keyDeclaredType}) { - ${readKey((x) => `key = ${x}`, "false")} + ${readKey(x => `key = ${x}`, "false")} } else { if (!${keySerializer}) { ${keySerializer} = ${anyHelper}.detectSerializer(${readContextName}); } - ${readDynamic(keySerializer, (x) => `key = ${x}`, "false")} + ${readDynamic(keySerializer, x => `key = ${x}`, "false")} } } @@ -615,12 +615,12 @@ export class MapSerializerGenerator extends BaseSerializerGenerator { switch (flag) { case ${RefFlags.RefValueFlag}: if (${valueDeclaredType}) { - ${readValue((x) => `value = ${x}`, "true")} + ${readValue(x => `value = ${x}`, "true")} } else { if (!${valueSerializer}) { ${valueSerializer} = ${anyHelper}.detectSerializer(${readContextName}); } - ${readDynamic(valueSerializer, (x) => `value = ${x}`, "true")} + ${readDynamic(valueSerializer, x => `value = ${x}`, "true")} } break; case ${RefFlags.RefFlag}: @@ -631,23 +631,23 @@ export class MapSerializerGenerator extends BaseSerializerGenerator { break; case ${RefFlags.NotNullValueFlag}: if (${valueDeclaredType}) { - ${readValue((x) => `value = ${x}`, "false")} + ${readValue(x => `value = ${x}`, "false")} } else { if (!${valueSerializer}) { ${valueSerializer} = ${anyHelper}.detectSerializer(${readContextName}); } - ${readDynamic(valueSerializer, (x) => `value = ${x}`, "false")} + ${readDynamic(valueSerializer, x => `value = ${x}`, "false")} } break; } } else { if (${valueDeclaredType}) { - ${readValue((x) => `value = ${x}`, "false")} + ${readValue(x => `value = ${x}`, "false")} } else { if (!${valueSerializer}) { ${valueSerializer} = ${anyHelper}.detectSerializer(${readContextName}); } - ${readDynamic(valueSerializer, (x) => `value = ${x}`, "false")} + ${readDynamic(valueSerializer, x => `value = ${x}`, "false")} } } @@ -672,12 +672,12 @@ export class MapSerializerGenerator extends BaseSerializerGenerator { "map_inner_ser", TypeId.isNamedType(innerTypeInfo.typeId) ? this.builder.typeResolver.getSerializerByName( - CodecBuilder.replaceBackslashAndQuote(innerTypeInfo.named!), - ) + CodecBuilder.replaceBackslashAndQuote(innerTypeInfo.named!), + ) : this.builder.typeResolver.getSerializerById( - innerTypeInfo.typeId, - innerTypeInfo.userTypeId, - ), + innerTypeInfo.typeId, + innerTypeInfo.userTypeId, + ), ); }; return accessor( diff --git a/javascript/packages/core/lib/gen/struct.ts b/javascript/packages/core/lib/gen/struct.ts index bec05a6446..801a34f4e5 100644 --- a/javascript/packages/core/lib/gen/struct.ts +++ b/javascript/packages/core/lib/gen/struct.ts @@ -51,10 +51,10 @@ function isDepthFreeField(typeInfo: TypeInfo): boolean { const key = typeInfo.options?.key; const value = typeInfo.options?.value; return ( - !!key && - !!value && - TypeId.isLeafTypeId(key.typeId) && - TypeId.isLeafTypeId(value.typeId) + !!key + && !!value + && TypeId.isLeafTypeId(key.typeId) + && TypeId.isLeafTypeId(value.typeId) ); } return false; @@ -79,8 +79,8 @@ const sortProps = ( const props = typeInfo.options!.props; if (typeInfo.options!.preserveFieldOrder) { return ( - typeInfo.options!.fieldEntries ?? - Object.entries(props!).map(([key, fieldTypeInfo]) => ({ + typeInfo.options!.fieldEntries + ?? Object.entries(props!).map(([key, fieldTypeInfo]) => ({ key, typeInfo: fieldTypeInfo, })) @@ -128,18 +128,18 @@ function varInt32ObjectReadKind( typeResolver: CodecBuilder["resolver"], ): "number" | "bigint" | null { if ( - toRefMode(typeInfo.trackingRef, typeInfo.nullable) !== RefMode.NONE || - !typeResolver.isMonomorphic(typeInfo, typeInfo.dynamic) + toRefMode(typeInfo.trackingRef, typeInfo.nullable) !== RefMode.NONE + || !typeResolver.isMonomorphic(typeInfo, typeInfo.dynamic) ) { return null; } const scalarAction = getCompatibleScalarReadAction(typeInfo); if (scalarAction !== undefined) { - return scalarAction.remoteNullable !== true && - scalarAction.remoteTypeId === TypeId.VARINT32 && - (scalarAction.localTypeId === TypeId.INT64 || - scalarAction.localTypeId === TypeId.VARINT64 || - scalarAction.localTypeId === TypeId.TAGGED_INT64) + return scalarAction.remoteNullable !== true + && scalarAction.remoteTypeId === TypeId.VARINT32 + && (scalarAction.localTypeId === TypeId.INT64 + || scalarAction.localTypeId === TypeId.VARINT64 + || scalarAction.localTypeId === TypeId.TAGGED_INT64) ? "bigint" : null; } @@ -151,9 +151,9 @@ function directNumericFieldReadExpr( builder: CodecBuilder, ): string | null { if ( - toRefMode(typeInfo.trackingRef, typeInfo.nullable) !== RefMode.NONE || - !builder.resolver.isMonomorphic(typeInfo, typeInfo.dynamic) || - getCompatibleScalarReadAction(typeInfo) !== undefined + toRefMode(typeInfo.trackingRef, typeInfo.nullable) !== RefMode.NONE + || !builder.resolver.isMonomorphic(typeInfo, typeInfo.dynamic) + || getCompatibleScalarReadAction(typeInfo) !== undefined ) { return null; } @@ -468,19 +468,19 @@ function integerRangeFits(remoteTypeId: number, localTypeId: number): boolean { return remoteTypeId === TypeId.INT8 || remoteTypeId === TypeId.UINT8; case TypeId.INT32: return ( - remoteTypeId === TypeId.INT8 || - remoteTypeId === TypeId.INT16 || - remoteTypeId === TypeId.UINT8 || - remoteTypeId === TypeId.UINT16 + remoteTypeId === TypeId.INT8 + || remoteTypeId === TypeId.INT16 + || remoteTypeId === TypeId.UINT8 + || remoteTypeId === TypeId.UINT16 ); case TypeId.INT64: return ( - remoteTypeId === TypeId.INT8 || - remoteTypeId === TypeId.INT16 || - remoteTypeId === TypeId.INT32 || - remoteTypeId === TypeId.UINT8 || - remoteTypeId === TypeId.UINT16 || - remoteTypeId === TypeId.UINT32 + remoteTypeId === TypeId.INT8 + || remoteTypeId === TypeId.INT16 + || remoteTypeId === TypeId.INT32 + || remoteTypeId === TypeId.UINT8 + || remoteTypeId === TypeId.UINT16 + || remoteTypeId === TypeId.UINT32 ); case TypeId.UINT16: return remoteTypeId === TypeId.UINT8; @@ -488,9 +488,9 @@ function integerRangeFits(remoteTypeId: number, localTypeId: number): boolean { return remoteTypeId === TypeId.UINT8 || remoteTypeId === TypeId.UINT16; case TypeId.UINT64: return ( - remoteTypeId === TypeId.UINT8 || - remoteTypeId === TypeId.UINT16 || - remoteTypeId === TypeId.UINT32 + remoteTypeId === TypeId.UINT8 + || remoteTypeId === TypeId.UINT16 + || remoteTypeId === TypeId.UINT32 ); default: return false; @@ -555,19 +555,19 @@ function integerRangeFitsFloat( return remoteTypeId === TypeId.INT8 || remoteTypeId === TypeId.UINT8; case TypeId.FLOAT32: return ( - remoteTypeId === TypeId.INT8 || - remoteTypeId === TypeId.INT16 || - remoteTypeId === TypeId.UINT8 || - remoteTypeId === TypeId.UINT16 + remoteTypeId === TypeId.INT8 + || remoteTypeId === TypeId.INT16 + || remoteTypeId === TypeId.UINT8 + || remoteTypeId === TypeId.UINT16 ); case TypeId.FLOAT64: return ( - remoteTypeId === TypeId.INT8 || - remoteTypeId === TypeId.INT16 || - remoteTypeId === TypeId.INT32 || - remoteTypeId === TypeId.UINT8 || - remoteTypeId === TypeId.UINT16 || - remoteTypeId === TypeId.UINT32 + remoteTypeId === TypeId.INT8 + || remoteTypeId === TypeId.INT16 + || remoteTypeId === TypeId.INT32 + || remoteTypeId === TypeId.UINT8 + || remoteTypeId === TypeId.UINT16 + || remoteTypeId === TypeId.UINT32 ); default: return false; @@ -601,8 +601,8 @@ class StructSerializerGenerator extends BaseSerializerGenerator { private isDepthFreeStruct(): boolean { return ( - this.sortedProps.length > 0 && - this.sortedProps.every(({ typeInfo }) => isDepthFreeField(typeInfo)) + this.sortedProps.length > 0 + && this.sortedProps.every(({ typeInfo }) => isDepthFreeField(typeInfo)) ); } @@ -704,7 +704,7 @@ class StructSerializerGenerator extends BaseSerializerGenerator { const noneedWrite = this.scope.uniqueName("noneedWrite"); stmt = ` let ${noneedWrite} = false; - ${embedGenerator.writeRefOrNull(fieldAccessor, (expr) => `${noneedWrite} = ${expr}`)} + ${embedGenerator.writeRefOrNull(fieldAccessor, expr => `${noneedWrite} = ${expr}`)} if (!${noneedWrite}) { ${embedGenerator.write(fieldAccessor)} } @@ -754,21 +754,21 @@ class StructSerializerGenerator extends BaseSerializerGenerator { write(accessor: string): string { if ( - !this.typeInfo.options?.props || - Object.keys(this.typeInfo.options.props).length === 0 + !this.typeInfo.options?.props + || Object.keys(this.typeInfo.options.props).length === 0 ) { const hash = this.typeMeta.computeStructHash(); return `${!this.builder.resolver.isCompatible() ? this.builder.writer.writeInt32(hash) : ""}`; } const hash = this.typeMeta.computeStructHash(); const fieldWrites: string[] = []; - for (let i = 0; i < this.sortedProps.length; ) { + for (let i = 0; i < this.sortedProps.length;) { const current = this.sortedProps[i]; if (isDirectVarInt32Field(current.typeInfo, this.builder.resolver)) { let end = i + 1; while ( - end < this.sortedProps.length && - isDirectVarInt32Field( + end < this.sortedProps.length + && isDirectVarInt32Field( this.sortedProps[end].typeInfo, this.builder.resolver, ) @@ -880,8 +880,8 @@ class StructSerializerGenerator extends BaseSerializerGenerator { const result = this.scope.uniqueName("result"); const hash = this.typeMeta.computeStructHash(); if ( - !this.typeInfo.options?.props || - Object.keys(this.typeInfo.options.props).length === 0 + !this.typeInfo.options?.props + || Object.keys(this.typeInfo.options.props).length === 0 ) { return ` ${ @@ -966,7 +966,7 @@ class StructSerializerGenerator extends BaseSerializerGenerator { this.scope, ); return ` - ${this.readField(key, typeInfo, (expr) => `${result}${CodecBuilder.safePropAccessor(key)} = ${expr}`, innerGenerator.readEmbed())} + ${this.readField(key, typeInfo, expr => `${result}${CodecBuilder.safePropAccessor(key)} = ${expr}`, innerGenerator.readEmbed())} `; }) .join(";\n")} @@ -986,8 +986,8 @@ class StructSerializerGenerator extends BaseSerializerGenerator { return varInt32ObjectRead; } if ( - this.typeInfo.options!.withConstructor || - this.sortedProps.length === 0 + this.typeInfo.options!.withConstructor + || this.sortedProps.length === 0 ) { return null; } @@ -997,15 +997,15 @@ class StructSerializerGenerator extends BaseSerializerGenerator { return null; } const scalarAction = getCompatibleScalarReadAction(typeInfo); - const expr = - scalarAction?.remoteNullable === true + const expr + = scalarAction?.remoteNullable === true ? null : scalarAction ? compatibleScalarFieldReadExpr( - scalarAction.remoteTypeId, - scalarAction.localTypeId, - this.builder, - ) + scalarAction.remoteTypeId, + scalarAction.localTypeId, + this.builder, + ) : directNumericFieldReadExpr(typeInfo, this.builder); if (expr === null) { return null; @@ -1028,8 +1028,8 @@ class StructSerializerGenerator extends BaseSerializerGenerator { refState: string, ): string | null { if ( - this.typeInfo.options!.withConstructor || - this.sortedProps.length === 0 + this.typeInfo.options!.withConstructor + || this.sortedProps.length === 0 ) { return null; } @@ -1121,8 +1121,8 @@ class StructSerializerGenerator extends BaseSerializerGenerator { readWithDepth(assignStmt: (v: string) => string, refState: string): string { if ( - !this.typeInfo.options?.props || - Object.keys(this.typeInfo.options.props).length === 0 + !this.typeInfo.options?.props + || Object.keys(this.typeInfo.options.props).length === 0 ) { const result = this.scope.uniqueName("result"); return ` @@ -1138,11 +1138,11 @@ class StructSerializerGenerator extends BaseSerializerGenerator { readNoRef(assignStmt: (v: string) => string, refState: string): string { const result = this.scope.uniqueName("result"); if ( - !this.typeInfo.options?.props || - Object.keys(this.typeInfo.options.props).length === 0 + !this.typeInfo.options?.props + || Object.keys(this.typeInfo.options.props).length === 0 ) { return this.readTypeInfoThen( - (changedSerializer) => ` + changedSerializer => ` ${assignStmt(`${changedSerializer}.read(${refState})`)}; `, () => ` @@ -1156,25 +1156,25 @@ class StructSerializerGenerator extends BaseSerializerGenerator { } if (this.isDepthFreeStruct()) { return this.readTypeInfoThen( - (changedSerializer) => ` + changedSerializer => ` ${assignStmt(`${changedSerializer}.read(${refState})`)}; `, () => ` let ${result}; - ${this.read((v) => `${result} = ${v}`, refState)}; + ${this.read(v => `${result} = ${v}`, refState)}; ${assignStmt(result)}; `, true, ); } return this.readTypeInfoThen( - (changedSerializer) => ` + changedSerializer => ` ${assignStmt(`${changedSerializer}.read(${refState})`)}; `, () => ` ${this.builder.getReadContextName()}.incReadDepth(); let ${result}; - ${this.read((v) => `${result} = ${v}`, refState)}; + ${this.read(v => `${result} = ${v}`, refState)}; ${this.builder.getReadContextName()}.decReadDepth(); ${assignStmt(result)}; `, @@ -1254,13 +1254,13 @@ class StructSerializerGenerator extends BaseSerializerGenerator { const builder = this.builder; const internalTypeId = this.getInternalTypeId(); const serializer = builder.resolver.getSerializerByTypeInfo(this.typeInfo); - const canInlineCompatibleTypeInfo = - internalTypeId === TypeId.COMPATIBLE_STRUCT || - internalTypeId === TypeId.NAMED_COMPATIBLE_STRUCT || - (internalTypeId === TypeId.NAMED_STRUCT && - builder.resolver.isCompatible()); - const canUseHeaderCacheFastPath = - canInlineCompatibleTypeInfo && serializer?._initialized; + const canInlineCompatibleTypeInfo + = internalTypeId === TypeId.COMPATIBLE_STRUCT + || internalTypeId === TypeId.NAMED_COMPATIBLE_STRUCT + || (internalTypeId === TypeId.NAMED_STRUCT + && builder.resolver.isCompatible()); + const canUseHeaderCacheFastPath + = canInlineCompatibleTypeInfo && serializer?._initialized; const inlineCompatibleTypeInfo = ( onMetaChanged: (changedSerializer: string) => string, onMetaUnchanged: () => string, @@ -1297,7 +1297,7 @@ class StructSerializerGenerator extends BaseSerializerGenerator { const result = scope.uniqueName("result"); return ` ${inlineCompatibleTypeInfo( - (changedSerializer) => + changedSerializer => `${accessor(`${changedSerializer}.read(${refState})`)};`, () => ` ${builder.getReadContextName()}.incReadDepth(); @@ -1322,7 +1322,7 @@ class StructSerializerGenerator extends BaseSerializerGenerator { ${result} = ${builder.referenceResolver.getReadRef(builder.reader.readVarUInt32())}; } else { ${inlineCompatibleTypeInfo( - (changedSerializer) => + changedSerializer => `${result} = ${changedSerializer}.read(${refFlag} === ${RefFlags.RefValueFlag});`, () => ` ${builder.getReadContextName()}.incReadDepth(); diff --git a/javascript/packages/core/lib/type.ts b/javascript/packages/core/lib/type.ts index eb65815987..1a58b2a3e8 100644 --- a/javascript/packages/core/lib/type.ts +++ b/javascript/packages/core/lib/type.ts @@ -174,12 +174,12 @@ export const TypeId = { }, userDefinedType(id: number) { return ( - this.structType(id) || - this.extType(id) || - this.enumType(id) || - id == TypeId.UNION || - id == TypeId.TYPED_UNION || - id == TypeId.NAMED_UNION + this.structType(id) + || this.extType(id) + || this.enumType(id) + || id == TypeId.UNION + || id == TypeId.TYPED_UNION + || id == TypeId.NAMED_UNION ); }, isBuiltin(id: number) { From 29a3298e109474d1a67c7903e2096961c196838b Mon Sep 17 00:00:00 2001 From: chaokunyang Date: Thu, 2 Jul 2026 12:51:35 +0800 Subject: [PATCH 25/54] style(python): apply graph budget formatter --- python/pyfory/_fory.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/python/pyfory/_fory.py b/python/pyfory/_fory.py index 9082d00acd..27a5d56d8f 100644 --- a/python/pyfory/_fory.py +++ b/python/pyfory/_fory.py @@ -219,11 +219,7 @@ def __init__( raise ValueError("max_schema_versions_per_type must be a positive integer") if not isinstance(max_average_schema_versions_per_type, int) or max_average_schema_versions_per_type <= 0: raise ValueError("max_average_schema_versions_per_type must be a positive integer") - if ( - not isinstance(max_graph_memory_bytes, int) - or max_graph_memory_bytes > (1 << 63) - 1 - or max_graph_memory_bytes < -(1 << 63) - ): + if not isinstance(max_graph_memory_bytes, int) or max_graph_memory_bytes > (1 << 63) - 1 or max_graph_memory_bytes < -(1 << 63): raise ValueError("max_graph_memory_bytes must be a 63-bit integer") self.max_graph_memory_bytes = max_graph_memory_bytes self.config = Config( From e30de0395212bc8f29b06eb44baf80b2ec3253ac Mon Sep 17 00:00:00 2001 From: chaokunyang Date: Thu, 2 Jul 2026 13:00:32 +0800 Subject: [PATCH 26/54] style(cpp): apply graph budget clang format --- .../serialization/collection_serializer.h | 20 ++++++------ cpp/fory/serialization/context.cc | 8 ++--- cpp/fory/serialization/context.h | 3 +- .../serialization/graph_memory_budget_test.cc | 32 +++++++++---------- cpp/fory/serialization/map_serializer.h | 16 ++++------ 5 files changed, 36 insertions(+), 43 deletions(-) diff --git a/cpp/fory/serialization/collection_serializer.h b/cpp/fory/serialization/collection_serializer.h index bf1e9cc46b..8bdd5bb7fc 100644 --- a/cpp/fory/serialization/collection_serializer.h +++ b/cpp/fory/serialization/collection_serializer.h @@ -394,18 +394,16 @@ template inline bool reserve_collection_storage(ReadContext &ctx, uint32_t length) { constexpr size_t kMaxLength = static_cast(std::numeric_limits::max()); - if constexpr (elem_bytes <= - std::numeric_limits::max() / kMaxLength) { + if constexpr (elem_bytes <= std::numeric_limits::max() / kMaxLength) { return ctx.reserve_graph_memory(static_cast(length) * elem_bytes); } else { - if (FORY_PREDICT_FALSE( - elem_bytes != 0 && - static_cast(length) > - std::numeric_limits::max() / elem_bytes)) { + if (FORY_PREDICT_FALSE(elem_bytes != 0 && + static_cast(length) > + std::numeric_limits::max() / + elem_bytes)) { ctx.set_error(Error::invalid_data( - "graph memory estimate overflows: length=" + - std::to_string(length) + " elementBytes=" + - std::to_string(elem_bytes))); + "graph memory estimate overflows: length=" + std::to_string(length) + + " elementBytes=" + std::to_string(elem_bytes))); return false; } return ctx.reserve_graph_memory(static_cast(length) * elem_bytes); @@ -444,8 +442,8 @@ inline bool reserve_collection(std::vector &result, if (FORY_PREDICT_FALSE(length_bytes > std::numeric_limits::max() - 7)) { ctx.set_error(Error::invalid_data( - "graph memory estimate overflows: length=" + - std::to_string(length) + " elementBytes=1")); + "graph memory estimate overflows: length=" + std::to_string(length) + + " elementBytes=1")); return false; } const size_t packed_bytes = (length_bytes + 7) / 8; diff --git a/cpp/fory/serialization/context.cc b/cpp/fory/serialization/context.cc index 2342a2d219..0cec51704a 100644 --- a/cpp/fory/serialization/context.cc +++ b/cpp/fory/serialization/context.cc @@ -435,10 +435,10 @@ ReadContext::ReadContext(const Config &config, std::unique_ptr type_resolver) : buffer_(nullptr), config_(&config), type_resolver_(std::move(type_resolver)), current_dyn_depth_(0), - graph_memory_limit_bytes_(config.max_graph_memory_bytes > 0 - ? static_cast( - config.max_graph_memory_bytes) - : size_t{0}) {} + graph_memory_limit_bytes_( + config.max_graph_memory_bytes > 0 + ? static_cast(config.max_graph_memory_bytes) + : size_t{0}) {} ReadContext::~ReadContext() = default; diff --git a/cpp/fory/serialization/context.h b/cpp/fory/serialization/context.h index 601ffd7285..cc09617458 100644 --- a/cpp/fory/serialization/context.h +++ b/cpp/fory/serialization/context.h @@ -525,8 +525,7 @@ class ReadContext { FORY_ALWAYS_INLINE bool reserve_graph_memory(size_t bytes) { const size_t remaining = remaining_graph_memory_bytes_; - if (FORY_PREDICT_FALSE(remaining == - std::numeric_limits::max())) { + if (FORY_PREDICT_FALSE(remaining == std::numeric_limits::max())) { return true; } if (FORY_PREDICT_FALSE(bytes > remaining)) { diff --git a/cpp/fory/serialization/graph_memory_budget_test.cc b/cpp/fory/serialization/graph_memory_budget_test.cc index b8edb86abc..f6c69d6b0b 100644 --- a/cpp/fory/serialization/graph_memory_budget_test.cc +++ b/cpp/fory/serialization/graph_memory_budget_test.cc @@ -140,7 +140,8 @@ TEST(GraphMemoryBudgetTest, FixedDefaultBudgetAndDisable) { disabled_config.max_graph_memory_bytes = 0; ReadContext disabled(disabled_config, std::make_unique()); ASSERT_TRUE(disabled.init_graph_budget()); - ASSERT_TRUE(disabled.reserve_graph_memory(std::numeric_limits::max())); + ASSERT_TRUE( + disabled.reserve_graph_memory(std::numeric_limits::max())); } TEST(GraphMemoryBudgetTest, RootKindsShareConfiguredBudget) { @@ -149,33 +150,30 @@ TEST(GraphMemoryBudgetTest, RootKindsShareConfiguredBudget) { auto bytes = serialize_value(value); const size_t required = nested_empty_budget(count); - auto byte_result = with_fory(static_cast(required - 1), - [&](Fory &fory) { - return fory.deserialize< - std::vector>>( - bytes); - }); + auto byte_result = + with_fory(static_cast(required - 1), [&](Fory &fory) { + return fory.deserialize>>(bytes); + }); ASSERT_FALSE(byte_result.ok()); EXPECT_EQ(byte_result.error().code(), ErrorCode::InvalidData); std::string input(reinterpret_cast(bytes.data()), bytes.size()); std::istringstream source(input); StdInputStream stream(source, 8); - auto stream_result = with_fory(static_cast(required - 1), - [&](Fory &fory) { - return fory.deserialize< - std::vector>>( - stream); - }); + auto stream_result = + with_fory(static_cast(required - 1), [&](Fory &fory) { + return fory.deserialize>>(stream); + }); ASSERT_FALSE(stream_result.ok()); EXPECT_EQ(stream_result.error().code(), ErrorCode::InvalidData); std::istringstream exact_source(input); StdInputStream exact_stream(exact_source, 8); - auto exact_result = with_fory(static_cast(required), [&](Fory &fory) { - return fory.deserialize>>( - exact_stream); - }); + auto exact_result = + with_fory(static_cast(required), [&](Fory &fory) { + return fory.deserialize>>( + exact_stream); + }); ASSERT_TRUE(exact_result.ok()) << exact_result.error().to_string(); EXPECT_EQ(exact_result.value(), value); } diff --git a/cpp/fory/serialization/map_serializer.h b/cpp/fory/serialization/map_serializer.h index 9d95d5c042..12fcdac775 100644 --- a/cpp/fory/serialization/map_serializer.h +++ b/cpp/fory/serialization/map_serializer.h @@ -86,18 +86,16 @@ template inline bool reserve_map_storage(ReadContext &ctx, uint32_t length) { constexpr size_t kMaxLength = static_cast(std::numeric_limits::max()); - if constexpr (elem_bytes <= - std::numeric_limits::max() / kMaxLength) { + if constexpr (elem_bytes <= std::numeric_limits::max() / kMaxLength) { return ctx.reserve_graph_memory(static_cast(length) * elem_bytes); } else { - if (FORY_PREDICT_FALSE( - elem_bytes != 0 && - static_cast(length) > - std::numeric_limits::max() / elem_bytes)) { + if (FORY_PREDICT_FALSE(elem_bytes != 0 && + static_cast(length) > + std::numeric_limits::max() / + elem_bytes)) { ctx.set_error(Error::invalid_data( - "graph memory estimate overflows: length=" + - std::to_string(length) + " elementBytes=" + - std::to_string(elem_bytes))); + "graph memory estimate overflows: length=" + std::to_string(length) + + " elementBytes=" + std::to_string(elem_bytes))); return false; } return ctx.reserve_graph_memory(static_cast(length) * elem_bytes); From 3cc6a02ed07e2ae29e3847f448824518fb58c03b Mon Sep 17 00:00:00 2001 From: chaokunyang Date: Thu, 2 Jul 2026 23:55:22 +0800 Subject: [PATCH 27/54] style(swift): apply swift-format after rebase --- swift/Sources/Fory/AnySerializer.swift | 1236 +++---- .../Sources/Fory/CollectionSerializers.swift | 1880 +++++----- swift/Sources/Fory/FieldCodecs.swift | 3182 +++++++++-------- swift/Sources/Fory/FieldSkipper.swift | 696 ++-- swift/Sources/Fory/Fory.swift | 1083 +++--- swift/Sources/Fory/ReadContext.swift | 1344 +++---- .../ForyObjectMacroReadGeneration.swift | 1106 +++--- swift/Tests/ForyTests/ForySwiftTests.swift | 2522 ++++++------- .../ForyTests/GraphMemoryBudgetTests.swift | 526 +-- 9 files changed, 6800 insertions(+), 6775 deletions(-) diff --git a/swift/Sources/Fory/AnySerializer.swift b/swift/Sources/Fory/AnySerializer.swift index b1b0fd0782..a3926eed53 100644 --- a/swift/Sources/Fory/AnySerializer.swift +++ b/swift/Sources/Fory/AnySerializer.swift @@ -22,754 +22,760 @@ private let anyArrayOwnerBytes = max(1, MemoryLayout<[Any]>.stride) @inline(never) private func throwAnyGraphMemoryOverflow() throws -> Never { - throw ForyError.invalidData("graph memory estimate overflows") + throw ForyError.invalidData("graph memory estimate overflows") } @inline(__always) private func reserveAnyReferenceArrayMemory(_ context: ReadContext, count: Int) throws { - let (slotBytes, overflow) = count.multipliedReportingOverflow(by: anyReferenceBytes) - if overflow { - try throwAnyGraphMemoryOverflow() - } - let (bytes, addOverflow) = anyArrayOwnerBytes.addingReportingOverflow(slotBytes) - if addOverflow { - try throwAnyGraphMemoryOverflow() - } - try context.reserveGraphMemory(bytes) + let (slotBytes, overflow) = count.multipliedReportingOverflow(by: anyReferenceBytes) + if overflow { + try throwAnyGraphMemoryOverflow() + } + let (bytes, addOverflow) = anyArrayOwnerBytes.addingReportingOverflow(slotBytes) + if addOverflow { + try throwAnyGraphMemoryOverflow() + } + try context.reserveGraphMemory(bytes) } @inline(__always) -private func reserveAnyReferenceMapMemory(_ context: ReadContext, _ type: Map.Type, count: Int) - throws +private func reserveAnyReferenceMapMemory( + _ context: ReadContext, _ type: Map.Type, count: Int +) + throws { - let (slotBytes, overflow) = count.multipliedReportingOverflow(by: 2 * anyReferenceBytes) - if overflow { - try throwAnyGraphMemoryOverflow() - } - let ownerBytes = max(1, MemoryLayout.stride) - let (bytes, addOverflow) = ownerBytes.addingReportingOverflow(slotBytes) - if addOverflow { - try throwAnyGraphMemoryOverflow() - } - try context.reserveGraphMemory(bytes) + let (slotBytes, overflow) = count.multipliedReportingOverflow(by: 2 * anyReferenceBytes) + if overflow { + try throwAnyGraphMemoryOverflow() + } + let ownerBytes = max(1, MemoryLayout.stride) + let (bytes, addOverflow) = ownerBytes.addingReportingOverflow(slotBytes) + if addOverflow { + try throwAnyGraphMemoryOverflow() + } + try context.reserveGraphMemory(bytes) } public struct ForyAnyNullValue: Serializer { - public init() {} + public init() {} - public static func foryDefault() -> ForyAnyNullValue { - ForyAnyNullValue() - } + public static func foryDefault() -> ForyAnyNullValue { + ForyAnyNullValue() + } - public static var staticTypeId: TypeId { - .none - } + public static var staticTypeId: TypeId { + .none + } - public static func foryWriteStaticTypeInfo(_ context: WriteContext) throws { - context.writeStaticTypeInfo(staticTypeId) - } + public static func foryWriteStaticTypeInfo(_ context: WriteContext) throws { + context.writeStaticTypeInfo(staticTypeId) + } - public static func foryReadTypeInfo(_ context: ReadContext) throws -> TypeInfo? { - try context.readStaticTypeInfo(staticTypeId) - } + public static func foryReadTypeInfo(_ context: ReadContext) throws -> TypeInfo? { + try context.readStaticTypeInfo(staticTypeId) + } - public var foryIsNone: Bool { - true - } + public var foryIsNone: Bool { + true + } - public func foryWriteData(_ context: WriteContext, hasGenerics: Bool) throws { - _ = context - _ = hasGenerics - } + public func foryWriteData(_ context: WriteContext, hasGenerics: Bool) throws { + _ = context + _ = hasGenerics + } - public static func foryReadData(_ context: ReadContext) throws -> ForyAnyNullValue { - _ = context - return ForyAnyNullValue() - } + public static func foryReadData(_ context: ReadContext) throws -> ForyAnyNullValue { + _ = context + return ForyAnyNullValue() + } } extension AnyHashable: Serializer { - public static func foryDefault() -> AnyHashable { - AnyHashable(Int32(0)) - } - - public static var staticTypeId: TypeId { - .unknown - } - - public func foryWriteData(_ context: WriteContext, hasGenerics: Bool) throws { - try writeAnyPayload(base, context: context, hasGenerics: hasGenerics) - } - - public static func foryReadData(_ context: ReadContext) throws -> AnyHashable { - _ = context - throw ForyError.invalidData( - "dynamic AnyHashable key read requires type info; foryReadData should not be called directly" - ) - } - - public static func foryReadCompatibleData(_ context: ReadContext, remoteTypeInfo: TypeInfo) throws - -> AnyHashable - { - let typeInfo = remoteTypeInfo - if typeInfo.typeID == .none { - throw ForyError.invalidData("dynamic AnyHashable key cannot be null") - } - let decoded = try context.readAnyValue(typeInfo: typeInfo) - return try toAnyHashableKey(decoded) - } - - public static func foryWriteStaticTypeInfo(_ context: WriteContext) throws { - _ = context - throw ForyError.invalidData("dynamic AnyHashable key type info is runtime-only") - } - - public func foryWriteTypeInfo(_ context: WriteContext) throws { - try writeAnyTypeInfo(base, context: context) - } - - public static func foryReadTypeInfo(_ context: ReadContext) throws -> TypeInfo? { - try context.readTypeInfo() - } - - public func foryWrite( - _ context: WriteContext, - refMode: RefMode, - writeTypeInfo: Bool, - hasGenerics: Bool - ) throws { - if refMode != .none { - context.buffer.writeInt8(RefFlag.notNullValue.rawValue) + public static func foryDefault() -> AnyHashable { + AnyHashable(Int32(0)) + } + + public static var staticTypeId: TypeId { + .unknown } - if writeTypeInfo { - try foryWriteTypeInfo(context) + + public func foryWriteData(_ context: WriteContext, hasGenerics: Bool) throws { + try writeAnyPayload(base, context: context, hasGenerics: hasGenerics) + } + + public static func foryReadData(_ context: ReadContext) throws -> AnyHashable { + _ = context + throw ForyError.invalidData( + "dynamic AnyHashable key read requires type info; foryReadData should not be called directly" + ) + } + + public static func foryReadCompatibleData( + _ context: ReadContext, remoteTypeInfo: TypeInfo + ) throws + -> AnyHashable + { + let typeInfo = remoteTypeInfo + if typeInfo.typeID == .none { + throw ForyError.invalidData("dynamic AnyHashable key cannot be null") + } + let decoded = try context.readAnyValue(typeInfo: typeInfo) + return try toAnyHashableKey(decoded) + } + + public static func foryWriteStaticTypeInfo(_ context: WriteContext) throws { + _ = context + throw ForyError.invalidData("dynamic AnyHashable key type info is runtime-only") + } + + public func foryWriteTypeInfo(_ context: WriteContext) throws { + try writeAnyTypeInfo(base, context: context) + } + + public static func foryReadTypeInfo(_ context: ReadContext) throws -> TypeInfo? { + try context.readTypeInfo() + } + + public func foryWrite( + _ context: WriteContext, + refMode: RefMode, + writeTypeInfo: Bool, + hasGenerics: Bool + ) throws { + if refMode != .none { + context.buffer.writeInt8(RefFlag.notNullValue.rawValue) + } + if writeTypeInfo { + try foryWriteTypeInfo(context) + } + try foryWriteData(context, hasGenerics: hasGenerics) } - try foryWriteData(context, hasGenerics: hasGenerics) - } } private protocol OptionalTypeMarker { - static var noneValue: Self { get } + static var noneValue: Self { get } } extension Optional: OptionalTypeMarker { - static var noneValue: Wrapped? { nil } + static var noneValue: Wrapped? { nil } } struct SerializableAny: Serializer { - var value: Any = ForyAnyNullValue() - - init(_ value: Any) { - self.value = value - } - - static func foryDefault() -> SerializableAny { - SerializableAny(ForyAnyNullValue()) - } - - static var staticTypeId: TypeId { - .unknown - } + var value: Any = ForyAnyNullValue() - static var isNullableType: Bool { - true - } - - static var isRefType: Bool { - true - } + init(_ value: Any) { + self.value = value + } - var foryIsNone: Bool { - value is ForyAnyNullValue - } + static func foryDefault() -> SerializableAny { + SerializableAny(ForyAnyNullValue()) + } - static func wrapped(_ value: Any?) -> SerializableAny { - guard let value else { - return .foryDefault() + static var staticTypeId: TypeId { + .unknown } - guard let unwrapped = unwrapOptionalAny(value) else { - return .foryDefault() + + static var isNullableType: Bool { + true } - if unwrapped is NSNull { - return .foryDefault() + + static var isRefType: Bool { + true } - return SerializableAny(unwrapped) - } - func anyValue() -> Any? { - foryIsNone ? nil : value - } + var foryIsNone: Bool { + value is ForyAnyNullValue + } - func anyValueForCollection() -> Any { - foryIsNone ? NSNull() : value - } + static func wrapped(_ value: Any?) -> SerializableAny { + guard let value else { + return .foryDefault() + } + guard let unwrapped = unwrapOptionalAny(value) else { + return .foryDefault() + } + if unwrapped is NSNull { + return .foryDefault() + } + return SerializableAny(unwrapped) + } - func foryWriteData(_ context: WriteContext, hasGenerics: Bool) throws { - if foryIsNone { - return + func anyValue() -> Any? { + foryIsNone ? nil : value } - try writeAnyPayload(value, context: context, hasGenerics: hasGenerics) - } - static func foryReadData(_ context: ReadContext) throws -> SerializableAny { - _ = context - throw ForyError.invalidData( - "dynamic Any read requires type info; foryReadData should not be called directly" - ) - } + func anyValueForCollection() -> Any { + foryIsNone ? NSNull() : value + } - static func foryReadCompatibleData(_ context: ReadContext, remoteTypeInfo: TypeInfo) throws - -> SerializableAny - { - let typeInfo = remoteTypeInfo - if typeInfo.typeID == .none { - return .foryDefault() + func foryWriteData(_ context: WriteContext, hasGenerics: Bool) throws { + if foryIsNone { + return + } + try writeAnyPayload(value, context: context, hasGenerics: hasGenerics) } - return SerializableAny(try context.readAnyValue(typeInfo: typeInfo)) - } - static func foryWriteStaticTypeInfo(_ context: WriteContext) throws { - _ = context - throw ForyError.invalidData("dynamic Any value type info is runtime-only") - } + static func foryReadData(_ context: ReadContext) throws -> SerializableAny { + _ = context + throw ForyError.invalidData( + "dynamic Any read requires type info; foryReadData should not be called directly" + ) + } - func foryWriteTypeInfo(_ context: WriteContext) throws { - if foryIsNone { - context.buffer.writeUInt8(UInt8(truncatingIfNeeded: TypeId.none.rawValue)) - return + static func foryReadCompatibleData( + _ context: ReadContext, remoteTypeInfo: TypeInfo + ) throws + -> SerializableAny + { + let typeInfo = remoteTypeInfo + if typeInfo.typeID == .none { + return .foryDefault() + } + return SerializableAny(try context.readAnyValue(typeInfo: typeInfo)) } - try writeAnyTypeInfo(value, context: context) - } - static func foryReadTypeInfo(_ context: ReadContext) throws -> TypeInfo? { - try context.readTypeInfo() - } + static func foryWriteStaticTypeInfo(_ context: WriteContext) throws { + _ = context + throw ForyError.invalidData("dynamic Any value type info is runtime-only") + } - func foryWrite( - _ context: WriteContext, - refMode: RefMode, - writeTypeInfo: Bool, - hasGenerics: Bool - ) throws { - if refMode != .none { - if foryIsNone { - context.buffer.writeInt8(RefFlag.null.rawValue) - return - } - context.buffer.writeInt8(RefFlag.notNullValue.rawValue) + func foryWriteTypeInfo(_ context: WriteContext) throws { + if foryIsNone { + context.buffer.writeUInt8(UInt8(truncatingIfNeeded: TypeId.none.rawValue)) + return + } + try writeAnyTypeInfo(value, context: context) } - if writeTypeInfo { - try foryWriteTypeInfo(context) + static func foryReadTypeInfo(_ context: ReadContext) throws -> TypeInfo? { + try context.readTypeInfo() } - try foryWriteData(context, hasGenerics: hasGenerics) - } - static func foryRead( - _ context: ReadContext, - refMode: RefMode, - readTypeInfo: Bool - ) throws -> SerializableAny { - @inline(__always) - func requireDynamicTypeInfo() throws -> TypeInfo { - if readTypeInfo { - guard let remoteTypeInfo = try foryReadTypeInfo(context) else { - throw ForyError.invalidData("dynamic Any value requires type info") + func foryWrite( + _ context: WriteContext, + refMode: RefMode, + writeTypeInfo: Bool, + hasGenerics: Bool + ) throws { + if refMode != .none { + if foryIsNone { + context.buffer.writeInt8(RefFlag.null.rawValue) + return + } + context.buffer.writeInt8(RefFlag.notNullValue.rawValue) } - return remoteTypeInfo - } - guard let remoteTypeInfo = context.getTypeInfo(for: Self.self) else { - throw ForyError.invalidData("dynamic Any value requires type info") - } - return remoteTypeInfo - } - - if refMode != .none { - let rawFlag = try context.buffer.readInt8() - guard let flag = RefFlag(rawValue: rawFlag) else { - throw ForyError.refError("invalid ref flag \(rawFlag)") - } - - switch flag { - case .null: - return .foryDefault() - case .ref: - let refID = try context.buffer.readVarUInt32() - let referenced = try context.refReader.readRefValue(refID) - if let value = referenced as? SerializableAny { - return value + + if writeTypeInfo { + try foryWriteTypeInfo(context) } - if referenced is NSNull { - return .foryDefault() + try foryWriteData(context, hasGenerics: hasGenerics) + } + + static func foryRead( + _ context: ReadContext, + refMode: RefMode, + readTypeInfo: Bool + ) throws -> SerializableAny { + @inline(__always) + func requireDynamicTypeInfo() throws -> TypeInfo { + if readTypeInfo { + guard let remoteTypeInfo = try foryReadTypeInfo(context) else { + throw ForyError.invalidData("dynamic Any value requires type info") + } + return remoteTypeInfo + } + guard let remoteTypeInfo = context.getTypeInfo(for: Self.self) else { + throw ForyError.invalidData("dynamic Any value requires type info") + } + return remoteTypeInfo } - return SerializableAny(referenced) - case .refValue: - let reservedRefID = context.trackRef ? context.refReader.reserveRefID() : nil - let remoteTypeInfo = try requireDynamicTypeInfo() - let value = try foryReadCompatibleData(context, remoteTypeInfo: remoteTypeInfo) - if let reservedRefID { - if let object = value.value as AnyObject? { - context.refReader.storeRef(object, at: reservedRefID) - } else { - context.refReader.storeRef(value, at: reservedRefID) - } + + if refMode != .none { + let rawFlag = try context.buffer.readInt8() + guard let flag = RefFlag(rawValue: rawFlag) else { + throw ForyError.refError("invalid ref flag \(rawFlag)") + } + + switch flag { + case .null: + return .foryDefault() + case .ref: + let refID = try context.buffer.readVarUInt32() + let referenced = try context.refReader.readRefValue(refID) + if let value = referenced as? SerializableAny { + return value + } + if referenced is NSNull { + return .foryDefault() + } + return SerializableAny(referenced) + case .refValue: + let reservedRefID = context.trackRef ? context.refReader.reserveRefID() : nil + let remoteTypeInfo = try requireDynamicTypeInfo() + let value = try foryReadCompatibleData(context, remoteTypeInfo: remoteTypeInfo) + if let reservedRefID { + if let object = value.value as AnyObject? { + context.refReader.storeRef(object, at: reservedRefID) + } else { + context.refReader.storeRef(value, at: reservedRefID) + } + } + return value + case .notNullValue: + break + } } - return value - case .notNullValue: - break - } - } - return try foryReadCompatibleData(context, remoteTypeInfo: requireDynamicTypeInfo()) - } + return try foryReadCompatibleData(context, remoteTypeInfo: requireDynamicTypeInfo()) + } } private func unwrapOptionalAny(_ value: Any) -> Any? { - let mirror = Mirror(reflecting: value) - guard mirror.displayStyle == .optional else { - return value - } - guard let (_, child) = mirror.children.first else { - return nil - } - return child + let mirror = Mirror(reflecting: value) + guard mirror.displayStyle == .optional else { + return value + } + guard let (_, child) = mirror.children.first else { + return nil + } + return child } private func toAnyHashableKey(_ value: Any) throws -> AnyHashable { - if let anyHashable = value as? AnyHashable { - return anyHashable - } - if value is ForyAnyNullValue { - throw ForyError.invalidData("dynamic AnyHashable key cannot be null") - } - guard let hashableValue = value as? any Hashable else { - throw ForyError.invalidData("dynamic AnyHashable key must be Hashable, got \(type(of: value))") - } - return AnyHashable(hashableValue) + if let anyHashable = value as? AnyHashable { + return anyHashable + } + if value is ForyAnyNullValue { + throw ForyError.invalidData("dynamic AnyHashable key cannot be null") + } + guard let hashableValue = value as? any Hashable else { + throw ForyError.invalidData("dynamic AnyHashable key must be Hashable, got \(type(of: value))") + } + return AnyHashable(hashableValue) } @inline(never) private func hasExactRuntimeType(_ value: Any, _: T.Type) -> Bool { - Swift.type(of: value) == T.self + Swift.type(of: value) == T.self } @inline(never) private func writePrimitiveArrayAnyTypeInfo(_ value: Any, context: WriteContext) -> Bool { - if hasExactRuntimeType(value, [Bool].self) { - context.buffer.writeUInt8(UInt8(truncatingIfNeeded: TypeId.boolArray.rawValue)) - return true - } - if hasExactRuntimeType(value, [Int8].self) { - context.buffer.writeUInt8(UInt8(truncatingIfNeeded: TypeId.int8Array.rawValue)) - return true - } - if hasExactRuntimeType(value, [Int16].self) { - context.buffer.writeUInt8(UInt8(truncatingIfNeeded: TypeId.int16Array.rawValue)) - return true - } - if hasExactRuntimeType(value, [Int32].self) { - context.buffer.writeUInt8(UInt8(truncatingIfNeeded: TypeId.int32Array.rawValue)) - return true - } - if hasExactRuntimeType(value, [Int64].self) { - context.buffer.writeUInt8(UInt8(truncatingIfNeeded: TypeId.int64Array.rawValue)) - return true - } - if hasExactRuntimeType(value, [UInt8].self) { - context.buffer.writeUInt8(UInt8(truncatingIfNeeded: TypeId.uint8Array.rawValue)) - return true - } - if hasExactRuntimeType(value, [UInt16].self) { - context.buffer.writeUInt8(UInt8(truncatingIfNeeded: TypeId.uint16Array.rawValue)) - return true - } - if hasExactRuntimeType(value, [UInt32].self) { - context.buffer.writeUInt8(UInt8(truncatingIfNeeded: TypeId.uint32Array.rawValue)) - return true - } - if hasExactRuntimeType(value, [UInt64].self) { - context.buffer.writeUInt8(UInt8(truncatingIfNeeded: TypeId.uint64Array.rawValue)) - return true - } - if hasExactRuntimeType(value, [Float16].self) { - context.buffer.writeUInt8(UInt8(truncatingIfNeeded: TypeId.float16Array.rawValue)) - return true - } - if hasExactRuntimeType(value, [BFloat16].self) { - context.buffer.writeUInt8(UInt8(truncatingIfNeeded: TypeId.bfloat16Array.rawValue)) - return true - } - if hasExactRuntimeType(value, [Float].self) { - context.buffer.writeUInt8(UInt8(truncatingIfNeeded: TypeId.float32Array.rawValue)) - return true - } - if hasExactRuntimeType(value, [Double].self) { - context.buffer.writeUInt8(UInt8(truncatingIfNeeded: TypeId.float64Array.rawValue)) - return true - } - return false + if hasExactRuntimeType(value, [Bool].self) { + context.buffer.writeUInt8(UInt8(truncatingIfNeeded: TypeId.boolArray.rawValue)) + return true + } + if hasExactRuntimeType(value, [Int8].self) { + context.buffer.writeUInt8(UInt8(truncatingIfNeeded: TypeId.int8Array.rawValue)) + return true + } + if hasExactRuntimeType(value, [Int16].self) { + context.buffer.writeUInt8(UInt8(truncatingIfNeeded: TypeId.int16Array.rawValue)) + return true + } + if hasExactRuntimeType(value, [Int32].self) { + context.buffer.writeUInt8(UInt8(truncatingIfNeeded: TypeId.int32Array.rawValue)) + return true + } + if hasExactRuntimeType(value, [Int64].self) { + context.buffer.writeUInt8(UInt8(truncatingIfNeeded: TypeId.int64Array.rawValue)) + return true + } + if hasExactRuntimeType(value, [UInt8].self) { + context.buffer.writeUInt8(UInt8(truncatingIfNeeded: TypeId.uint8Array.rawValue)) + return true + } + if hasExactRuntimeType(value, [UInt16].self) { + context.buffer.writeUInt8(UInt8(truncatingIfNeeded: TypeId.uint16Array.rawValue)) + return true + } + if hasExactRuntimeType(value, [UInt32].self) { + context.buffer.writeUInt8(UInt8(truncatingIfNeeded: TypeId.uint32Array.rawValue)) + return true + } + if hasExactRuntimeType(value, [UInt64].self) { + context.buffer.writeUInt8(UInt8(truncatingIfNeeded: TypeId.uint64Array.rawValue)) + return true + } + if hasExactRuntimeType(value, [Float16].self) { + context.buffer.writeUInt8(UInt8(truncatingIfNeeded: TypeId.float16Array.rawValue)) + return true + } + if hasExactRuntimeType(value, [BFloat16].self) { + context.buffer.writeUInt8(UInt8(truncatingIfNeeded: TypeId.bfloat16Array.rawValue)) + return true + } + if hasExactRuntimeType(value, [Float].self) { + context.buffer.writeUInt8(UInt8(truncatingIfNeeded: TypeId.float32Array.rawValue)) + return true + } + if hasExactRuntimeType(value, [Double].self) { + context.buffer.writeUInt8(UInt8(truncatingIfNeeded: TypeId.float64Array.rawValue)) + return true + } + return false } @inline(never) private func writePrimitiveArrayAnyPayload(_ value: Any, context: WriteContext) -> Bool { - if hasExactRuntimeType(value, [Bool].self), let array = value as? [Bool] { - writePrimitiveArray(array, context: context) - return true - } - if hasExactRuntimeType(value, [Int8].self), let array = value as? [Int8] { - writePrimitiveArray(array, context: context) - return true - } - if hasExactRuntimeType(value, [Int16].self), let array = value as? [Int16] { - writePrimitiveArray(array, context: context) - return true - } - if hasExactRuntimeType(value, [Int32].self), let array = value as? [Int32] { - writePrimitiveArray(array, context: context) - return true - } - if hasExactRuntimeType(value, [Int64].self), let array = value as? [Int64] { - writePrimitiveArray(array, context: context) - return true - } - if hasExactRuntimeType(value, [UInt8].self), let array = value as? [UInt8] { - writePrimitiveArray(array, context: context) - return true - } - if hasExactRuntimeType(value, [UInt16].self), let array = value as? [UInt16] { - writePrimitiveArray(array, context: context) - return true - } - if hasExactRuntimeType(value, [UInt32].self), let array = value as? [UInt32] { - writePrimitiveArray(array, context: context) - return true - } - if hasExactRuntimeType(value, [UInt64].self), let array = value as? [UInt64] { - writePrimitiveArray(array, context: context) - return true - } - if hasExactRuntimeType(value, [Float16].self), let array = value as? [Float16] { - writePrimitiveArray(array, context: context) - return true - } - if hasExactRuntimeType(value, [BFloat16].self), let array = value as? [BFloat16] { - writePrimitiveArray(array, context: context) - return true - } - if hasExactRuntimeType(value, [Float].self), let array = value as? [Float] { - writePrimitiveArray(array, context: context) - return true - } - if hasExactRuntimeType(value, [Double].self), let array = value as? [Double] { - writePrimitiveArray(array, context: context) - return true - } - return false + if hasExactRuntimeType(value, [Bool].self), let array = value as? [Bool] { + writePrimitiveArray(array, context: context) + return true + } + if hasExactRuntimeType(value, [Int8].self), let array = value as? [Int8] { + writePrimitiveArray(array, context: context) + return true + } + if hasExactRuntimeType(value, [Int16].self), let array = value as? [Int16] { + writePrimitiveArray(array, context: context) + return true + } + if hasExactRuntimeType(value, [Int32].self), let array = value as? [Int32] { + writePrimitiveArray(array, context: context) + return true + } + if hasExactRuntimeType(value, [Int64].self), let array = value as? [Int64] { + writePrimitiveArray(array, context: context) + return true + } + if hasExactRuntimeType(value, [UInt8].self), let array = value as? [UInt8] { + writePrimitiveArray(array, context: context) + return true + } + if hasExactRuntimeType(value, [UInt16].self), let array = value as? [UInt16] { + writePrimitiveArray(array, context: context) + return true + } + if hasExactRuntimeType(value, [UInt32].self), let array = value as? [UInt32] { + writePrimitiveArray(array, context: context) + return true + } + if hasExactRuntimeType(value, [UInt64].self), let array = value as? [UInt64] { + writePrimitiveArray(array, context: context) + return true + } + if hasExactRuntimeType(value, [Float16].self), let array = value as? [Float16] { + writePrimitiveArray(array, context: context) + return true + } + if hasExactRuntimeType(value, [BFloat16].self), let array = value as? [BFloat16] { + writePrimitiveArray(array, context: context) + return true + } + if hasExactRuntimeType(value, [Float].self), let array = value as? [Float] { + writePrimitiveArray(array, context: context) + return true + } + if hasExactRuntimeType(value, [Double].self), let array = value as? [Double] { + writePrimitiveArray(array, context: context) + return true + } + return false } private func writeAnyTypeInfo(_ value: Any, context: WriteContext) throws { - if writePrimitiveArrayAnyTypeInfo(value, context: context) { - return - } - - if let serializer = value as? any Serializer { - try serializer.foryWriteTypeInfo(context) - return - } - - if value is [Any] { - context.buffer.writeUInt8(UInt8(truncatingIfNeeded: TypeId.list.rawValue)) - return - } - if value is [String: Any] || value is [Int32: Any] || value is [AnyHashable: Any] { - context.buffer.writeUInt8(UInt8(truncatingIfNeeded: TypeId.map.rawValue)) - return - } - - throw ForyError.invalidData("unsupported dynamic Any runtime type \(type(of: value))") -} + if writePrimitiveArrayAnyTypeInfo(value, context: context) { + return + } -private func writeAnyPayload(_ value: Any, context: WriteContext, hasGenerics: Bool) throws { - try context.enterDynamicAnyDepth() - defer { context.leaveDynamicAnyDepth() } + if let serializer = value as? any Serializer { + try serializer.foryWriteTypeInfo(context) + return + } - if writePrimitiveArrayAnyPayload(value, context: context) { - return - } + if value is [Any] { + context.buffer.writeUInt8(UInt8(truncatingIfNeeded: TypeId.list.rawValue)) + return + } + if value is [String: Any] || value is [Int32: Any] || value is [AnyHashable: Any] { + context.buffer.writeUInt8(UInt8(truncatingIfNeeded: TypeId.map.rawValue)) + return + } - if let serializer = value as? any Serializer { - if type(of: serializer).isRefType { - try serializer.foryWrite( - context, - refMode: .tracking, - writeTypeInfo: false, - hasGenerics: hasGenerics - ) - } else { - try serializer.foryWriteData(context, hasGenerics: hasGenerics) - } - return - } - if let list = value as? [Any] { - try writeListOfAny(list, context: context, refMode: .none, hasGenerics: hasGenerics) - return - } - if let map = value as? [String: Any] { - // Always include key type info for dynamic map payload. - try writeMapStringToAny(map, context: context, refMode: .none, hasGenerics: false) - return - } - if let map = value as? [Int32: Any] { - // Always include key type info for dynamic map payload. - try writeMapInt32ToAny(map, context: context, refMode: .none, hasGenerics: false) - return - } - if let map = value as? [AnyHashable: Any] { - // Always include key type info for dynamic map payload. - try writeMapAnyHashableToAny(map, context: context, refMode: .none, hasGenerics: false) - return - } - throw ForyError.invalidData("unsupported dynamic Any runtime type \(type(of: value))") + throw ForyError.invalidData("unsupported dynamic Any runtime type \(type(of: value))") } -public func castAnyDynamicValue(_ value: Any?, to type: T.Type) throws -> T { - _ = type - func castNilSentinel(_ sentinel: Any) throws -> T { - guard let casted = sentinel as? T else { - throw ForyError.invalidData("cannot cast dynamic Any value to \(type)") +private func writeAnyPayload(_ value: Any, context: WriteContext, hasGenerics: Bool) throws { + try context.enterDynamicAnyDepth() + defer { context.leaveDynamicAnyDepth() } + + if writePrimitiveArrayAnyPayload(value, context: context) { + return } - return casted - } - if value == nil { - if T.self == Any.self { - return try castNilSentinel(ForyAnyNullValue()) + if let serializer = value as? any Serializer { + if type(of: serializer).isRefType { + try serializer.foryWrite( + context, + refMode: .tracking, + writeTypeInfo: false, + hasGenerics: hasGenerics + ) + } else { + try serializer.foryWriteData(context, hasGenerics: hasGenerics) + } + return } - if T.self == AnyObject.self { - return try castNilSentinel(NSNull()) + if let list = value as? [Any] { + try writeListOfAny(list, context: context, refMode: .none, hasGenerics: hasGenerics) + return } - if T.self == (any Serializer).self { - return try castNilSentinel(ForyAnyNullValue()) + if let map = value as? [String: Any] { + // Always include key type info for dynamic map payload. + try writeMapStringToAny(map, context: context, refMode: .none, hasGenerics: false) + return } - if let optionalType = T.self as? any OptionalTypeMarker.Type { - return try castNilSentinel(optionalType.noneValue) + if let map = value as? [Int32: Any] { + // Always include key type info for dynamic map payload. + try writeMapInt32ToAny(map, context: context, refMode: .none, hasGenerics: false) + return + } + if let map = value as? [AnyHashable: Any] { + // Always include key type info for dynamic map payload. + try writeMapAnyHashableToAny(map, context: context, refMode: .none, hasGenerics: false) + return + } + throw ForyError.invalidData("unsupported dynamic Any runtime type \(type(of: value))") +} + +public func castAnyDynamicValue(_ value: Any?, to type: T.Type) throws -> T { + _ = type + func castNilSentinel(_ sentinel: Any) throws -> T { + guard let casted = sentinel as? T else { + throw ForyError.invalidData("cannot cast dynamic Any value to \(type)") + } + return casted + } + + if value == nil { + if T.self == Any.self { + return try castNilSentinel(ForyAnyNullValue()) + } + if T.self == AnyObject.self { + return try castNilSentinel(NSNull()) + } + if T.self == (any Serializer).self { + return try castNilSentinel(ForyAnyNullValue()) + } + if let optionalType = T.self as? any OptionalTypeMarker.Type { + return try castNilSentinel(optionalType.noneValue) + } } - } - guard let typed = value as? T else { - throw ForyError.invalidData("cannot cast dynamic Any value to \(type)") - } - return typed + guard let typed = value as? T else { + throw ForyError.invalidData("cannot cast dynamic Any value to \(type)") + } + return typed } public func writeAny( - _ value: Any?, - context: WriteContext, - refMode: RefMode, - writeTypeInfo: Bool = true, - hasGenerics: Bool = false + _ value: Any?, + context: WriteContext, + refMode: RefMode, + writeTypeInfo: Bool = true, + hasGenerics: Bool = false ) throws { - try SerializableAny.wrapped(value).foryWrite( - context, - refMode: refMode, - writeTypeInfo: writeTypeInfo, - hasGenerics: hasGenerics - ) + try SerializableAny.wrapped(value).foryWrite( + context, + refMode: refMode, + writeTypeInfo: writeTypeInfo, + hasGenerics: hasGenerics + ) } public func readAny( - context: ReadContext, - refMode: RefMode, - readTypeInfo: Bool = true + context: ReadContext, + refMode: RefMode, + readTypeInfo: Bool = true ) throws -> Any? { - try SerializableAny.foryRead(context, refMode: refMode, readTypeInfo: readTypeInfo).anyValue() + try SerializableAny.foryRead(context, refMode: refMode, readTypeInfo: readTypeInfo).anyValue() } public func writeListOfAny( - _ value: [Any]?, - context: WriteContext, - refMode: RefMode, - writeTypeInfo: Bool = false, - hasGenerics: Bool = true + _ value: [Any]?, + context: WriteContext, + refMode: RefMode, + writeTypeInfo: Bool = false, + hasGenerics: Bool = true ) throws { - let wrapped = value?.map { SerializableAny.wrapped($0) } - try wrapped.foryWrite( - context, - refMode: refMode, - writeTypeInfo: writeTypeInfo, - hasGenerics: hasGenerics - ) + let wrapped = value?.map { SerializableAny.wrapped($0) } + try wrapped.foryWrite( + context, + refMode: refMode, + writeTypeInfo: writeTypeInfo, + hasGenerics: hasGenerics + ) } public func readListOfAny( - context: ReadContext, - refMode: RefMode, - readTypeInfo: Bool = false + context: ReadContext, + refMode: RefMode, + readTypeInfo: Bool = false ) throws -> [Any]? { - let wrapped: [SerializableAny]? = try [SerializableAny]?.foryRead( - context, - refMode: refMode, - readTypeInfo: readTypeInfo - ) - guard let wrapped else { - return nil - } - try reserveAnyReferenceArrayMemory(context, count: wrapped.count) - return wrapped.map { $0.anyValueForCollection() } + let wrapped: [SerializableAny]? = try [SerializableAny]?.foryRead( + context, + refMode: refMode, + readTypeInfo: readTypeInfo + ) + guard let wrapped else { + return nil + } + try reserveAnyReferenceArrayMemory(context, count: wrapped.count) + return wrapped.map { $0.anyValueForCollection() } } public func writeMapStringToAny( - _ value: [String: Any]?, - context: WriteContext, - refMode: RefMode, - writeTypeInfo: Bool = false, - hasGenerics: Bool = true + _ value: [String: Any]?, + context: WriteContext, + refMode: RefMode, + writeTypeInfo: Bool = false, + hasGenerics: Bool = true ) throws { - let wrapped = value?.reduce(into: [String: SerializableAny]()) { result, pair in - result[pair.key] = SerializableAny.wrapped(pair.value) - } - try wrapped.foryWrite( - context, - refMode: refMode, - writeTypeInfo: writeTypeInfo, - hasGenerics: hasGenerics - ) + let wrapped = value?.reduce(into: [String: SerializableAny]()) { result, pair in + result[pair.key] = SerializableAny.wrapped(pair.value) + } + try wrapped.foryWrite( + context, + refMode: refMode, + writeTypeInfo: writeTypeInfo, + hasGenerics: hasGenerics + ) } public func readMapStringToAny( - context: ReadContext, - refMode: RefMode, - readTypeInfo: Bool = false + context: ReadContext, + refMode: RefMode, + readTypeInfo: Bool = false ) throws -> [String: Any]? { - let wrapped: [String: SerializableAny]? = try [String: SerializableAny]?.foryRead( - context, - refMode: refMode, - readTypeInfo: readTypeInfo - ) - guard let wrapped else { - return nil - } - try reserveAnyReferenceMapMemory(context, [String: Any].self, count: wrapped.count) - var map: [String: Any] = [:] - map.reserveCapacity(wrapped.count) - for pair in wrapped { - map[pair.key] = pair.value.anyValueForCollection() - } - return map + let wrapped: [String: SerializableAny]? = try [String: SerializableAny]?.foryRead( + context, + refMode: refMode, + readTypeInfo: readTypeInfo + ) + guard let wrapped else { + return nil + } + try reserveAnyReferenceMapMemory(context, [String: Any].self, count: wrapped.count) + var map: [String: Any] = [:] + map.reserveCapacity(wrapped.count) + for pair in wrapped { + map[pair.key] = pair.value.anyValueForCollection() + } + return map } public func writeMapInt32ToAny( - _ value: [Int32: Any]?, - context: WriteContext, - refMode: RefMode, - writeTypeInfo: Bool = false, - hasGenerics: Bool = true + _ value: [Int32: Any]?, + context: WriteContext, + refMode: RefMode, + writeTypeInfo: Bool = false, + hasGenerics: Bool = true ) throws { - let wrapped = value?.reduce(into: [Int32: SerializableAny]()) { result, pair in - result[pair.key] = SerializableAny.wrapped(pair.value) - } - try wrapped.foryWrite( - context, - refMode: refMode, - writeTypeInfo: writeTypeInfo, - hasGenerics: hasGenerics - ) + let wrapped = value?.reduce(into: [Int32: SerializableAny]()) { result, pair in + result[pair.key] = SerializableAny.wrapped(pair.value) + } + try wrapped.foryWrite( + context, + refMode: refMode, + writeTypeInfo: writeTypeInfo, + hasGenerics: hasGenerics + ) } public func readMapInt32ToAny( - context: ReadContext, - refMode: RefMode, - readTypeInfo: Bool = false + context: ReadContext, + refMode: RefMode, + readTypeInfo: Bool = false ) throws -> [Int32: Any]? { - let wrapped: [Int32: SerializableAny]? = try [Int32: SerializableAny]?.foryRead( - context, - refMode: refMode, - readTypeInfo: readTypeInfo - ) - guard let wrapped else { - return nil - } - try reserveAnyReferenceMapMemory(context, [Int32: Any].self, count: wrapped.count) - var map: [Int32: Any] = [:] - map.reserveCapacity(wrapped.count) - for pair in wrapped { - map[pair.key] = pair.value.anyValueForCollection() - } - return map + let wrapped: [Int32: SerializableAny]? = try [Int32: SerializableAny]?.foryRead( + context, + refMode: refMode, + readTypeInfo: readTypeInfo + ) + guard let wrapped else { + return nil + } + try reserveAnyReferenceMapMemory(context, [Int32: Any].self, count: wrapped.count) + var map: [Int32: Any] = [:] + map.reserveCapacity(wrapped.count) + for pair in wrapped { + map[pair.key] = pair.value.anyValueForCollection() + } + return map } public func writeMapAnyHashableToAny( - _ value: [AnyHashable: Any]?, - context: WriteContext, - refMode: RefMode, - writeTypeInfo: Bool = false, - hasGenerics: Bool = true + _ value: [AnyHashable: Any]?, + context: WriteContext, + refMode: RefMode, + writeTypeInfo: Bool = false, + hasGenerics: Bool = true ) throws { - let wrapped = value?.reduce(into: [AnyHashable: SerializableAny]()) { result, pair in - result[pair.key] = SerializableAny.wrapped(pair.value) - } - try wrapped.foryWrite( - context, - refMode: refMode, - writeTypeInfo: writeTypeInfo, - hasGenerics: hasGenerics - ) + let wrapped = value?.reduce(into: [AnyHashable: SerializableAny]()) { result, pair in + result[pair.key] = SerializableAny.wrapped(pair.value) + } + try wrapped.foryWrite( + context, + refMode: refMode, + writeTypeInfo: writeTypeInfo, + hasGenerics: hasGenerics + ) } public func readMapAnyHashableToAny( - context: ReadContext, - refMode: RefMode, - readTypeInfo: Bool = false + context: ReadContext, + refMode: RefMode, + readTypeInfo: Bool = false ) throws -> [AnyHashable: Any]? { - let wrapped: [AnyHashable: SerializableAny]? = try [AnyHashable: SerializableAny]?.foryRead( - context, - refMode: refMode, - readTypeInfo: readTypeInfo - ) - guard let wrapped else { - return nil - } - try reserveAnyReferenceMapMemory(context, [AnyHashable: Any].self, count: wrapped.count) - var map: [AnyHashable: Any] = [:] - map.reserveCapacity(wrapped.count) - for pair in wrapped { - map[pair.key] = pair.value.anyValueForCollection() - } - return map + let wrapped: [AnyHashable: SerializableAny]? = try [AnyHashable: SerializableAny]?.foryRead( + context, + refMode: refMode, + readTypeInfo: readTypeInfo + ) + guard let wrapped else { + return nil + } + try reserveAnyReferenceMapMemory(context, [AnyHashable: Any].self, count: wrapped.count) + var map: [AnyHashable: Any] = [:] + map.reserveCapacity(wrapped.count) + for pair in wrapped { + map[pair.key] = pair.value.anyValueForCollection() + } + return map } func readDynamicAnyMapValue(context: ReadContext) throws -> Any { - let map = try readMapAnyHashableToAny(context: context, refMode: .none) ?? [:] - if map.isEmpty { - try reserveAnyReferenceMapMemory(context, [String: Any].self, count: 0) - return [String: Any]() - } - try reserveAnyReferenceMapMemory(context, [String: Any].self, count: map.count) - var stringMap: [String: Any] = [:] - stringMap.reserveCapacity(map.count) - for pair in map { - guard let key = pair.key.base as? String else { - stringMap.removeAll(keepingCapacity: false) - break - } - stringMap[key] = pair.value - } - if stringMap.count == map.count { - return stringMap - } - - try reserveAnyReferenceMapMemory(context, [Int32: Any].self, count: map.count) - var int32Map: [Int32: Any] = [:] - int32Map.reserveCapacity(map.count) - for pair in map { - guard let key = pair.key.base as? Int32 else { - return map - } - int32Map[key] = pair.value - } - if int32Map.count == map.count { - return int32Map - } - - return map + let map = try readMapAnyHashableToAny(context: context, refMode: .none) ?? [:] + if map.isEmpty { + try reserveAnyReferenceMapMemory(context, [String: Any].self, count: 0) + return [String: Any]() + } + try reserveAnyReferenceMapMemory(context, [String: Any].self, count: map.count) + var stringMap: [String: Any] = [:] + stringMap.reserveCapacity(map.count) + for pair in map { + guard let key = pair.key.base as? String else { + stringMap.removeAll(keepingCapacity: false) + break + } + stringMap[key] = pair.value + } + if stringMap.count == map.count { + return stringMap + } + + try reserveAnyReferenceMapMemory(context, [Int32: Any].self, count: map.count) + var int32Map: [Int32: Any] = [:] + int32Map.reserveCapacity(map.count) + for pair in map { + guard let key = pair.key.base as? Int32 else { + return map + } + int32Map[key] = pair.value + } + if int32Map.count == map.count { + return int32Map + } + + return map } diff --git a/swift/Sources/Fory/CollectionSerializers.swift b/swift/Sources/Fory/CollectionSerializers.swift index f286ea3950..137b9fb37b 100644 --- a/swift/Sources/Fory/CollectionSerializers.swift +++ b/swift/Sources/Fory/CollectionSerializers.swift @@ -18,1112 +18,1112 @@ import Foundation enum CollectionHeader { - static let trackingRef: UInt8 = 0b0000_0001 - static let hasNull: UInt8 = 0b0000_0010 - static let declaredElementType: UInt8 = 0b0000_0100 - static let sameType: UInt8 = 0b0000_1000 + static let trackingRef: UInt8 = 0b0000_0001 + static let hasNull: UInt8 = 0b0000_0010 + static let declaredElementType: UInt8 = 0b0000_0100 + static let sameType: UInt8 = 0b0000_1000 } enum MapHeader { - static let trackingKeyRef: UInt8 = 0b0000_0001 - static let keyNull: UInt8 = 0b0000_0010 - static let declaredKeyType: UInt8 = 0b0000_0100 + static let trackingKeyRef: UInt8 = 0b0000_0001 + static let keyNull: UInt8 = 0b0000_0010 + static let declaredKeyType: UInt8 = 0b0000_0100 - static let trackingValueRef: UInt8 = 0b0000_1000 - static let valueNull: UInt8 = 0b0001_0000 - static let declaredValueType: UInt8 = 0b0010_0000 + static let trackingValueRef: UInt8 = 0b0000_1000 + static let valueNull: UInt8 = 0b0001_0000 + static let declaredValueType: UInt8 = 0b0010_0000 } private let storedReferenceBytes = 4 @inline(__always) private func storedElementBytes(_ type: Element.Type) -> Int { - type.isRefType ? storedReferenceBytes : max(1, MemoryLayout.stride) + type.isRefType ? storedReferenceBytes : max(1, MemoryLayout.stride) } @inline(__always) private func reserveGraphStorage( - _ context: ReadContext, - count: Int, - elementBytes: Int + _ context: ReadContext, + count: Int, + elementBytes: Int ) throws { - if count < 0 || elementBytes < 0 { - throw ForyError.invalidData("graph memory estimate overflows") - } - let (bytes, overflow) = count.multipliedReportingOverflow(by: elementBytes) - if overflow { - throw ForyError.invalidData("graph memory estimate overflows") - } - try context.reserveGraphMemory(bytes) + if count < 0 || elementBytes < 0 { + throw ForyError.invalidData("graph memory estimate overflows") + } + let (bytes, overflow) = count.multipliedReportingOverflow(by: elementBytes) + if overflow { + throw ForyError.invalidData("graph memory estimate overflows") + } + try context.reserveGraphMemory(bytes) } @inline(__always) private func reserveGraphArrayMemory( - _ context: ReadContext, - _ type: Element.Type, - count: Int + _ context: ReadContext, + _ type: Element.Type, + count: Int ) throws { - try reserveGraphStorage(context, count: count, elementBytes: storedElementBytes(type)) + try reserveGraphStorage(context, count: count, elementBytes: storedElementBytes(type)) } @inline(__always) private func reserveGraphMapMemory( - _ context: ReadContext, - key: Key.Type, - value: Value.Type, - count: Int + _ context: ReadContext, + key: Key.Type, + value: Value.Type, + count: Int ) throws { - let keyBytes = storedElementBytes(key) - let valueBytes = storedElementBytes(value) - let (elementBytes, overflow) = keyBytes.addingReportingOverflow(valueBytes) - if overflow { - throw ForyError.invalidData("graph memory estimate overflows") - } - try reserveGraphStorage(context, count: count, elementBytes: elementBytes) + let keyBytes = storedElementBytes(key) + let valueBytes = storedElementBytes(value) + let (elementBytes, overflow) = keyBytes.addingReportingOverflow(valueBytes) + if overflow { + throw ForyError.invalidData("graph memory estimate overflows") + } + try reserveGraphStorage(context, count: count, elementBytes: elementBytes) } private func primitiveArrayTypeID(for _: Element.Type) -> TypeId? { - if Element.self == UInt8.self { return .uint8Array } - if Element.self == Bool.self { return .boolArray } - if Element.self == Int8.self { return .int8Array } - if Element.self == Int16.self { return .int16Array } - if Element.self == Int32.self { return .int32Array } - if Element.self == Int64.self { return .int64Array } - if Element.self == UInt16.self { return .uint16Array } - if Element.self == UInt32.self { return .uint32Array } - if Element.self == UInt64.self { return .uint64Array } - if Element.self == Float16.self { return .float16Array } - if Element.self == BFloat16.self { return .bfloat16Array } - if Element.self == Float.self { return .float32Array } - if Element.self == Double.self { return .float64Array } - return nil + if Element.self == UInt8.self { return .uint8Array } + if Element.self == Bool.self { return .boolArray } + if Element.self == Int8.self { return .int8Array } + if Element.self == Int16.self { return .int16Array } + if Element.self == Int32.self { return .int32Array } + if Element.self == Int64.self { return .int64Array } + if Element.self == UInt16.self { return .uint16Array } + if Element.self == UInt32.self { return .uint32Array } + if Element.self == UInt64.self { return .uint64Array } + if Element.self == Float16.self { return .float16Array } + if Element.self == BFloat16.self { return .bfloat16Array } + if Element.self == Float.self { return .float32Array } + if Element.self == Double.self { return .float64Array } + return nil } private let hostIsLittleEndian = Int(littleEndian: 1) == 1 @inline(__always) private func uncheckedArrayCast(_ array: [From], to _: To.Type) -> [To] { - assert(From.self == To.self) - return unsafeBitCast(array, to: [To].self) + assert(From.self == To.self) + return unsafeBitCast(array, to: [To].self) } @inline(__always) private func readArrayUninitialized( - count: Int, - _ initializer: (UnsafeMutablePointer) throws -> Void + count: Int, + _ initializer: (UnsafeMutablePointer) throws -> Void ) rethrows -> [Element] { - try [Element](unsafeUninitializedCapacity: count) { destination, initializedCount in - if count > 0 { - try initializer(destination.baseAddress!) + try [Element](unsafeUninitializedCapacity: count) { destination, initializedCount in + if count > 0 { + try initializer(destination.baseAddress!) + } + initializedCount = count } - initializedCount = count - } } func writePrimitiveArray(_ value: [Element], context: WriteContext) { - if Element.self == UInt8.self { - let bytes = uncheckedArrayCast(value, to: UInt8.self) - context.buffer.writeVarUInt32(UInt32(bytes.count)) - context.buffer.writeBytes(bytes) - return - } - - if Element.self == Bool.self { - let bools = uncheckedArrayCast(value, to: Bool.self) - context.buffer.writeVarUInt32(UInt32(bools.count)) - for item in bools { - context.buffer.writeUInt8(item ? 1 : 0) + if Element.self == UInt8.self { + let bytes = uncheckedArrayCast(value, to: UInt8.self) + context.buffer.writeVarUInt32(UInt32(bytes.count)) + context.buffer.writeBytes(bytes) + return } - return - } - - if Element.self == Int8.self { - let values = uncheckedArrayCast(value, to: Int8.self) - context.buffer.writeVarUInt32(UInt32(values.count)) - values.withUnsafeBytes { rawBytes in - context.buffer.writeBytes(rawBytes) + + if Element.self == Bool.self { + let bools = uncheckedArrayCast(value, to: Bool.self) + context.buffer.writeVarUInt32(UInt32(bools.count)) + for item in bools { + context.buffer.writeUInt8(item ? 1 : 0) + } + return } - return - } - if Element.self == Int16.self { - let values = uncheckedArrayCast(value, to: Int16.self) - context.buffer.writeVarUInt32(UInt32(values.count * 2)) - if hostIsLittleEndian { - values.withUnsafeBytes { rawBytes in - context.buffer.writeBytes(rawBytes) - } - } else { - for item in values { - context.buffer.writeInt16(item) - } + if Element.self == Int8.self { + let values = uncheckedArrayCast(value, to: Int8.self) + context.buffer.writeVarUInt32(UInt32(values.count)) + values.withUnsafeBytes { rawBytes in + context.buffer.writeBytes(rawBytes) + } + return } - return - } - if Element.self == Int32.self { - let values = uncheckedArrayCast(value, to: Int32.self) - context.buffer.writeVarUInt32(UInt32(values.count * 4)) - if hostIsLittleEndian { - values.withUnsafeBytes { rawBytes in - context.buffer.writeBytes(rawBytes) - } - } else { - for item in values { - context.buffer.writeInt32(item) - } + if Element.self == Int16.self { + let values = uncheckedArrayCast(value, to: Int16.self) + context.buffer.writeVarUInt32(UInt32(values.count * 2)) + if hostIsLittleEndian { + values.withUnsafeBytes { rawBytes in + context.buffer.writeBytes(rawBytes) + } + } else { + for item in values { + context.buffer.writeInt16(item) + } + } + return } - return - } - if Element.self == UInt32.self { - let values = uncheckedArrayCast(value, to: UInt32.self) - context.buffer.writeVarUInt32(UInt32(values.count * 4)) - if hostIsLittleEndian { - values.withUnsafeBytes { rawBytes in - context.buffer.writeBytes(rawBytes) - } - } else { - for item in values { - context.buffer.writeUInt32(item) - } + if Element.self == Int32.self { + let values = uncheckedArrayCast(value, to: Int32.self) + context.buffer.writeVarUInt32(UInt32(values.count * 4)) + if hostIsLittleEndian { + values.withUnsafeBytes { rawBytes in + context.buffer.writeBytes(rawBytes) + } + } else { + for item in values { + context.buffer.writeInt32(item) + } + } + return } - return - } - if Element.self == Int64.self { - let values = uncheckedArrayCast(value, to: Int64.self) - context.buffer.writeVarUInt32(UInt32(values.count * 8)) - if hostIsLittleEndian { - values.withUnsafeBytes { rawBytes in - context.buffer.writeBytes(rawBytes) - } - } else { - for item in values { - context.buffer.writeInt64(item) - } + if Element.self == UInt32.self { + let values = uncheckedArrayCast(value, to: UInt32.self) + context.buffer.writeVarUInt32(UInt32(values.count * 4)) + if hostIsLittleEndian { + values.withUnsafeBytes { rawBytes in + context.buffer.writeBytes(rawBytes) + } + } else { + for item in values { + context.buffer.writeUInt32(item) + } + } + return } - return - } - if Element.self == UInt64.self { - let values = uncheckedArrayCast(value, to: UInt64.self) - context.buffer.writeVarUInt32(UInt32(values.count * 8)) - if hostIsLittleEndian { - values.withUnsafeBytes { rawBytes in - context.buffer.writeBytes(rawBytes) - } - } else { - for item in values { - context.buffer.writeUInt64(item) - } + if Element.self == Int64.self { + let values = uncheckedArrayCast(value, to: Int64.self) + context.buffer.writeVarUInt32(UInt32(values.count * 8)) + if hostIsLittleEndian { + values.withUnsafeBytes { rawBytes in + context.buffer.writeBytes(rawBytes) + } + } else { + for item in values { + context.buffer.writeInt64(item) + } + } + return } - return - } - if Element.self == UInt16.self { - let values = uncheckedArrayCast(value, to: UInt16.self) - context.buffer.writeVarUInt32(UInt32(values.count * 2)) - if hostIsLittleEndian { - values.withUnsafeBytes { rawBytes in - context.buffer.writeBytes(rawBytes) - } - } else { - for item in values { - context.buffer.writeUInt16(item) - } + if Element.self == UInt64.self { + let values = uncheckedArrayCast(value, to: UInt64.self) + context.buffer.writeVarUInt32(UInt32(values.count * 8)) + if hostIsLittleEndian { + values.withUnsafeBytes { rawBytes in + context.buffer.writeBytes(rawBytes) + } + } else { + for item in values { + context.buffer.writeUInt64(item) + } + } + return } - return - } - - if Element.self == Float16.self { - let values = uncheckedArrayCast(value, to: Float16.self) - context.buffer.writeVarUInt32(UInt32(values.count * 2)) - for item in values { - context.buffer.writeUInt16(item.bitPattern) + + if Element.self == UInt16.self { + let values = uncheckedArrayCast(value, to: UInt16.self) + context.buffer.writeVarUInt32(UInt32(values.count * 2)) + if hostIsLittleEndian { + values.withUnsafeBytes { rawBytes in + context.buffer.writeBytes(rawBytes) + } + } else { + for item in values { + context.buffer.writeUInt16(item) + } + } + return } - return - } - - if Element.self == BFloat16.self { - let values = uncheckedArrayCast(value, to: BFloat16.self) - context.buffer.writeVarUInt32(UInt32(values.count * 2)) - for item in values { - context.buffer.writeUInt16(item.rawValue) + + if Element.self == Float16.self { + let values = uncheckedArrayCast(value, to: Float16.self) + context.buffer.writeVarUInt32(UInt32(values.count * 2)) + for item in values { + context.buffer.writeUInt16(item.bitPattern) + } + return } - return - } - if Element.self == Float.self { - let values = uncheckedArrayCast(value, to: Float.self) - context.buffer.writeVarUInt32(UInt32(values.count * 4)) - if hostIsLittleEndian { - values.withUnsafeBytes { rawBytes in - context.buffer.writeBytes(rawBytes) - } - } else { - for item in values { - context.buffer.writeFloat32(item) - } + if Element.self == BFloat16.self { + let values = uncheckedArrayCast(value, to: BFloat16.self) + context.buffer.writeVarUInt32(UInt32(values.count * 2)) + for item in values { + context.buffer.writeUInt16(item.rawValue) + } + return } - return - } - - let values = uncheckedArrayCast(value, to: Double.self) - context.buffer.writeVarUInt32(UInt32(values.count * 8)) - if hostIsLittleEndian { - values.withUnsafeBytes { rawBytes in - context.buffer.writeBytes(rawBytes) + + if Element.self == Float.self { + let values = uncheckedArrayCast(value, to: Float.self) + context.buffer.writeVarUInt32(UInt32(values.count * 4)) + if hostIsLittleEndian { + values.withUnsafeBytes { rawBytes in + context.buffer.writeBytes(rawBytes) + } + } else { + for item in values { + context.buffer.writeFloat32(item) + } + } + return } - } else { - for item in values { - context.buffer.writeFloat64(item) + + let values = uncheckedArrayCast(value, to: Double.self) + context.buffer.writeVarUInt32(UInt32(values.count * 8)) + if hostIsLittleEndian { + values.withUnsafeBytes { rawBytes in + context.buffer.writeBytes(rawBytes) + } + } else { + for item in values { + context.buffer.writeFloat64(item) + } } - } } @inline(__always) private func preparePrimitiveArray( - _ context: ReadContext, - reserveGraphStorage: Bool, - type: Element.Type, - count: Int, - label: String + _ context: ReadContext, + reserveGraphStorage: Bool, + type: Element.Type, + count: Int, + label: String ) throws { - try context.ensureCollectionLength(count, label: label) - if reserveGraphStorage { - try reserveGraphArrayMemory(context, type, count: count) - } + try context.ensureCollectionLength(count, label: label) + if reserveGraphStorage { + try reserveGraphArrayMemory(context, type, count: count) + } } func readPrimitiveArray( - _ context: ReadContext, - reserveGraphStorage: Bool = false + _ context: ReadContext, + reserveGraphStorage: Bool = false ) throws -> [Element] { - let byteSize = Int(try context.buffer.readVarUInt32()) - try context.ensureRemainingBytes(byteSize, label: "primitive_array_bytes") - - if Element.self == UInt8.self { - try preparePrimitiveArray( - context, reserveGraphStorage: reserveGraphStorage, type: Element.self, count: byteSize, - label: "uint8_array") - let bytes = try context.buffer.readBytes(count: byteSize) - return uncheckedArrayCast(bytes, to: Element.self) - } - - if Element.self == Bool.self { - try preparePrimitiveArray( - context, reserveGraphStorage: reserveGraphStorage, type: Element.self, count: byteSize, - label: "bool_array") - let out = try readArrayUninitialized(count: byteSize) { destination in - for index in 0.. [Element] { - [] - } - - public static var staticTypeId: TypeId { - .list - } - - public static func foryWriteStaticTypeInfo(_ context: WriteContext) throws { - context.buffer.writeUInt8(UInt8(truncatingIfNeeded: staticTypeId.rawValue)) - } - - public static func foryReadTypeInfo(_ context: ReadContext) throws -> TypeInfo? { - let rawTypeID = try context.buffer.readVarUInt32() - guard let actualTypeID = TypeId(rawValue: rawTypeID) else { - throw ForyError.invalidData("unknown type id \(rawTypeID)") - } - - let expectedTypeID = staticTypeId - if actualTypeID != expectedTypeID { - throw ForyError.typeMismatch(expected: expectedTypeID.rawValue, actual: rawTypeID) + public static func foryDefault() -> [Element] { + [] } - return nil - } - public func foryWriteData(_ context: WriteContext, hasGenerics: Bool) throws { - let buffer = context.buffer - buffer.writeVarUInt32(UInt32(self.count)) - if self.isEmpty { - return + public static var staticTypeId: TypeId { + .list } - let hasNull = Element.isNullableType && self.contains(where: { $0.foryIsNone }) - let trackRef = context.trackRef && Element.isRefType - let declaredElementType = hasGenerics && !TypeId.needsTypeInfoForField(Element.staticTypeId) - let dynamicElementType = Element.staticTypeId == .unknown - - var header: UInt8 = dynamicElementType ? 0 : CollectionHeader.sameType - if trackRef { - header |= CollectionHeader.trackingRef - } - if hasNull { - header |= CollectionHeader.hasNull - } - if declaredElementType { - header |= CollectionHeader.declaredElementType + public static func foryWriteStaticTypeInfo(_ context: WriteContext) throws { + context.buffer.writeUInt8(UInt8(truncatingIfNeeded: staticTypeId.rawValue)) } - buffer.writeUInt8(header) - if !dynamicElementType && !declaredElementType { - try Element.foryWriteStaticTypeInfo(context) - } + public static func foryReadTypeInfo(_ context: ReadContext) throws -> TypeInfo? { + let rawTypeID = try context.buffer.readVarUInt32() + guard let actualTypeID = TypeId(rawValue: rawTypeID) else { + throw ForyError.invalidData("unknown type id \(rawTypeID)") + } - if dynamicElementType { - let refMode: RefMode - if trackRef { - refMode = .tracking - } else if hasNull { - refMode = .nullOnly - } else { - refMode = .none - } - for element in self { - try element.foryWrite( - context, refMode: refMode, writeTypeInfo: true, hasGenerics: hasGenerics) - } - return + let expectedTypeID = staticTypeId + if actualTypeID != expectedTypeID { + throw ForyError.typeMismatch(expected: expectedTypeID.rawValue, actual: rawTypeID) + } + return nil } - if trackRef { - for element in self { - try element.foryWrite( - context, refMode: .tracking, writeTypeInfo: false, hasGenerics: hasGenerics) - } - } else if hasNull { - for element in self { - if element.foryIsNone { - buffer.writeInt8(RefFlag.null.rawValue) - } else { - buffer.writeInt8(RefFlag.notNullValue.rawValue) - try element.foryWriteData(context, hasGenerics: hasGenerics) + public func foryWriteData(_ context: WriteContext, hasGenerics: Bool) throws { + let buffer = context.buffer + buffer.writeVarUInt32(UInt32(self.count)) + if self.isEmpty { + return } - } - } else { - for element in self { - try element.foryWriteData(context, hasGenerics: hasGenerics) - } - } - } - public static func foryReadData(_ context: ReadContext) throws -> [Element] { - try readData(context, reserveGraphStorage: true) - } + let hasNull = Element.isNullableType && self.contains(where: { $0.foryIsNone }) + let trackRef = context.trackRef && Element.isRefType + let declaredElementType = hasGenerics && !TypeId.needsTypeInfoForField(Element.staticTypeId) + let dynamicElementType = Element.staticTypeId == .unknown - fileprivate static func readData( - _ context: ReadContext, - reserveGraphStorage: Bool - ) throws -> [Element] { - let buffer = context.buffer - let length = Int(try buffer.readVarUInt32()) - try context.ensureCollectionLength(length, label: "array") - if length == 0 { - if reserveGraphStorage { - try reserveGraphArrayMemory(context, Element.self, count: length) - } - return [] - } + var header: UInt8 = dynamicElementType ? 0 : CollectionHeader.sameType + if trackRef { + header |= CollectionHeader.trackingRef + } + if hasNull { + header |= CollectionHeader.hasNull + } + if declaredElementType { + header |= CollectionHeader.declaredElementType + } - let header = try buffer.readUInt8() - let trackRef = (header & CollectionHeader.trackingRef) != 0 - let hasNull = (header & CollectionHeader.hasNull) != 0 - let declared = (header & CollectionHeader.declaredElementType) != 0 - let sameType = (header & CollectionHeader.sameType) != 0 - if !sameType { - if reserveGraphStorage { - try reserveGraphArrayMemory(context, Element.self, count: length) - } - try context.ensureRemainingBytes(length, label: "array") - if trackRef { - return try readArrayUninitialized(count: length) { destination in - for index in 0.. [Element] { + try readData(context, reserveGraphStorage: true) } - try context.ensureRemainingBytes(length, label: "array") - return try context.withTypeInfo(elementTypeInfo, for: Element.self) { - if trackRef { - return try readArrayUninitialized(count: length) { destination in - for index in 0.. [Element] { + let buffer = context.buffer + let length = Int(try buffer.readVarUInt32()) + try context.ensureCollectionLength(length, label: "array") + if length == 0 { + if reserveGraphStorage { + try reserveGraphArrayMemory(context, Element.self, count: length) + } + return [] } - } - - if hasNull { - return try readArrayUninitialized(count: length) { destination in - for index in 0.. Set { [] } + public static func foryDefault() -> Set { [] } - public static var staticTypeId: TypeId { .set } + public static var staticTypeId: TypeId { .set } - public static func foryWriteStaticTypeInfo(_ context: WriteContext) throws { - context.writeStaticTypeInfo(staticTypeId) - } + public static func foryWriteStaticTypeInfo(_ context: WriteContext) throws { + context.writeStaticTypeInfo(staticTypeId) + } - public static func foryReadTypeInfo(_ context: ReadContext) throws -> TypeInfo? { - try context.readStaticTypeInfo(staticTypeId) - } + public static func foryReadTypeInfo(_ context: ReadContext) throws -> TypeInfo? { + try context.readStaticTypeInfo(staticTypeId) + } - public func foryWriteData(_ context: WriteContext, hasGenerics: Bool) throws { - try Array(self).foryWriteData(context, hasGenerics: hasGenerics) - } + public func foryWriteData(_ context: WriteContext, hasGenerics: Bool) throws { + try Array(self).foryWriteData(context, hasGenerics: hasGenerics) + } - public static func foryReadData(_ context: ReadContext) throws -> Set { - let values = try [Element].readData(context, reserveGraphStorage: false) - try reserveGraphArrayMemory(context, Element.self, count: values.count) - return Set(values) - } + public static func foryReadData(_ context: ReadContext) throws -> Set { + let values = try [Element].readData(context, reserveGraphStorage: false) + try reserveGraphArrayMemory(context, Element.self, count: values.count) + return Set(values) + } } extension Dictionary: Serializer where Key: Serializer & Hashable, Value: Serializer { - public static func foryDefault() -> [Key: Value] { [:] } + public static func foryDefault() -> [Key: Value] { [:] } - public static var staticTypeId: TypeId { .map } + public static var staticTypeId: TypeId { .map } - public static func foryWriteStaticTypeInfo(_ context: WriteContext) throws { - context.writeStaticTypeInfo(staticTypeId) - } - - public static func foryReadTypeInfo(_ context: ReadContext) throws -> TypeInfo? { - try context.readStaticTypeInfo(staticTypeId) - } + public static func foryWriteStaticTypeInfo(_ context: WriteContext) throws { + context.writeStaticTypeInfo(staticTypeId) + } - public func foryWriteData(_ context: WriteContext, hasGenerics: Bool) throws { - context.buffer.writeVarUInt32(UInt32(self.count)) - if self.isEmpty { - return + public static func foryReadTypeInfo(_ context: ReadContext) throws -> TypeInfo? { + try context.readStaticTypeInfo(staticTypeId) } - let trackKeyRef = context.trackRef && Key.isRefType - let trackValueRef = context.trackRef && Value.isRefType - let keyDeclared = hasGenerics && !TypeId.needsTypeInfoForField(Key.staticTypeId) - let valueDeclared = hasGenerics && !TypeId.needsTypeInfoForField(Value.staticTypeId) - let keyDynamicType = Key.staticTypeId == .unknown - let valueDynamicType = Value.staticTypeId == .unknown - - if keyDynamicType || valueDynamicType { - for pair in self { - let keyIsNil = pair.key.foryIsNone - let valueIsNil = pair.value.foryIsNone - var header: UInt8 = 0 - if trackKeyRef { - header |= MapHeader.trackingKeyRef - } - if trackValueRef { - header |= MapHeader.trackingValueRef - } - if keyIsNil { - header |= MapHeader.keyNull - } else if !keyDynamicType && keyDeclared { - header |= MapHeader.declaredKeyType + public func foryWriteData(_ context: WriteContext, hasGenerics: Bool) throws { + context.buffer.writeVarUInt32(UInt32(self.count)) + if self.isEmpty { + return } - if valueIsNil { - header |= MapHeader.valueNull - } else if !valueDynamicType && valueDeclared { - header |= MapHeader.declaredValueType - } - context.buffer.writeUInt8(header) - if keyIsNil && valueIsNil { - continue - } - if keyIsNil { - if !valueDeclared { - if valueDynamicType { - try pair.value.foryWriteTypeInfo(context) - } else { - try Value.foryWriteStaticTypeInfo(context) + let trackKeyRef = context.trackRef && Key.isRefType + let trackValueRef = context.trackRef && Value.isRefType + let keyDeclared = hasGenerics && !TypeId.needsTypeInfoForField(Key.staticTypeId) + let valueDeclared = hasGenerics && !TypeId.needsTypeInfoForField(Value.staticTypeId) + let keyDynamicType = Key.staticTypeId == .unknown + let valueDynamicType = Value.staticTypeId == .unknown + + if keyDynamicType || valueDynamicType { + for pair in self { + let keyIsNil = pair.key.foryIsNone + let valueIsNil = pair.value.foryIsNone + var header: UInt8 = 0 + if trackKeyRef { + header |= MapHeader.trackingKeyRef + } + if trackValueRef { + header |= MapHeader.trackingValueRef + } + if keyIsNil { + header |= MapHeader.keyNull + } else if !keyDynamicType && keyDeclared { + header |= MapHeader.declaredKeyType + } + if valueIsNil { + header |= MapHeader.valueNull + } else if !valueDynamicType && valueDeclared { + header |= MapHeader.declaredValueType + } + context.buffer.writeUInt8(header) + + if keyIsNil && valueIsNil { + continue + } + if keyIsNil { + if !valueDeclared { + if valueDynamicType { + try pair.value.foryWriteTypeInfo(context) + } else { + try Value.foryWriteStaticTypeInfo(context) + } + } + if trackValueRef { + try pair.value.foryWrite( + context, + refMode: .tracking, + writeTypeInfo: false, + hasGenerics: hasGenerics + ) + } else { + try pair.value.foryWriteData(context, hasGenerics: hasGenerics) + } + continue + } + + if valueIsNil { + if !keyDeclared { + if keyDynamicType { + try pair.key.foryWriteTypeInfo(context) + } else { + try Key.foryWriteStaticTypeInfo(context) + } + } + if trackKeyRef { + try pair.key.foryWrite( + context, + refMode: .tracking, + writeTypeInfo: false, + hasGenerics: hasGenerics + ) + } else { + try pair.key.foryWriteData(context, hasGenerics: hasGenerics) + } + continue + } + + context.buffer.writeUInt8(1) + + if !keyDeclared { + if keyDynamicType { + try pair.key.foryWriteTypeInfo(context) + } else { + try Key.foryWriteStaticTypeInfo(context) + } + } + if !valueDeclared { + if valueDynamicType { + try pair.value.foryWriteTypeInfo(context) + } else { + try Value.foryWriteStaticTypeInfo(context) + } + } + + if trackKeyRef { + try pair.key.foryWrite( + context, + refMode: .tracking, + writeTypeInfo: false, + hasGenerics: hasGenerics + ) + } else { + try pair.key.foryWriteData(context, hasGenerics: hasGenerics) + } + if trackValueRef { + try pair.value.foryWrite( + context, + refMode: .tracking, + writeTypeInfo: false, + hasGenerics: hasGenerics + ) + } else { + try pair.value.foryWriteData(context, hasGenerics: hasGenerics) + } } - } - if trackValueRef { - try pair.value.foryWrite( - context, - refMode: .tracking, - writeTypeInfo: false, - hasGenerics: hasGenerics - ) - } else { - try pair.value.foryWriteData(context, hasGenerics: hasGenerics) - } - continue + return } - if valueIsNil { - if !keyDeclared { - if keyDynamicType { - try pair.key.foryWriteTypeInfo(context) - } else { - try Key.foryWriteStaticTypeInfo(context) + var iterator = makeIterator() + var pendingPair = iterator.next() + + while let pair = pendingPair { + let keyIsNil = pair.key.foryIsNone + let valueIsNil = pair.value.foryIsNone + + if keyIsNil || valueIsNil { + var header: UInt8 = 0 + if trackKeyRef { + header |= MapHeader.trackingKeyRef + } + if trackValueRef { + header |= MapHeader.trackingValueRef + } + if keyIsNil { header |= MapHeader.keyNull } + if valueIsNil { header |= MapHeader.valueNull } + if !keyIsNil && keyDeclared { header |= MapHeader.declaredKeyType } + if !valueIsNil && valueDeclared { header |= MapHeader.declaredValueType } + + context.buffer.writeUInt8(header) + if !keyIsNil { + if !keyDeclared { + try Key.foryWriteStaticTypeInfo(context) + } + if trackKeyRef { + try pair.key.foryWrite( + context, refMode: .tracking, writeTypeInfo: false, hasGenerics: hasGenerics) + } else { + try pair.key.foryWriteData(context, hasGenerics: hasGenerics) + } + } + if !valueIsNil { + if !valueDeclared { + try Value.foryWriteStaticTypeInfo(context) + } + if trackValueRef { + try pair.value.foryWrite( + context, refMode: .tracking, writeTypeInfo: false, hasGenerics: hasGenerics) + } else { + try pair.value.foryWriteData(context, hasGenerics: hasGenerics) + } + } + pendingPair = iterator.next() + continue } - } - if trackKeyRef { - try pair.key.foryWrite( - context, - refMode: .tracking, - writeTypeInfo: false, - hasGenerics: hasGenerics - ) - } else { - try pair.key.foryWriteData(context, hasGenerics: hasGenerics) - } - continue - } - context.buffer.writeUInt8(1) + var header: UInt8 = 0 + if trackKeyRef { header |= MapHeader.trackingKeyRef } + if trackValueRef { header |= MapHeader.trackingValueRef } + if keyDeclared { header |= MapHeader.declaredKeyType } + if valueDeclared { header |= MapHeader.declaredValueType } - if !keyDeclared { - if keyDynamicType { - try pair.key.foryWriteTypeInfo(context) - } else { - try Key.foryWriteStaticTypeInfo(context) - } - } - if !valueDeclared { - if valueDynamicType { - try pair.value.foryWriteTypeInfo(context) - } else { - try Value.foryWriteStaticTypeInfo(context) - } - } + context.buffer.writeUInt8(header) + let chunkSizeOffset = context.buffer.count + context.buffer.writeUInt8(0) - if trackKeyRef { - try pair.key.foryWrite( - context, - refMode: .tracking, - writeTypeInfo: false, - hasGenerics: hasGenerics - ) - } else { - try pair.key.foryWriteData(context, hasGenerics: hasGenerics) - } - if trackValueRef { - try pair.value.foryWrite( - context, - refMode: .tracking, - writeTypeInfo: false, - hasGenerics: hasGenerics - ) - } else { - try pair.value.foryWriteData(context, hasGenerics: hasGenerics) - } - } - return - } - - var iterator = makeIterator() - var pendingPair = iterator.next() - - while let pair = pendingPair { - let keyIsNil = pair.key.foryIsNone - let valueIsNil = pair.value.foryIsNone + if !keyDeclared { + try Key.foryWriteStaticTypeInfo(context) + } + if !valueDeclared { + try Value.foryWriteStaticTypeInfo(context) + } - if keyIsNil || valueIsNil { - var header: UInt8 = 0 - if trackKeyRef { - header |= MapHeader.trackingKeyRef - } - if trackValueRef { - header |= MapHeader.trackingValueRef - } - if keyIsNil { header |= MapHeader.keyNull } - if valueIsNil { header |= MapHeader.valueNull } - if !keyIsNil && keyDeclared { header |= MapHeader.declaredKeyType } - if !valueIsNil && valueDeclared { header |= MapHeader.declaredValueType } - - context.buffer.writeUInt8(header) - if !keyIsNil { - if !keyDeclared { - try Key.foryWriteStaticTypeInfo(context) - } - if trackKeyRef { - try pair.key.foryWrite( - context, refMode: .tracking, writeTypeInfo: false, hasGenerics: hasGenerics) - } else { - try pair.key.foryWriteData(context, hasGenerics: hasGenerics) - } - } - if !valueIsNil { - if !valueDeclared { - try Value.foryWriteStaticTypeInfo(context) - } - if trackValueRef { - try pair.value.foryWrite( - context, refMode: .tracking, writeTypeInfo: false, hasGenerics: hasGenerics) - } else { - try pair.value.foryWriteData(context, hasGenerics: hasGenerics) - } - } - pendingPair = iterator.next() - continue - } - - var header: UInt8 = 0 - if trackKeyRef { header |= MapHeader.trackingKeyRef } - if trackValueRef { header |= MapHeader.trackingValueRef } - if keyDeclared { header |= MapHeader.declaredKeyType } - if valueDeclared { header |= MapHeader.declaredValueType } - - context.buffer.writeUInt8(header) - let chunkSizeOffset = context.buffer.count - context.buffer.writeUInt8(0) - - if !keyDeclared { - try Key.foryWriteStaticTypeInfo(context) - } - if !valueDeclared { - try Value.foryWriteStaticTypeInfo(context) - } - - var chunkSize: UInt8 = 0 - while chunkSize < UInt8.max, let current = pendingPair { - if current.key.foryIsNone || current.value.foryIsNone { - break - } - if trackKeyRef { - try current.key.foryWrite( - context, refMode: .tracking, writeTypeInfo: false, hasGenerics: hasGenerics) - } else { - try current.key.foryWriteData(context, hasGenerics: hasGenerics) - } - if trackValueRef { - try current.value.foryWrite( - context, refMode: .tracking, writeTypeInfo: false, hasGenerics: hasGenerics) - } else { - try current.value.foryWriteData(context, hasGenerics: hasGenerics) + var chunkSize: UInt8 = 0 + while chunkSize < UInt8.max, let current = pendingPair { + if current.key.foryIsNone || current.value.foryIsNone { + break + } + if trackKeyRef { + try current.key.foryWrite( + context, refMode: .tracking, writeTypeInfo: false, hasGenerics: hasGenerics) + } else { + try current.key.foryWriteData(context, hasGenerics: hasGenerics) + } + if trackValueRef { + try current.value.foryWrite( + context, refMode: .tracking, writeTypeInfo: false, hasGenerics: hasGenerics) + } else { + try current.value.foryWriteData(context, hasGenerics: hasGenerics) + } + chunkSize &+= 1 + pendingPair = iterator.next() + } + context.buffer.setByte(at: chunkSizeOffset, to: chunkSize) } - chunkSize &+= 1 - pendingPair = iterator.next() - } - context.buffer.setByte(at: chunkSizeOffset, to: chunkSize) - } - } - - public static func foryReadData(_ context: ReadContext) throws -> [Key: Value] { - let totalLength = Int(try context.buffer.readVarUInt32()) - try context.ensureCollectionLength(totalLength, label: "map") - if totalLength == 0 { - try reserveGraphMapMemory(context, key: Key.self, value: Value.self, count: totalLength) - return [:] } - try reserveGraphMapMemory(context, key: Key.self, value: Value.self, count: totalLength) - try context.ensureRemainingBytes(totalLength, label: "map") - var map: [Key: Value] = [:] - map.reserveCapacity(totalLength) - let keyDynamicType = Key.staticTypeId == .unknown - let valueDynamicType = Value.staticTypeId == .unknown - if keyDynamicType || valueDynamicType { - var dynamicReadCount = 0 - while dynamicReadCount < totalLength { - let header = try context.buffer.readUInt8() - let trackKeyRef = (header & MapHeader.trackingKeyRef) != 0 - let keyNull = (header & MapHeader.keyNull) != 0 - let keyDeclared = (header & MapHeader.declaredKeyType) != 0 - - let trackValueRef = (header & MapHeader.trackingValueRef) != 0 - let valueNull = (header & MapHeader.valueNull) != 0 - let valueDeclared = (header & MapHeader.declaredValueType) != 0 - - if keyNull && valueNull { - map[Key.foryDefault()] = Value.foryDefault() - dynamicReadCount += 1 - continue + public static func foryReadData(_ context: ReadContext) throws -> [Key: Value] { + let totalLength = Int(try context.buffer.readVarUInt32()) + try context.ensureCollectionLength(totalLength, label: "map") + if totalLength == 0 { + try reserveGraphMapMemory(context, key: Key.self, value: Value.self, count: totalLength) + return [:] } - if keyNull { - let value = try Value.foryRead( - context, - refMode: trackValueRef ? .tracking : .none, - readTypeInfo: valueDynamicType || !valueDeclared - ) - map[Key.foryDefault()] = value - dynamicReadCount += 1 - continue + try reserveGraphMapMemory(context, key: Key.self, value: Value.self, count: totalLength) + try context.ensureRemainingBytes(totalLength, label: "map") + var map: [Key: Value] = [:] + map.reserveCapacity(totalLength) + let keyDynamicType = Key.staticTypeId == .unknown + let valueDynamicType = Value.staticTypeId == .unknown + if keyDynamicType || valueDynamicType { + var dynamicReadCount = 0 + while dynamicReadCount < totalLength { + let header = try context.buffer.readUInt8() + let trackKeyRef = (header & MapHeader.trackingKeyRef) != 0 + let keyNull = (header & MapHeader.keyNull) != 0 + let keyDeclared = (header & MapHeader.declaredKeyType) != 0 + + let trackValueRef = (header & MapHeader.trackingValueRef) != 0 + let valueNull = (header & MapHeader.valueNull) != 0 + let valueDeclared = (header & MapHeader.declaredValueType) != 0 + + if keyNull && valueNull { + map[Key.foryDefault()] = Value.foryDefault() + dynamicReadCount += 1 + continue + } + + if keyNull { + let value = try Value.foryRead( + context, + refMode: trackValueRef ? .tracking : .none, + readTypeInfo: valueDynamicType || !valueDeclared + ) + map[Key.foryDefault()] = value + dynamicReadCount += 1 + continue + } + + if valueNull { + let key = try Key.foryRead( + context, + refMode: trackKeyRef ? .tracking : .none, + readTypeInfo: keyDynamicType || !keyDeclared + ) + map[key] = Value.foryDefault() + dynamicReadCount += 1 + continue + } + + let chunkSize = Int(try context.buffer.readUInt8()) + if chunkSize > (totalLength - dynamicReadCount) { + throw ForyError.invalidData("map dynamic chunk size exceeds remaining entries") + } + let keyTypeInfo = keyDeclared ? nil : try Key.foryReadTypeInfo(context) + let valueTypeInfo = valueDeclared ? nil : try Value.foryReadTypeInfo(context) + for _ in 0.. (totalLength - dynamicReadCount) { - throw ForyError.invalidData("map dynamic chunk size exceeds remaining entries") - } - let keyTypeInfo = keyDeclared ? nil : try Key.foryReadTypeInfo(context) - let valueTypeInfo = valueDeclared ? nil : try Value.foryReadTypeInfo(context) - for _ in 0.. (totalLength - readCount) { - throw ForyError.invalidData("map chunk size exceeds remaining entries") - } - let keyTypeInfo = keyDeclared ? nil : try Key.foryReadTypeInfo(context) - let valueTypeInfo = valueDeclared ? nil : try Value.foryReadTypeInfo(context) - for _ in 0.. (totalLength - readCount) { + throw ForyError.invalidData("map chunk size exceeds remaining entries") + } + let keyTypeInfo = keyDeclared ? nil : try Key.foryReadTypeInfo(context) + let valueTypeInfo = valueDeclared ? nil : try Value.foryReadTypeInfo(context) + for _ in 0..(_ codec: ElementCodec.Type) -> Int { - codec.isRefType ? fieldReferenceBytes : max(1, MemoryLayout.stride) + codec.isRefType ? fieldReferenceBytes : max(1, MemoryLayout.stride) } @inline(__always) private func serializerElementBytes(_ type: Element.Type) -> Int { - type.isRefType ? fieldReferenceBytes : max(1, MemoryLayout.stride) + type.isRefType ? fieldReferenceBytes : max(1, MemoryLayout.stride) } @inline(__always) private func reserveFieldStorage( - _ context: ReadContext, - count: Int, - elementBytes: Int + _ context: ReadContext, + count: Int, + elementBytes: Int ) throws { - if count < 0 || elementBytes < 0 { - throw ForyError.invalidData("graph memory estimate overflows") - } - let (bytes, overflow) = count.multipliedReportingOverflow(by: elementBytes) - if overflow { - throw ForyError.invalidData("graph memory estimate overflows") - } - try context.reserveGraphMemory(bytes) + if count < 0 || elementBytes < 0 { + throw ForyError.invalidData("graph memory estimate overflows") + } + let (bytes, overflow) = count.multipliedReportingOverflow(by: elementBytes) + if overflow { + throw ForyError.invalidData("graph memory estimate overflows") + } + try context.reserveGraphMemory(bytes) } @inline(__always) private func reserveFieldArrayStorage( - _ context: ReadContext, - _ codec: ElementCodec.Type, - count: Int + _ context: ReadContext, + _ codec: ElementCodec.Type, + count: Int ) throws { - try reserveFieldStorage(context, count: count, elementBytes: fieldElementBytes(codec)) + try reserveFieldStorage(context, count: count, elementBytes: fieldElementBytes(codec)) } @inline(__always) private func reserveSerializerArrayMemory( - _ context: ReadContext, - _ type: Element.Type, - count: Int + _ context: ReadContext, + _ type: Element.Type, + count: Int ) throws { - try reserveFieldStorage(context, count: count, elementBytes: serializerElementBytes(type)) + try reserveFieldStorage(context, count: count, elementBytes: serializerElementBytes(type)) } @inline(__always) private func reserveFieldMapStorage( - _ context: ReadContext, - key: KeyCodec.Type, - value: ValueCodec.Type, - count: Int + _ context: ReadContext, + key: KeyCodec.Type, + value: ValueCodec.Type, + count: Int ) throws { - let keyBytes = fieldElementBytes(key) - let valueBytes = fieldElementBytes(value) - let (elementBytes, overflow) = keyBytes.addingReportingOverflow(valueBytes) - if overflow { - throw ForyError.invalidData("graph memory estimate overflows") - } - try reserveFieldStorage(context, count: count, elementBytes: elementBytes) + let keyBytes = fieldElementBytes(key) + let valueBytes = fieldElementBytes(value) + let (elementBytes, overflow) = keyBytes.addingReportingOverflow(valueBytes) + if overflow { + throw ForyError.invalidData("graph memory estimate overflows") + } + try reserveFieldStorage(context, count: count, elementBytes: elementBytes) } public protocol FieldCodec { - associatedtype Value - - static var typeId: TypeId { get } - static var defaultValue: Value { get } - static var isNullableType: Bool { get } - static var isRefType: Bool { get } - - static func isNone(_ value: Value) -> Bool - static func fieldType(nullable: Bool, trackRef: Bool) -> TypeMeta.FieldType - static func writePayload(_ value: Value, _ context: WriteContext) throws - static func readPayload(_ context: ReadContext) throws -> Value - static func writeStaticTypeInfo(_ context: WriteContext) throws - static func readTypeInfo(_ context: ReadContext) throws -> TypeInfo? - static func withTypeInfo(_ typeInfo: TypeInfo?, _ context: ReadContext, _ body: () throws -> R) - rethrows -> R - static func readCompatibleField( - _ context: ReadContext, - remoteFieldType: TypeMeta.FieldType, - refMode: RefMode - ) throws -> Value + associatedtype Value + + static var typeId: TypeId { get } + static var defaultValue: Value { get } + static var isNullableType: Bool { get } + static var isRefType: Bool { get } + + static func isNone(_ value: Value) -> Bool + static func fieldType(nullable: Bool, trackRef: Bool) -> TypeMeta.FieldType + static func writePayload(_ value: Value, _ context: WriteContext) throws + static func readPayload(_ context: ReadContext) throws -> Value + static func writeStaticTypeInfo(_ context: WriteContext) throws + static func readTypeInfo(_ context: ReadContext) throws -> TypeInfo? + static func withTypeInfo( + _ typeInfo: TypeInfo?, _ context: ReadContext, _ body: () throws -> R + ) + rethrows -> R + static func readCompatibleField( + _ context: ReadContext, + remoteFieldType: TypeMeta.FieldType, + refMode: RefMode + ) throws -> Value } extension FieldCodec { - public static var isNullableType: Bool { false } - public static var isRefType: Bool { false } + public static var isNullableType: Bool { false } + public static var isRefType: Bool { false } - public static func isNone(_: Value) -> Bool { false } + public static func isNone(_: Value) -> Bool { false } - public static func fieldType(nullable: Bool, trackRef: Bool) -> TypeMeta.FieldType { - TypeMeta.FieldType(typeID: typeId.rawValue, nullable: nullable, trackRef: trackRef) - } + public static func fieldType(nullable: Bool, trackRef: Bool) -> TypeMeta.FieldType { + TypeMeta.FieldType(typeID: typeId.rawValue, nullable: nullable, trackRef: trackRef) + } + + public static func writeStaticTypeInfo(_ context: WriteContext) throws { + context.writeStaticTypeInfo(typeId) + } - public static func writeStaticTypeInfo(_ context: WriteContext) throws { - context.writeStaticTypeInfo(typeId) - } + public static func readTypeInfo(_ context: ReadContext) throws -> TypeInfo? { + try context.readStaticTypeInfo(typeId) + } - public static func readTypeInfo(_ context: ReadContext) throws -> TypeInfo? { - try context.readStaticTypeInfo(typeId) - } + public static func withTypeInfo( + _ typeInfo: TypeInfo?, + _ context: ReadContext, + _ body: () throws -> R + ) rethrows -> R { + _ = typeInfo + _ = context + return try body() + } - public static func withTypeInfo( - _ typeInfo: TypeInfo?, - _ context: ReadContext, - _ body: () throws -> R - ) rethrows -> R { - _ = typeInfo - _ = context - return try body() - } - - public static func readCompatibleField( - _ context: ReadContext, - remoteFieldType: TypeMeta.FieldType, - refMode: RefMode - ) throws -> Value { - try read( - context, - refMode: refMode, - readTypeInfo: TypeId.needsTypeInfoForField( - TypeId(rawValue: remoteFieldType.typeID) ?? .unknown) - ) - } + public static func readCompatibleField( + _ context: ReadContext, + remoteFieldType: TypeMeta.FieldType, + refMode: RefMode + ) throws -> Value { + try read( + context, + refMode: refMode, + readTypeInfo: TypeId.needsTypeInfoForField( + TypeId(rawValue: remoteFieldType.typeID) ?? .unknown) + ) + } - public static func write( - _ value: Value, - _ context: WriteContext, - refMode: RefMode, - writeTypeInfo: Bool - ) throws { - if refMode != .none { - if refMode == .tracking, isRefType, let object = value as AnyObject? { - if context.refWriter.tryWriteRef(buffer: context.buffer, object: object) { - return + public static func write( + _ value: Value, + _ context: WriteContext, + refMode: RefMode, + writeTypeInfo: Bool + ) throws { + if refMode != .none { + if refMode == .tracking, isRefType, let object = value as AnyObject? { + if context.refWriter.tryWriteRef(buffer: context.buffer, object: object) { + return + } + } else { + context.buffer.writeInt8(RefFlag.notNullValue.rawValue) + } } - } else { - context.buffer.writeInt8(RefFlag.notNullValue.rawValue) - } - } - if writeTypeInfo { - try writeStaticTypeInfo(context) + if writeTypeInfo { + try writeStaticTypeInfo(context) + } + try writePayload(value, context) } - try writePayload(value, context) - } - public static func read( - _ context: ReadContext, - refMode: RefMode, - readTypeInfo: Bool - ) throws -> Value { - switch refMode { - case .none: - if readTypeInfo { - let typeInfo = try Self.readTypeInfo(context) - return try withTypeInfo(typeInfo, context) { - try readPayload(context) + public static func read( + _ context: ReadContext, + refMode: RefMode, + readTypeInfo: Bool + ) throws -> Value { + switch refMode { + case .none: + if readTypeInfo { + let typeInfo = try Self.readTypeInfo(context) + return try withTypeInfo(typeInfo, context) { + try readPayload(context) + } + } + return try readPayload(context) + case .nullOnly: + let rawFlag = try context.buffer.readInt8() + switch rawFlag { + case RefFlag.null.rawValue: + return defaultValue + case RefFlag.notNullValue.rawValue: + if readTypeInfo { + let typeInfo = try Self.readTypeInfo(context) + return try withTypeInfo(typeInfo, context) { + try readPayload(context) + } + } + return try readPayload(context) + case RefFlag.refValue.rawValue: + if context.trackRef { + let reservedRefID = context.refReader.reserveRefID() + let value = try readPayloadAfterTypeInfo(context, readTypeInfo: readTypeInfo) + context.refReader.storeRef(value, at: reservedRefID) + return value + } + return try readPayloadAfterTypeInfo(context, readTypeInfo: readTypeInfo) + case RefFlag.ref.rawValue: + let refID = try context.buffer.readVarUInt32() + return try context.refReader.readRef(refID, as: Value.self) + default: + throw ForyError.refError("invalid ref flag \(rawFlag)") + } + case .tracking: + let rawFlag = try context.buffer.readInt8() + guard let flag = RefFlag(rawValue: rawFlag) else { + throw ForyError.refError("invalid ref flag \(rawFlag)") + } + switch flag { + case .null: + return defaultValue + case .ref: + let refID = try context.buffer.readVarUInt32() + return try context.refReader.readRef(refID, as: Value.self) + case .refValue: + let reservedRefID = context.trackRef ? context.refReader.reserveRefID() : nil + let value = try readPayloadAfterTypeInfo(context, readTypeInfo: readTypeInfo) + if let reservedRefID { + context.refReader.storeRef(value, at: reservedRefID) + } + return value + case .notNullValue: + return try readPayloadAfterTypeInfo(context, readTypeInfo: readTypeInfo) + } } - } - return try readPayload(context) - case .nullOnly: - let rawFlag = try context.buffer.readInt8() - switch rawFlag { - case RefFlag.null.rawValue: - return defaultValue - case RefFlag.notNullValue.rawValue: + } + + private static func readPayloadAfterTypeInfo( + _ context: ReadContext, + readTypeInfo: Bool + ) throws -> Value { if readTypeInfo { - let typeInfo = try Self.readTypeInfo(context) - return try withTypeInfo(typeInfo, context) { - try readPayload(context) - } + let typeInfo = try Self.readTypeInfo(context) + return try withTypeInfo(typeInfo, context) { + try readPayload(context) + } } return try readPayload(context) - case RefFlag.refValue.rawValue: - if context.trackRef { - let reservedRefID = context.refReader.reserveRefID() - let value = try readPayloadAfterTypeInfo(context, readTypeInfo: readTypeInfo) - context.refReader.storeRef(value, at: reservedRefID) - return value - } - return try readPayloadAfterTypeInfo(context, readTypeInfo: readTypeInfo) - case RefFlag.ref.rawValue: - let refID = try context.buffer.readVarUInt32() - return try context.refReader.readRef(refID, as: Value.self) - default: - throw ForyError.refError("invalid ref flag \(rawFlag)") - } - case .tracking: - let rawFlag = try context.buffer.readInt8() - guard let flag = RefFlag(rawValue: rawFlag) else { - throw ForyError.refError("invalid ref flag \(rawFlag)") - } - switch flag { - case .null: - return defaultValue - case .ref: - let refID = try context.buffer.readVarUInt32() - return try context.refReader.readRef(refID, as: Value.self) - case .refValue: - let reservedRefID = context.trackRef ? context.refReader.reserveRefID() : nil - let value = try readPayloadAfterTypeInfo(context, readTypeInfo: readTypeInfo) - if let reservedRefID { - context.refReader.storeRef(value, at: reservedRefID) - } - return value - case .notNullValue: - return try readPayloadAfterTypeInfo(context, readTypeInfo: readTypeInfo) - } } - } - - private static func readPayloadAfterTypeInfo( - _ context: ReadContext, - readTypeInfo: Bool - ) throws -> Value { - if readTypeInfo { - let typeInfo = try Self.readTypeInfo(context) - return try withTypeInfo(typeInfo, context) { - try readPayload(context) - } - } - return try readPayload(context) - } } private enum FieldCodecDefault { - static func readCompatibleField( - codec _: Codec.Type, - _ context: ReadContext, - remoteFieldType: TypeMeta.FieldType, - refMode: RefMode - ) throws -> Codec.Value { - try Codec.read( - context, - refMode: refMode, - readTypeInfo: TypeId.needsTypeInfoForField( - TypeId(rawValue: remoteFieldType.typeID) ?? .unknown) - ) - } + static func readCompatibleField( + codec _: Codec.Type, + _ context: ReadContext, + remoteFieldType: TypeMeta.FieldType, + refMode: RefMode + ) throws -> Codec.Value { + try Codec.read( + context, + refMode: refMode, + readTypeInfo: TypeId.needsTypeInfoForField( + TypeId(rawValue: remoteFieldType.typeID) ?? .unknown) + ) + } } public enum SerializerCodec: FieldCodec { - public typealias Value = T + public typealias Value = T - public static var typeId: TypeId { T.staticTypeId } - public static var defaultValue: T { T.foryDefault() } - public static var isNullableType: Bool { T.isNullableType } - public static var isRefType: Bool { T.isRefType } + public static var typeId: TypeId { T.staticTypeId } + public static var defaultValue: T { T.foryDefault() } + public static var isNullableType: Bool { T.isNullableType } + public static var isRefType: Bool { T.isRefType } - public static func isNone(_ value: T) -> Bool { - value.foryIsNone - } + public static func isNone(_ value: T) -> Bool { + value.foryIsNone + } - public static func fieldType(nullable: Bool, trackRef: Bool) -> TypeMeta.FieldType { - let fieldTypeID = - T.staticTypeId == .structType ? TypeId.compatibleStruct.rawValue : T.staticTypeId.rawValue - return TypeMeta.FieldType(typeID: fieldTypeID, nullable: nullable, trackRef: trackRef) - } + public static func fieldType(nullable: Bool, trackRef: Bool) -> TypeMeta.FieldType { + let fieldTypeID = + T.staticTypeId == .structType ? TypeId.compatibleStruct.rawValue : T.staticTypeId.rawValue + return TypeMeta.FieldType(typeID: fieldTypeID, nullable: nullable, trackRef: trackRef) + } - public static func writePayload(_ value: T, _ context: WriteContext) throws { - try value.foryWriteData(context, hasGenerics: false) - } + public static func writePayload(_ value: T, _ context: WriteContext) throws { + try value.foryWriteData(context, hasGenerics: false) + } - public static func readPayload(_ context: ReadContext) throws -> T { - try T.foryReadPayload(context, readTypeInfo: false) - } + public static func readPayload(_ context: ReadContext) throws -> T { + try T.foryReadPayload(context, readTypeInfo: false) + } - public static func writeStaticTypeInfo(_ context: WriteContext) throws { - try T.foryWriteStaticTypeInfo(context) - } + public static func writeStaticTypeInfo(_ context: WriteContext) throws { + try T.foryWriteStaticTypeInfo(context) + } - public static func readTypeInfo(_ context: ReadContext) throws -> TypeInfo? { - try T.foryReadTypeInfo(context) - } + public static func readTypeInfo(_ context: ReadContext) throws -> TypeInfo? { + try T.foryReadTypeInfo(context) + } - public static func withTypeInfo( - _ typeInfo: TypeInfo?, - _ context: ReadContext, - _ body: () throws -> R - ) rethrows -> R { - try context.withTypeInfo(typeInfo, for: T.self, body) - } + public static func withTypeInfo( + _ typeInfo: TypeInfo?, + _ context: ReadContext, + _ body: () throws -> R + ) rethrows -> R { + try context.withTypeInfo(typeInfo, for: T.self, body) + } - public static func write( - _ value: T, - _ context: WriteContext, - refMode: RefMode, - writeTypeInfo: Bool - ) throws { - try value.foryWrite(context, refMode: refMode, writeTypeInfo: writeTypeInfo, hasGenerics: false) - } + public static func write( + _ value: T, + _ context: WriteContext, + refMode: RefMode, + writeTypeInfo: Bool + ) throws { + try value.foryWrite(context, refMode: refMode, writeTypeInfo: writeTypeInfo, hasGenerics: false) + } - public static func read( - _ context: ReadContext, - refMode: RefMode, - readTypeInfo: Bool - ) throws -> T { - try T.foryRead(context, refMode: refMode, readTypeInfo: readTypeInfo) - } + public static func read( + _ context: ReadContext, + refMode: RefMode, + readTypeInfo: Bool + ) throws -> T { + try T.foryRead(context, refMode: refMode, readTypeInfo: readTypeInfo) + } } public enum OptionalFieldCodec: FieldCodec { - public typealias Value = WrappedCodec.Value? + public typealias Value = WrappedCodec.Value? - public static var typeId: TypeId { WrappedCodec.typeId } - public static var defaultValue: Value { nil } - public static var isNullableType: Bool { true } - public static var isRefType: Bool { WrappedCodec.isRefType } + public static var typeId: TypeId { WrappedCodec.typeId } + public static var defaultValue: Value { nil } + public static var isNullableType: Bool { true } + public static var isRefType: Bool { WrappedCodec.isRefType } - public static func isNone(_ value: Value) -> Bool { - value == nil - } + public static func isNone(_ value: Value) -> Bool { + value == nil + } - public static func fieldType(nullable: Bool, trackRef: Bool) -> TypeMeta.FieldType { - WrappedCodec.fieldType(nullable: nullable, trackRef: trackRef) - } + public static func fieldType(nullable: Bool, trackRef: Bool) -> TypeMeta.FieldType { + WrappedCodec.fieldType(nullable: nullable, trackRef: trackRef) + } - public static func writePayload(_ value: Value, _ context: WriteContext) throws { - guard let value else { - throw ForyError.invalidData("Option.none cannot write raw payload") + public static func writePayload(_ value: Value, _ context: WriteContext) throws { + guard let value else { + throw ForyError.invalidData("Option.none cannot write raw payload") + } + try WrappedCodec.writePayload(value, context) } - try WrappedCodec.writePayload(value, context) - } - public static func readPayload(_ context: ReadContext) throws -> Value { - try WrappedCodec.readPayload(context) - } + public static func readPayload(_ context: ReadContext) throws -> Value { + try WrappedCodec.readPayload(context) + } - public static func writeStaticTypeInfo(_ context: WriteContext) throws { - try WrappedCodec.writeStaticTypeInfo(context) - } + public static func writeStaticTypeInfo(_ context: WriteContext) throws { + try WrappedCodec.writeStaticTypeInfo(context) + } - public static func readTypeInfo(_ context: ReadContext) throws -> TypeInfo? { - try WrappedCodec.readTypeInfo(context) - } + public static func readTypeInfo(_ context: ReadContext) throws -> TypeInfo? { + try WrappedCodec.readTypeInfo(context) + } - public static func withTypeInfo( - _ typeInfo: TypeInfo?, - _ context: ReadContext, - _ body: () throws -> R - ) rethrows -> R { - try WrappedCodec.withTypeInfo(typeInfo, context, body) - } + public static func withTypeInfo( + _ typeInfo: TypeInfo?, + _ context: ReadContext, + _ body: () throws -> R + ) rethrows -> R { + try WrappedCodec.withTypeInfo(typeInfo, context, body) + } - public static func write( - _ value: Value, - _ context: WriteContext, - refMode: RefMode, - writeTypeInfo: Bool - ) throws { - switch refMode { - case .none: - guard let value else { - throw ForyError.invalidData("Option.none with RefMode.none") - } - try WrappedCodec.write(value, context, refMode: .none, writeTypeInfo: writeTypeInfo) - case .nullOnly: - guard let value else { - context.buffer.writeInt8(RefFlag.null.rawValue) - return - } - context.buffer.writeInt8(RefFlag.notNullValue.rawValue) - try WrappedCodec.write(value, context, refMode: .none, writeTypeInfo: writeTypeInfo) - case .tracking: - guard let value else { - context.buffer.writeInt8(RefFlag.null.rawValue) - return - } - try WrappedCodec.write(value, context, refMode: .tracking, writeTypeInfo: writeTypeInfo) + public static func write( + _ value: Value, + _ context: WriteContext, + refMode: RefMode, + writeTypeInfo: Bool + ) throws { + switch refMode { + case .none: + guard let value else { + throw ForyError.invalidData("Option.none with RefMode.none") + } + try WrappedCodec.write(value, context, refMode: .none, writeTypeInfo: writeTypeInfo) + case .nullOnly: + guard let value else { + context.buffer.writeInt8(RefFlag.null.rawValue) + return + } + context.buffer.writeInt8(RefFlag.notNullValue.rawValue) + try WrappedCodec.write(value, context, refMode: .none, writeTypeInfo: writeTypeInfo) + case .tracking: + guard let value else { + context.buffer.writeInt8(RefFlag.null.rawValue) + return + } + try WrappedCodec.write(value, context, refMode: .tracking, writeTypeInfo: writeTypeInfo) + } } - } - public static func read( - _ context: ReadContext, - refMode: RefMode, - readTypeInfo: Bool - ) throws -> Value { - switch refMode { - case .none: - return try WrappedCodec.read(context, refMode: .none, readTypeInfo: readTypeInfo) - case .nullOnly: - let refFlag = try context.buffer.readInt8() - if refFlag == RefFlag.null.rawValue { - return nil - } - return try WrappedCodec.read(context, refMode: .none, readTypeInfo: readTypeInfo) - case .tracking: - let refFlag = try context.buffer.readInt8() - if refFlag == RefFlag.null.rawValue { - return nil - } - context.buffer.moveBack(1) - return try WrappedCodec.read(context, refMode: .tracking, readTypeInfo: readTypeInfo) + public static func read( + _ context: ReadContext, + refMode: RefMode, + readTypeInfo: Bool + ) throws -> Value { + switch refMode { + case .none: + return try WrappedCodec.read(context, refMode: .none, readTypeInfo: readTypeInfo) + case .nullOnly: + let refFlag = try context.buffer.readInt8() + if refFlag == RefFlag.null.rawValue { + return nil + } + return try WrappedCodec.read(context, refMode: .none, readTypeInfo: readTypeInfo) + case .tracking: + let refFlag = try context.buffer.readInt8() + if refFlag == RefFlag.null.rawValue { + return nil + } + context.buffer.moveBack(1) + return try WrappedCodec.read(context, refMode: .tracking, readTypeInfo: readTypeInfo) + } } - } } public enum BoolCodec: FieldCodec { - public static let typeId: TypeId = .bool - public static let defaultValue = false - public static func writePayload(_ value: Bool, _ context: WriteContext) { - context.buffer.writeUInt8(value ? 1 : 0) - } - public static func readPayload(_ context: ReadContext) throws -> Bool { - try context.buffer.readUInt8() != 0 - } + public static let typeId: TypeId = .bool + public static let defaultValue = false + public static func writePayload(_ value: Bool, _ context: WriteContext) { + context.buffer.writeUInt8(value ? 1 : 0) + } + public static func readPayload(_ context: ReadContext) throws -> Bool { + try context.buffer.readUInt8() != 0 + } } public enum Int8Codec: FieldCodec { - public static let typeId: TypeId = .int8 - public static let defaultValue = Int8(0) - public static func writePayload(_ value: Int8, _ context: WriteContext) { - context.buffer.writeInt8(value) - } - public static func readPayload(_ context: ReadContext) throws -> Int8 { - try context.buffer.readInt8() - } + public static let typeId: TypeId = .int8 + public static let defaultValue = Int8(0) + public static func writePayload(_ value: Int8, _ context: WriteContext) { + context.buffer.writeInt8(value) + } + public static func readPayload(_ context: ReadContext) throws -> Int8 { + try context.buffer.readInt8() + } } public enum Int16Codec: FieldCodec { - public static let typeId: TypeId = .int16 - public static let defaultValue = Int16(0) - public static func writePayload(_ value: Int16, _ context: WriteContext) { - context.buffer.writeInt16(value) - } - public static func readPayload(_ context: ReadContext) throws -> Int16 { - try context.buffer.readInt16() - } + public static let typeId: TypeId = .int16 + public static let defaultValue = Int16(0) + public static func writePayload(_ value: Int16, _ context: WriteContext) { + context.buffer.writeInt16(value) + } + public static func readPayload(_ context: ReadContext) throws -> Int16 { + try context.buffer.readInt16() + } } public enum Int32VarintCodec: FieldCodec { - public static let typeId: TypeId = .varint32 - public static let defaultValue = Int32(0) - public static func writePayload(_ value: Int32, _ context: WriteContext) { - context.buffer.writeVarInt32(value) - } - public static func readPayload(_ context: ReadContext) throws -> Int32 { - try context.buffer.readVarInt32() - } + public static let typeId: TypeId = .varint32 + public static let defaultValue = Int32(0) + public static func writePayload(_ value: Int32, _ context: WriteContext) { + context.buffer.writeVarInt32(value) + } + public static func readPayload(_ context: ReadContext) throws -> Int32 { + try context.buffer.readVarInt32() + } } public enum Int32FixedCodec: FieldCodec { - public static let typeId: TypeId = .int32 - public static let defaultValue = Int32(0) - public static func writePayload(_ value: Int32, _ context: WriteContext) { - context.buffer.writeInt32(value) - } - public static func readPayload(_ context: ReadContext) throws -> Int32 { - try context.buffer.readInt32() - } + public static let typeId: TypeId = .int32 + public static let defaultValue = Int32(0) + public static func writePayload(_ value: Int32, _ context: WriteContext) { + context.buffer.writeInt32(value) + } + public static func readPayload(_ context: ReadContext) throws -> Int32 { + try context.buffer.readInt32() + } } public enum Int64VarintCodec: FieldCodec { - public static let typeId: TypeId = .varint64 - public static let defaultValue = Int64(0) - public static func writePayload(_ value: Int64, _ context: WriteContext) { - context.buffer.writeVarInt64(value) - } - public static func readPayload(_ context: ReadContext) throws -> Int64 { - try context.buffer.readVarInt64() - } + public static let typeId: TypeId = .varint64 + public static let defaultValue = Int64(0) + public static func writePayload(_ value: Int64, _ context: WriteContext) { + context.buffer.writeVarInt64(value) + } + public static func readPayload(_ context: ReadContext) throws -> Int64 { + try context.buffer.readVarInt64() + } } public enum Int64FixedCodec: FieldCodec { - public static let typeId: TypeId = .int64 - public static let defaultValue = Int64(0) - public static func writePayload(_ value: Int64, _ context: WriteContext) { - context.buffer.writeInt64(value) - } - public static func readPayload(_ context: ReadContext) throws -> Int64 { - try context.buffer.readInt64() - } + public static let typeId: TypeId = .int64 + public static let defaultValue = Int64(0) + public static func writePayload(_ value: Int64, _ context: WriteContext) { + context.buffer.writeInt64(value) + } + public static func readPayload(_ context: ReadContext) throws -> Int64 { + try context.buffer.readInt64() + } } public enum Int64TaggedCodec: FieldCodec { - public static let typeId: TypeId = .taggedInt64 - public static let defaultValue = Int64(0) - public static func writePayload(_ value: Int64, _ context: WriteContext) { - context.buffer.writeTaggedInt64(value) - } - public static func readPayload(_ context: ReadContext) throws -> Int64 { - try context.buffer.readTaggedInt64() - } + public static let typeId: TypeId = .taggedInt64 + public static let defaultValue = Int64(0) + public static func writePayload(_ value: Int64, _ context: WriteContext) { + context.buffer.writeTaggedInt64(value) + } + public static func readPayload(_ context: ReadContext) throws -> Int64 { + try context.buffer.readTaggedInt64() + } } public enum UInt8Codec: FieldCodec { - public static let typeId: TypeId = .uint8 - public static let defaultValue = UInt8(0) - public static func writePayload(_ value: UInt8, _ context: WriteContext) { - context.buffer.writeUInt8(value) - } - public static func readPayload(_ context: ReadContext) throws -> UInt8 { - try context.buffer.readUInt8() - } + public static let typeId: TypeId = .uint8 + public static let defaultValue = UInt8(0) + public static func writePayload(_ value: UInt8, _ context: WriteContext) { + context.buffer.writeUInt8(value) + } + public static func readPayload(_ context: ReadContext) throws -> UInt8 { + try context.buffer.readUInt8() + } } public enum UInt16Codec: FieldCodec { - public static let typeId: TypeId = .uint16 - public static let defaultValue = UInt16(0) - public static func writePayload(_ value: UInt16, _ context: WriteContext) { - context.buffer.writeUInt16(value) - } - public static func readPayload(_ context: ReadContext) throws -> UInt16 { - try context.buffer.readUInt16() - } + public static let typeId: TypeId = .uint16 + public static let defaultValue = UInt16(0) + public static func writePayload(_ value: UInt16, _ context: WriteContext) { + context.buffer.writeUInt16(value) + } + public static func readPayload(_ context: ReadContext) throws -> UInt16 { + try context.buffer.readUInt16() + } } public enum UInt32VarintCodec: FieldCodec { - public static let typeId: TypeId = .varUInt32 - public static let defaultValue = UInt32(0) - public static func writePayload(_ value: UInt32, _ context: WriteContext) { - context.buffer.writeVarUInt32(value) - } - public static func readPayload(_ context: ReadContext) throws -> UInt32 { - try context.buffer.readVarUInt32() - } + public static let typeId: TypeId = .varUInt32 + public static let defaultValue = UInt32(0) + public static func writePayload(_ value: UInt32, _ context: WriteContext) { + context.buffer.writeVarUInt32(value) + } + public static func readPayload(_ context: ReadContext) throws -> UInt32 { + try context.buffer.readVarUInt32() + } } public enum UInt32FixedCodec: FieldCodec { - public static let typeId: TypeId = .uint32 - public static let defaultValue = UInt32(0) - public static func writePayload(_ value: UInt32, _ context: WriteContext) { - context.buffer.writeUInt32(value) - } - public static func readPayload(_ context: ReadContext) throws -> UInt32 { - try context.buffer.readUInt32() - } + public static let typeId: TypeId = .uint32 + public static let defaultValue = UInt32(0) + public static func writePayload(_ value: UInt32, _ context: WriteContext) { + context.buffer.writeUInt32(value) + } + public static func readPayload(_ context: ReadContext) throws -> UInt32 { + try context.buffer.readUInt32() + } } public enum UInt64VarintCodec: FieldCodec { - public static let typeId: TypeId = .varUInt64 - public static let defaultValue = UInt64(0) - public static func writePayload(_ value: UInt64, _ context: WriteContext) { - context.buffer.writeVarUInt64(value) - } - public static func readPayload(_ context: ReadContext) throws -> UInt64 { - try context.buffer.readVarUInt64() - } + public static let typeId: TypeId = .varUInt64 + public static let defaultValue = UInt64(0) + public static func writePayload(_ value: UInt64, _ context: WriteContext) { + context.buffer.writeVarUInt64(value) + } + public static func readPayload(_ context: ReadContext) throws -> UInt64 { + try context.buffer.readVarUInt64() + } } public enum UInt64FixedCodec: FieldCodec { - public static let typeId: TypeId = .uint64 - public static let defaultValue = UInt64(0) - public static func writePayload(_ value: UInt64, _ context: WriteContext) { - context.buffer.writeUInt64(value) - } - public static func readPayload(_ context: ReadContext) throws -> UInt64 { - try context.buffer.readUInt64() - } + public static let typeId: TypeId = .uint64 + public static let defaultValue = UInt64(0) + public static func writePayload(_ value: UInt64, _ context: WriteContext) { + context.buffer.writeUInt64(value) + } + public static func readPayload(_ context: ReadContext) throws -> UInt64 { + try context.buffer.readUInt64() + } } public enum UInt64TaggedCodec: FieldCodec { - public static let typeId: TypeId = .taggedUInt64 - public static let defaultValue = UInt64(0) - public static func writePayload(_ value: UInt64, _ context: WriteContext) { - context.buffer.writeTaggedUInt64(value) - } - public static func readPayload(_ context: ReadContext) throws -> UInt64 { - try context.buffer.readTaggedUInt64() - } + public static let typeId: TypeId = .taggedUInt64 + public static let defaultValue = UInt64(0) + public static func writePayload(_ value: UInt64, _ context: WriteContext) { + context.buffer.writeTaggedUInt64(value) + } + public static func readPayload(_ context: ReadContext) throws -> UInt64 { + try context.buffer.readTaggedUInt64() + } } public enum IntVarintCodec: FieldCodec { - public static let typeId: TypeId = .varint64 - public static let defaultValue = Int(0) - public static func writePayload(_ value: Int, _ context: WriteContext) { - context.buffer.writeVarInt64(Int64(value)) - } - public static func readPayload(_ context: ReadContext) throws -> Int { - Int(try context.buffer.readVarInt64()) - } + public static let typeId: TypeId = .varint64 + public static let defaultValue = Int(0) + public static func writePayload(_ value: Int, _ context: WriteContext) { + context.buffer.writeVarInt64(Int64(value)) + } + public static func readPayload(_ context: ReadContext) throws -> Int { + Int(try context.buffer.readVarInt64()) + } } public enum IntFixedCodec: FieldCodec { - public static let typeId: TypeId = .int64 - public static let defaultValue = Int(0) - public static func writePayload(_ value: Int, _ context: WriteContext) { - context.buffer.writeInt64(Int64(value)) - } - public static func readPayload(_ context: ReadContext) throws -> Int { - Int(try context.buffer.readInt64()) - } + public static let typeId: TypeId = .int64 + public static let defaultValue = Int(0) + public static func writePayload(_ value: Int, _ context: WriteContext) { + context.buffer.writeInt64(Int64(value)) + } + public static func readPayload(_ context: ReadContext) throws -> Int { + Int(try context.buffer.readInt64()) + } } public enum IntTaggedCodec: FieldCodec { - public static let typeId: TypeId = .taggedInt64 - public static let defaultValue = Int(0) - public static func writePayload(_ value: Int, _ context: WriteContext) { - context.buffer.writeTaggedInt64(Int64(value)) - } - public static func readPayload(_ context: ReadContext) throws -> Int { - Int(try context.buffer.readTaggedInt64()) - } + public static let typeId: TypeId = .taggedInt64 + public static let defaultValue = Int(0) + public static func writePayload(_ value: Int, _ context: WriteContext) { + context.buffer.writeTaggedInt64(Int64(value)) + } + public static func readPayload(_ context: ReadContext) throws -> Int { + Int(try context.buffer.readTaggedInt64()) + } } public enum UIntVarintCodec: FieldCodec { - public static let typeId: TypeId = .varUInt64 - public static let defaultValue = UInt(0) - public static func writePayload(_ value: UInt, _ context: WriteContext) { - context.buffer.writeVarUInt64(UInt64(value)) - } - public static func readPayload(_ context: ReadContext) throws -> UInt { - UInt(try context.buffer.readVarUInt64()) - } + public static let typeId: TypeId = .varUInt64 + public static let defaultValue = UInt(0) + public static func writePayload(_ value: UInt, _ context: WriteContext) { + context.buffer.writeVarUInt64(UInt64(value)) + } + public static func readPayload(_ context: ReadContext) throws -> UInt { + UInt(try context.buffer.readVarUInt64()) + } } public enum UIntFixedCodec: FieldCodec { - public static let typeId: TypeId = .uint64 - public static let defaultValue = UInt(0) - public static func writePayload(_ value: UInt, _ context: WriteContext) { - context.buffer.writeUInt64(UInt64(value)) - } - public static func readPayload(_ context: ReadContext) throws -> UInt { - UInt(try context.buffer.readUInt64()) - } + public static let typeId: TypeId = .uint64 + public static let defaultValue = UInt(0) + public static func writePayload(_ value: UInt, _ context: WriteContext) { + context.buffer.writeUInt64(UInt64(value)) + } + public static func readPayload(_ context: ReadContext) throws -> UInt { + UInt(try context.buffer.readUInt64()) + } } public enum UIntTaggedCodec: FieldCodec { - public static let typeId: TypeId = .taggedUInt64 - public static let defaultValue = UInt(0) - public static func writePayload(_ value: UInt, _ context: WriteContext) { - context.buffer.writeTaggedUInt64(UInt64(value)) - } - public static func readPayload(_ context: ReadContext) throws -> UInt { - UInt(try context.buffer.readTaggedUInt64()) - } + public static let typeId: TypeId = .taggedUInt64 + public static let defaultValue = UInt(0) + public static func writePayload(_ value: UInt, _ context: WriteContext) { + context.buffer.writeTaggedUInt64(UInt64(value)) + } + public static func readPayload(_ context: ReadContext) throws -> UInt { + UInt(try context.buffer.readTaggedUInt64()) + } } public enum Float16Codec: FieldCodec { - public static let typeId: TypeId = .float16 - public static let defaultValue = Float16(0) - public static func writePayload(_ value: Float16, _ context: WriteContext) { - context.buffer.writeUInt16(value.bitPattern) - } - public static func readPayload(_ context: ReadContext) throws -> Float16 { - Float16(bitPattern: try context.buffer.readUInt16()) - } + public static let typeId: TypeId = .float16 + public static let defaultValue = Float16(0) + public static func writePayload(_ value: Float16, _ context: WriteContext) { + context.buffer.writeUInt16(value.bitPattern) + } + public static func readPayload(_ context: ReadContext) throws -> Float16 { + Float16(bitPattern: try context.buffer.readUInt16()) + } } public enum BFloat16Codec: FieldCodec { - public static let typeId: TypeId = .bfloat16 - public static let defaultValue = BFloat16() - public static func writePayload(_ value: BFloat16, _ context: WriteContext) { - context.buffer.writeUInt16(value.rawValue) - } - public static func readPayload(_ context: ReadContext) throws -> BFloat16 { - BFloat16(rawValue: try context.buffer.readUInt16()) - } + public static let typeId: TypeId = .bfloat16 + public static let defaultValue = BFloat16() + public static func writePayload(_ value: BFloat16, _ context: WriteContext) { + context.buffer.writeUInt16(value.rawValue) + } + public static func readPayload(_ context: ReadContext) throws -> BFloat16 { + BFloat16(rawValue: try context.buffer.readUInt16()) + } } public enum FloatCodec: FieldCodec { - public static let typeId: TypeId = .float32 - public static let defaultValue = Float(0) - public static func writePayload(_ value: Float, _ context: WriteContext) { - context.buffer.writeFloat32(value) - } - public static func readPayload(_ context: ReadContext) throws -> Float { - try context.buffer.readFloat32() - } + public static let typeId: TypeId = .float32 + public static let defaultValue = Float(0) + public static func writePayload(_ value: Float, _ context: WriteContext) { + context.buffer.writeFloat32(value) + } + public static func readPayload(_ context: ReadContext) throws -> Float { + try context.buffer.readFloat32() + } } public enum DoubleCodec: FieldCodec { - public static let typeId: TypeId = .float64 - public static let defaultValue = Double(0) - public static func writePayload(_ value: Double, _ context: WriteContext) { - context.buffer.writeFloat64(value) - } - public static func readPayload(_ context: ReadContext) throws -> Double { - try context.buffer.readFloat64() - } + public static let typeId: TypeId = .float64 + public static let defaultValue = Double(0) + public static func writePayload(_ value: Double, _ context: WriteContext) { + context.buffer.writeFloat64(value) + } + public static func readPayload(_ context: ReadContext) throws -> Double { + try context.buffer.readFloat64() + } } public typealias StringCodec = SerializerCodec @@ -699,1194 +701,1198 @@ public typealias DecimalCodec = SerializerCodec public typealias DataCodec = SerializerCodec public enum ListFieldCodec: FieldCodec { - public typealias Value = [ElementCodec.Value] - - public static var typeId: TypeId { .list } - public static var defaultValue: Value { [] } - - public static func fieldType(nullable: Bool, trackRef: Bool) -> TypeMeta.FieldType { - return TypeMeta.FieldType( - typeID: TypeId.list.rawValue, - nullable: nullable, - trackRef: trackRef, - generics: [ - ElementCodec.fieldType( - nullable: ElementCodec.isNullableType, - trackRef: trackRef && ElementCodec.isRefType) - ] - ) - } + public typealias Value = [ElementCodec.Value] + + public static var typeId: TypeId { .list } + public static var defaultValue: Value { [] } + + public static func fieldType(nullable: Bool, trackRef: Bool) -> TypeMeta.FieldType { + return TypeMeta.FieldType( + typeID: TypeId.list.rawValue, + nullable: nullable, + trackRef: trackRef, + generics: [ + ElementCodec.fieldType( + nullable: ElementCodec.isNullableType, + trackRef: trackRef && ElementCodec.isRefType) + ] + ) + } - public static func writePayload(_ value: Value, _ context: WriteContext) throws { - try writeCollectionPayload(value, context, elementCodec: ElementCodec.self) - } + public static func writePayload(_ value: Value, _ context: WriteContext) throws { + try writeCollectionPayload(value, context, elementCodec: ElementCodec.self) + } - public static func readPayload(_ context: ReadContext) throws -> Value { - return try readCollectionPayload(context, elementCodec: ElementCodec.self) - } + public static func readPayload(_ context: ReadContext) throws -> Value { + return try readCollectionPayload(context, elementCodec: ElementCodec.self) + } - public static func readCompatibleField( - _ context: ReadContext, - remoteFieldType: TypeMeta.FieldType, - refMode: RefMode - ) throws -> Value { - if isCompatiblePackedArrayTypeID(remoteFieldType.typeID, elementCodec: ElementCodec.self) { - return try readCompatiblePackedArrayField( - context, refMode: refMode, elementCodec: ElementCodec.self) - } - return try FieldCodecDefault.readCompatibleField( - codec: Self.self, - context, - remoteFieldType: remoteFieldType, - refMode: refMode - ) - } + public static func readCompatibleField( + _ context: ReadContext, + remoteFieldType: TypeMeta.FieldType, + refMode: RefMode + ) throws -> Value { + if isCompatiblePackedArrayTypeID(remoteFieldType.typeID, elementCodec: ElementCodec.self) { + return try readCompatiblePackedArrayField( + context, refMode: refMode, elementCodec: ElementCodec.self) + } + return try FieldCodecDefault.readCompatibleField( + codec: Self.self, + context, + remoteFieldType: remoteFieldType, + refMode: refMode + ) + } } public enum ArrayFieldCodec: FieldCodec { - public typealias Value = [ElementCodec.Value] + public typealias Value = [ElementCodec.Value] - public static var typeId: TypeId { - guard let typeID = packedArrayTypeID(for: ElementCodec.self) else { - preconditionFailure("ArrayFieldCodec requires a non-null numeric or bool element codec") + public static var typeId: TypeId { + guard let typeID = packedArrayTypeID(for: ElementCodec.self) else { + preconditionFailure("ArrayFieldCodec requires a non-null numeric or bool element codec") + } + return typeID } - return typeID - } - - public static var defaultValue: Value { [] } - public static func fieldType(nullable: Bool, trackRef: Bool) -> TypeMeta.FieldType { - TypeMeta.FieldType(typeID: typeId.rawValue, nullable: nullable, trackRef: trackRef) - } + public static var defaultValue: Value { [] } - public static func writePayload(_ value: Value, _ context: WriteContext) throws { - if try writePackedArrayPayload(value, context, elementCodec: ElementCodec.self) { - return + public static func fieldType(nullable: Bool, trackRef: Bool) -> TypeMeta.FieldType { + TypeMeta.FieldType(typeID: typeId.rawValue, nullable: nullable, trackRef: trackRef) } - throw ForyError.invalidData("unsupported array field element codec \(ElementCodec.self)") - } - public static func readPayload(_ context: ReadContext) throws -> Value { - if let value = try readPackedArrayPayload(context, elementCodec: ElementCodec.self) { - return value + public static func writePayload(_ value: Value, _ context: WriteContext) throws { + if try writePackedArrayPayload(value, context, elementCodec: ElementCodec.self) { + return + } + throw ForyError.invalidData("unsupported array field element codec \(ElementCodec.self)") } - throw ForyError.invalidData("unsupported array field element codec \(ElementCodec.self)") - } - public static func readCompatibleField( - _ context: ReadContext, - remoteFieldType: TypeMeta.FieldType, - refMode: RefMode - ) throws -> Value { - if remoteFieldType.typeID == TypeId.list.rawValue, - let element = remoteFieldType.generics.first, - let localArrayTypeID = packedArrayTypeID(for: ElementCodec.self), - TypeId.listElementTypeID(element.typeID, matchesDenseArrayTypeID: localArrayTypeID.rawValue) - { - return try readListPayloadAsArray( - context, - refMode: refMode, - elementCodec: ElementCodec.self, - remoteElementTypeID: element.typeID - ) - } - return try FieldCodecDefault.readCompatibleField( - codec: Self.self, - context, - remoteFieldType: remoteFieldType, - refMode: refMode - ) - } - - public static func write( - _ value: Value, - _ context: WriteContext, - refMode: RefMode, - writeTypeInfo: Bool - ) throws { - if refMode == .none, !writeTypeInfo { - try writePayload(value, context) - return - } - if refMode != .none { - context.buffer.writeInt8(RefFlag.notNullValue.rawValue) + public static func readPayload(_ context: ReadContext) throws -> Value { + if let value = try readPackedArrayPayload(context, elementCodec: ElementCodec.self) { + return value + } + throw ForyError.invalidData("unsupported array field element codec \(ElementCodec.self)") } - if writeTypeInfo { - try writeStaticTypeInfo(context) + + public static func readCompatibleField( + _ context: ReadContext, + remoteFieldType: TypeMeta.FieldType, + refMode: RefMode + ) throws -> Value { + if remoteFieldType.typeID == TypeId.list.rawValue, + let element = remoteFieldType.generics.first, + let localArrayTypeID = packedArrayTypeID(for: ElementCodec.self), + TypeId.listElementTypeID(element.typeID, matchesDenseArrayTypeID: localArrayTypeID.rawValue) + { + return try readListPayloadAsArray( + context, + refMode: refMode, + elementCodec: ElementCodec.self, + remoteElementTypeID: element.typeID + ) + } + return try FieldCodecDefault.readCompatibleField( + codec: Self.self, + context, + remoteFieldType: remoteFieldType, + refMode: refMode + ) } - try writePayload(value, context) - } - public static func read( - _ context: ReadContext, - refMode: RefMode, - readTypeInfo: Bool - ) throws -> Value { - switch refMode { - case .none: - return try readPayloadAfterTypeInfo(context, readTypeInfo: readTypeInfo) - case .nullOnly: - let rawFlag = try context.buffer.readInt8() - switch rawFlag { - case RefFlag.null.rawValue: - return defaultValue - case RefFlag.notNullValue.rawValue: - return try readPayloadAfterTypeInfo(context, readTypeInfo: readTypeInfo) - case RefFlag.refValue.rawValue: - if context.trackRef { - let reservedRefID = context.refReader.reserveRefID() - let value = try readPayloadAfterTypeInfo(context, readTypeInfo: readTypeInfo) - context.refReader.storeRef(value, at: reservedRefID) - return value + public static func write( + _ value: Value, + _ context: WriteContext, + refMode: RefMode, + writeTypeInfo: Bool + ) throws { + if refMode == .none, !writeTypeInfo { + try writePayload(value, context) + return } - return try readPayloadAfterTypeInfo(context, readTypeInfo: readTypeInfo) - case RefFlag.ref.rawValue: - let refID = try context.buffer.readVarUInt32() - return try context.refReader.readRef(refID, as: Value.self) - default: - throw ForyError.refError("invalid ref flag \(rawFlag)") - } - case .tracking: - let rawFlag = try context.buffer.readInt8() - guard let flag = RefFlag(rawValue: rawFlag) else { - throw ForyError.refError("invalid ref flag \(rawFlag)") - } - switch flag { - case .null: - return defaultValue - case .ref: - let refID = try context.buffer.readVarUInt32() - return try context.refReader.readRef(refID, as: Value.self) - case .refValue: - let reservedRefID = context.trackRef ? context.refReader.reserveRefID() : nil - let value = try readPayloadAfterTypeInfo(context, readTypeInfo: readTypeInfo) - if let reservedRefID { - context.refReader.storeRef(value, at: reservedRefID) + if refMode != .none { + context.buffer.writeInt8(RefFlag.notNullValue.rawValue) } - return value - case .notNullValue: - return try readPayloadAfterTypeInfo(context, readTypeInfo: readTypeInfo) - } + if writeTypeInfo { + try writeStaticTypeInfo(context) + } + try writePayload(value, context) } - } - private static func readPayloadAfterTypeInfo( - _ context: ReadContext, - readTypeInfo: Bool - ) throws -> Value { - if readTypeInfo { - let typeInfo = try Self.readTypeInfo(context) - return try withTypeInfo(typeInfo, context) { - try readPayload(context) - } - } - return try readPayload(context) - } + public static func read( + _ context: ReadContext, + refMode: RefMode, + readTypeInfo: Bool + ) throws -> Value { + switch refMode { + case .none: + return try readPayloadAfterTypeInfo(context, readTypeInfo: readTypeInfo) + case .nullOnly: + let rawFlag = try context.buffer.readInt8() + switch rawFlag { + case RefFlag.null.rawValue: + return defaultValue + case RefFlag.notNullValue.rawValue: + return try readPayloadAfterTypeInfo(context, readTypeInfo: readTypeInfo) + case RefFlag.refValue.rawValue: + if context.trackRef { + let reservedRefID = context.refReader.reserveRefID() + let value = try readPayloadAfterTypeInfo(context, readTypeInfo: readTypeInfo) + context.refReader.storeRef(value, at: reservedRefID) + return value + } + return try readPayloadAfterTypeInfo(context, readTypeInfo: readTypeInfo) + case RefFlag.ref.rawValue: + let refID = try context.buffer.readVarUInt32() + return try context.refReader.readRef(refID, as: Value.self) + default: + throw ForyError.refError("invalid ref flag \(rawFlag)") + } + case .tracking: + let rawFlag = try context.buffer.readInt8() + guard let flag = RefFlag(rawValue: rawFlag) else { + throw ForyError.refError("invalid ref flag \(rawFlag)") + } + switch flag { + case .null: + return defaultValue + case .ref: + let refID = try context.buffer.readVarUInt32() + return try context.refReader.readRef(refID, as: Value.self) + case .refValue: + let reservedRefID = context.trackRef ? context.refReader.reserveRefID() : nil + let value = try readPayloadAfterTypeInfo(context, readTypeInfo: readTypeInfo) + if let reservedRefID { + context.refReader.storeRef(value, at: reservedRefID) + } + return value + case .notNullValue: + return try readPayloadAfterTypeInfo(context, readTypeInfo: readTypeInfo) + } + } + } + + private static func readPayloadAfterTypeInfo( + _ context: ReadContext, + readTypeInfo: Bool + ) throws -> Value { + if readTypeInfo { + let typeInfo = try Self.readTypeInfo(context) + return try withTypeInfo(typeInfo, context) { + try readPayload(context) + } + } + return try readPayload(context) + } } public enum SetFieldCodec: FieldCodec where ElementCodec.Value: Hashable { - public typealias Value = Set - - public static var typeId: TypeId { .set } - public static var defaultValue: Value { [] } - - public static func fieldType(nullable: Bool, trackRef: Bool) -> TypeMeta.FieldType { - TypeMeta.FieldType( - typeID: TypeId.set.rawValue, - nullable: nullable, - trackRef: trackRef, - generics: [ - ElementCodec.fieldType( - nullable: ElementCodec.isNullableType, - trackRef: trackRef && ElementCodec.isRefType) - ] - ) - } + public typealias Value = Set + + public static var typeId: TypeId { .set } + public static var defaultValue: Value { [] } + + public static func fieldType(nullable: Bool, trackRef: Bool) -> TypeMeta.FieldType { + TypeMeta.FieldType( + typeID: TypeId.set.rawValue, + nullable: nullable, + trackRef: trackRef, + generics: [ + ElementCodec.fieldType( + nullable: ElementCodec.isNullableType, + trackRef: trackRef && ElementCodec.isRefType) + ] + ) + } - public static func writePayload(_ value: Value, _ context: WriteContext) throws { - try writeCollectionPayload(Array(value), context, elementCodec: ElementCodec.self) - } + public static func writePayload(_ value: Value, _ context: WriteContext) throws { + try writeCollectionPayload(Array(value), context, elementCodec: ElementCodec.self) + } - public static func readPayload(_ context: ReadContext) throws -> Value { - let values = try readCollectionPayload(context, elementCodec: ElementCodec.self) - try reserveFieldArrayStorage(context, ElementCodec.self, count: values.count) - return Set(values) - } + public static func readPayload(_ context: ReadContext) throws -> Value { + let values = try readCollectionPayload(context, elementCodec: ElementCodec.self) + try reserveFieldArrayStorage(context, ElementCodec.self, count: values.count) + return Set(values) + } } public enum MapFieldCodec: FieldCodec where KeyCodec.Value: Hashable { - public typealias Value = [KeyCodec.Value: ValueCodec.Value] - - private struct MapEntryWriteOptions { - var trackKeyRef: Bool - var trackValueRef: Bool - var keyDeclared: Bool - var valueDeclared: Bool - var keyDynamicType: Bool - var valueDynamicType: Bool - var keyIsNil: Bool - var valueIsNil: Bool - } - - public static var typeId: TypeId { .map } - public static var defaultValue: Value { [:] } - - public static func fieldType(nullable: Bool, trackRef: Bool) -> TypeMeta.FieldType { - TypeMeta.FieldType( - typeID: TypeId.map.rawValue, - nullable: nullable, - trackRef: trackRef, - generics: [ - KeyCodec.fieldType( - nullable: KeyCodec.isNullableType, - trackRef: trackRef && KeyCodec.isRefType), - ValueCodec.fieldType( - nullable: ValueCodec.isNullableType, - trackRef: trackRef && ValueCodec.isRefType), - ] - ) - } - - public static func writePayload(_ value: Value, _ context: WriteContext) throws { - context.buffer.writeVarUInt32(UInt32(value.count)) - if value.isEmpty { - return - } - - let trackKeyRef = context.trackRef && KeyCodec.isRefType - let trackValueRef = context.trackRef && ValueCodec.isRefType - let keyDeclared = !TypeId.needsTypeInfoForField(KeyCodec.typeId) - let valueDeclared = !TypeId.needsTypeInfoForField(ValueCodec.typeId) - let keyDynamicType = KeyCodec.typeId == .unknown - let valueDynamicType = ValueCodec.typeId == .unknown - let commonOptions = MapEntryWriteOptions( - trackKeyRef: trackKeyRef, - trackValueRef: trackValueRef, - keyDeclared: keyDeclared, - valueDeclared: valueDeclared, - keyDynamicType: keyDynamicType, - valueDynamicType: valueDynamicType, - keyIsNil: false, - valueIsNil: false - ) + public typealias Value = [KeyCodec.Value: ValueCodec.Value] + + private struct MapEntryWriteOptions { + var trackKeyRef: Bool + var trackValueRef: Bool + var keyDeclared: Bool + var valueDeclared: Bool + var keyDynamicType: Bool + var valueDynamicType: Bool + var keyIsNil: Bool + var valueIsNil: Bool + } - var iterator = value.makeIterator() - var pendingPair = iterator.next() - while let pair = pendingPair { - let keyIsNil = KeyCodec.isNone(pair.key) - let valueIsNil = ValueCodec.isNone(pair.value) - - if keyDynamicType || valueDynamicType || keyIsNil || valueIsNil { - var options = commonOptions - options.keyIsNil = keyIsNil - options.valueIsNil = valueIsNil - try writeMapEntry( - pair, - context, - options: options + public static var typeId: TypeId { .map } + public static var defaultValue: Value { [:] } + + public static func fieldType(nullable: Bool, trackRef: Bool) -> TypeMeta.FieldType { + TypeMeta.FieldType( + typeID: TypeId.map.rawValue, + nullable: nullable, + trackRef: trackRef, + generics: [ + KeyCodec.fieldType( + nullable: KeyCodec.isNullableType, + trackRef: trackRef && KeyCodec.isRefType), + ValueCodec.fieldType( + nullable: ValueCodec.isNullableType, + trackRef: trackRef && ValueCodec.isRefType) + ] ) - pendingPair = iterator.next() - continue - } - - var header: UInt8 = 0 - if trackKeyRef { header |= MapHeader.trackingKeyRef } - if trackValueRef { header |= MapHeader.trackingValueRef } - if keyDeclared { header |= MapHeader.declaredKeyType } - if valueDeclared { header |= MapHeader.declaredValueType } - - context.buffer.writeUInt8(header) - let chunkSizeOffset = context.buffer.count - context.buffer.writeUInt8(0) - - if !keyDeclared { - try KeyCodec.writeStaticTypeInfo(context) - } - if !valueDeclared { - try ValueCodec.writeStaticTypeInfo(context) - } - - var chunkSize: UInt8 = 0 - while chunkSize < UInt8.max, let current = pendingPair { - if KeyCodec.isNone(current.key) || ValueCodec.isNone(current.value) { - break + } + + public static func writePayload(_ value: Value, _ context: WriteContext) throws { + context.buffer.writeVarUInt32(UInt32(value.count)) + if value.isEmpty { + return } - try writeMapPayload( - current, - context, - trackKeyRef: trackKeyRef, - trackValueRef: trackValueRef - ) - chunkSize &+= 1 - pendingPair = iterator.next() - } - context.buffer.setByte(at: chunkSizeOffset, to: chunkSize) - } - } - - public static func readPayload(_ context: ReadContext) throws -> Value { - let totalLength = Int(try context.buffer.readVarUInt32()) - try context.ensureCollectionLength(totalLength, label: "map") - if totalLength == 0 { - try reserveFieldMapStorage( - context, key: KeyCodec.self, value: ValueCodec.self, count: totalLength) - return [:] - } - - try reserveFieldMapStorage( - context, key: KeyCodec.self, value: ValueCodec.self, count: totalLength) - try context.ensureRemainingBytes(totalLength, label: "map") - var map: Value = [:] - map.reserveCapacity(totalLength) - var readCount = 0 - while readCount < totalLength { - let header = try context.buffer.readUInt8() - // IMPORTANT: map readers must obey the sender-written key/value ref - // bits in this header. Local Swift field metadata must not - // override that decision while reading. Shared xlang tests - // intentionally deserialize one ref policy and then serialize - // another local payload. DO NOT REMOVE this comment. - let trackKeyRef = (header & MapHeader.trackingKeyRef) != 0 - let keyNull = (header & MapHeader.keyNull) != 0 - let keyDeclared = (header & MapHeader.declaredKeyType) != 0 - - let trackValueRef = (header & MapHeader.trackingValueRef) != 0 - let valueNull = (header & MapHeader.valueNull) != 0 - let valueDeclared = (header & MapHeader.declaredValueType) != 0 - - if keyNull && valueNull { - map[KeyCodec.defaultValue] = ValueCodec.defaultValue - readCount += 1 - continue - } - - if keyNull { - let value = try readMapValue( - context, - declared: valueDeclared, - trackRef: trackValueRef - ) - map[KeyCodec.defaultValue] = value - readCount += 1 - continue - } - - if valueNull { - let key = try readMapKey( - context, - declared: keyDeclared, - trackRef: trackKeyRef + + let trackKeyRef = context.trackRef && KeyCodec.isRefType + let trackValueRef = context.trackRef && ValueCodec.isRefType + let keyDeclared = !TypeId.needsTypeInfoForField(KeyCodec.typeId) + let valueDeclared = !TypeId.needsTypeInfoForField(ValueCodec.typeId) + let keyDynamicType = KeyCodec.typeId == .unknown + let valueDynamicType = ValueCodec.typeId == .unknown + let commonOptions = MapEntryWriteOptions( + trackKeyRef: trackKeyRef, + trackValueRef: trackValueRef, + keyDeclared: keyDeclared, + valueDeclared: valueDeclared, + keyDynamicType: keyDynamicType, + valueDynamicType: valueDynamicType, + keyIsNil: false, + valueIsNil: false ) - map[key] = ValueCodec.defaultValue - readCount += 1 - continue - } - - let chunkSize = Int(try context.buffer.readUInt8()) - if chunkSize > (totalLength - readCount) { - throw ForyError.invalidData("map chunk size exceeds remaining entries") - } - let keyTypeInfo = keyDeclared ? nil : try KeyCodec.readTypeInfo(context) - let valueTypeInfo = valueDeclared ? nil : try ValueCodec.readTypeInfo(context) - for _ in 0.. Value { + let totalLength = Int(try context.buffer.readVarUInt32()) + try context.ensureCollectionLength(totalLength, label: "map") + if totalLength == 0 { + try reserveFieldMapStorage( + context, key: KeyCodec.self, value: ValueCodec.self, count: totalLength) + return [:] + } + + try reserveFieldMapStorage( + context, key: KeyCodec.self, value: ValueCodec.self, count: totalLength) + try context.ensureRemainingBytes(totalLength, label: "map") + var map: Value = [:] + map.reserveCapacity(totalLength) + var readCount = 0 + while readCount < totalLength { + let header = try context.buffer.readUInt8() + // IMPORTANT: map readers must obey the sender-written key/value ref + // bits in this header. Local Swift field metadata must not + // override that decision while reading. Shared xlang tests + // intentionally deserialize one ref policy and then serialize + // another local payload. DO NOT REMOVE this comment. + let trackKeyRef = (header & MapHeader.trackingKeyRef) != 0 + let keyNull = (header & MapHeader.keyNull) != 0 + let keyDeclared = (header & MapHeader.declaredKeyType) != 0 + + let trackValueRef = (header & MapHeader.trackingValueRef) != 0 + let valueNull = (header & MapHeader.valueNull) != 0 + let valueDeclared = (header & MapHeader.declaredValueType) != 0 + + if keyNull && valueNull { + map[KeyCodec.defaultValue] = ValueCodec.defaultValue + readCount += 1 + continue + } + + if keyNull { + let value = try readMapValue( + context, + declared: valueDeclared, + trackRef: trackValueRef + ) + map[KeyCodec.defaultValue] = value + readCount += 1 + continue + } + + if valueNull { + let key = try readMapKey( + context, + declared: keyDeclared, + trackRef: trackKeyRef + ) + map[key] = ValueCodec.defaultValue + readCount += 1 + continue + } + + let chunkSize = Int(try context.buffer.readUInt8()) + if chunkSize > (totalLength - readCount) { + throw ForyError.invalidData("map chunk size exceeds remaining entries") + } + let keyTypeInfo = keyDeclared ? nil : try KeyCodec.readTypeInfo(context) + let valueTypeInfo = valueDeclared ? nil : try ValueCodec.readTypeInfo(context) + for _ in 0...Element, + _ context: WriteContext, + options: MapEntryWriteOptions + ) throws { + var header: UInt8 = 0 + if options.trackKeyRef { header |= MapHeader.trackingKeyRef } + if options.trackValueRef { header |= MapHeader.trackingValueRef } + if options.keyIsNil { + header |= MapHeader.keyNull + } else if !options.keyDynamicType && options.keyDeclared { + header |= MapHeader.declaredKeyType + } + if options.valueIsNil { + header |= MapHeader.valueNull + } else if !options.valueDynamicType && options.valueDeclared { + header |= MapHeader.declaredValueType + } + context.buffer.writeUInt8(header) + + if !options.keyIsNil { + if !options.keyDeclared { + try KeyCodec.writeStaticTypeInfo(context) + } + try KeyCodec.write( + pair.key, + context, + refMode: options.trackKeyRef ? .tracking : .none, + writeTypeInfo: false + ) + } + if !options.valueIsNil { + if !options.valueDeclared { + try ValueCodec.writeStaticTypeInfo(context) + } + try ValueCodec.write( + pair.value, + context, + refMode: options.trackValueRef ? .tracking : .none, + writeTypeInfo: false + ) + } + } + + private static func writeMapPayload( + _ pair: Dictionary.Element, + _ context: WriteContext, + trackKeyRef: Bool, + trackValueRef: Bool + ) throws { + try KeyCodec.write( + pair.key, context, refMode: trackKeyRef ? .tracking : .none, - readTypeInfo: false - ) - } - let value = try ValueCodec.withTypeInfo(valueTypeInfo, context) { - try ValueCodec.read( + writeTypeInfo: false + ) + try ValueCodec.write( + pair.value, context, refMode: trackValueRef ? .tracking : .none, - readTypeInfo: false - ) - } - map[key] = value - } - readCount += chunkSize + writeTypeInfo: false + ) } - return map - } - private static func writeMapEntry( - _ pair: Dictionary.Element, - _ context: WriteContext, - options: MapEntryWriteOptions - ) throws { - var header: UInt8 = 0 - if options.trackKeyRef { header |= MapHeader.trackingKeyRef } - if options.trackValueRef { header |= MapHeader.trackingValueRef } - if options.keyIsNil { - header |= MapHeader.keyNull - } else if !options.keyDynamicType && options.keyDeclared { - header |= MapHeader.declaredKeyType - } - if options.valueIsNil { - header |= MapHeader.valueNull - } else if !options.valueDynamicType && options.valueDeclared { - header |= MapHeader.declaredValueType - } - context.buffer.writeUInt8(header) - - if !options.keyIsNil { - if !options.keyDeclared { - try KeyCodec.writeStaticTypeInfo(context) - } - try KeyCodec.write( - pair.key, - context, - refMode: options.trackKeyRef ? .tracking : .none, - writeTypeInfo: false - ) - } - if !options.valueIsNil { - if !options.valueDeclared { - try ValueCodec.writeStaticTypeInfo(context) - } - try ValueCodec.write( - pair.value, - context, - refMode: options.trackValueRef ? .tracking : .none, - writeTypeInfo: false - ) - } - } - - private static func writeMapPayload( - _ pair: Dictionary.Element, - _ context: WriteContext, - trackKeyRef: Bool, - trackValueRef: Bool - ) throws { - try KeyCodec.write( - pair.key, - context, - refMode: trackKeyRef ? .tracking : .none, - writeTypeInfo: false - ) - try ValueCodec.write( - pair.value, - context, - refMode: trackValueRef ? .tracking : .none, - writeTypeInfo: false - ) - } - - private static func readMapKey( - _ context: ReadContext, - declared: Bool, - trackRef: Bool - ) throws -> KeyCodec.Value { - let typeInfo = declared ? nil : try KeyCodec.readTypeInfo(context) - return try KeyCodec.withTypeInfo(typeInfo, context) { - try KeyCodec.read(context, refMode: trackRef ? .tracking : .none, readTypeInfo: false) + private static func readMapKey( + _ context: ReadContext, + declared: Bool, + trackRef: Bool + ) throws -> KeyCodec.Value { + let typeInfo = declared ? nil : try KeyCodec.readTypeInfo(context) + return try KeyCodec.withTypeInfo(typeInfo, context) { + try KeyCodec.read(context, refMode: trackRef ? .tracking : .none, readTypeInfo: false) + } } - } - private static func readMapValue( - _ context: ReadContext, - declared: Bool, - trackRef: Bool - ) throws -> ValueCodec.Value { - let typeInfo = declared ? nil : try ValueCodec.readTypeInfo(context) - return try ValueCodec.withTypeInfo(typeInfo, context) { - try ValueCodec.read(context, refMode: trackRef ? .tracking : .none, readTypeInfo: false) - } - } + private static func readMapValue( + _ context: ReadContext, + declared: Bool, + trackRef: Bool + ) throws -> ValueCodec.Value { + let typeInfo = declared ? nil : try ValueCodec.readTypeInfo(context) + return try ValueCodec.withTypeInfo(typeInfo, context) { + try ValueCodec.read(context, refMode: trackRef ? .tracking : .none, readTypeInfo: false) + } + } } @inline(__always) private func uncheckedPackedArrayCast(_ array: [From], to _: To.Type) -> [To] { - assert(From.self == To.self) - return unsafeBitCast(array, to: [To].self) + assert(From.self == To.self) + return unsafeBitCast(array, to: [To].self) } @inline(__always) private func uncheckedScalarCast(_ value: From, to _: To.Type) -> To { - assert(From.self == To.self) - return unsafeBitCast(value, to: To.self) + assert(From.self == To.self) + return unsafeBitCast(value, to: To.self) } private func packedArrayTypeID(for _: ElementCodec.Type) -> TypeId? { - if ElementCodec.isNullableType { + if ElementCodec.isNullableType { + return nil + } + if ElementCodec.self == BoolCodec.self { + return .boolArray + } + if ElementCodec.self == Int8Codec.self { + return .int8Array + } + if ElementCodec.self == Int16Codec.self { + return .int16Array + } + if ElementCodec.self == Int32FixedCodec.self { + return .int32Array + } + if ElementCodec.self == Int64FixedCodec.self || ElementCodec.self == IntFixedCodec.self { + return .int64Array + } + if ElementCodec.self == UInt8Codec.self { + return .uint8Array + } + if ElementCodec.self == UInt16Codec.self { + return .uint16Array + } + if ElementCodec.self == UInt32FixedCodec.self { + return .uint32Array + } + if ElementCodec.self == UInt64FixedCodec.self || ElementCodec.self == UIntFixedCodec.self { + return .uint64Array + } + if ElementCodec.self == Float16Codec.self { + return .float16Array + } + if ElementCodec.self == BFloat16Codec.self { + return .bfloat16Array + } + if ElementCodec.self == FloatCodec.self { + return .float32Array + } + if ElementCodec.self == DoubleCodec.self { + return .float64Array + } return nil - } - if ElementCodec.self == BoolCodec.self { - return .boolArray - } - if ElementCodec.self == Int8Codec.self { - return .int8Array - } - if ElementCodec.self == Int16Codec.self { - return .int16Array - } - if ElementCodec.self == Int32FixedCodec.self { - return .int32Array - } - if ElementCodec.self == Int64FixedCodec.self || ElementCodec.self == IntFixedCodec.self { - return .int64Array - } - if ElementCodec.self == UInt8Codec.self { - return .uint8Array - } - if ElementCodec.self == UInt16Codec.self { - return .uint16Array - } - if ElementCodec.self == UInt32FixedCodec.self { - return .uint32Array - } - if ElementCodec.self == UInt64FixedCodec.self || ElementCodec.self == UIntFixedCodec.self { - return .uint64Array - } - if ElementCodec.self == Float16Codec.self { - return .float16Array - } - if ElementCodec.self == BFloat16Codec.self { - return .bfloat16Array - } - if ElementCodec.self == FloatCodec.self { - return .float32Array - } - if ElementCodec.self == DoubleCodec.self { - return .float64Array - } - return nil } private func isCompatiblePackedArrayTypeID( - _ typeID: UInt32, - elementCodec _: ElementCodec.Type + _ typeID: UInt32, + elementCodec _: ElementCodec.Type ) -> Bool { - TypeId.listElementTypeID(ElementCodec.typeId.rawValue, matchesDenseArrayTypeID: typeID) + TypeId.listElementTypeID(ElementCodec.typeId.rawValue, matchesDenseArrayTypeID: typeID) } private func writePackedArrayPayload( - _ value: [ElementCodec.Value], - _ context: WriteContext, - elementCodec _: ElementCodec.Type + _ value: [ElementCodec.Value], + _ context: WriteContext, + elementCodec _: ElementCodec.Type ) throws -> Bool { - if ElementCodec.self == BoolCodec.self { - writePrimitiveArray(uncheckedPackedArrayCast(value, to: Bool.self), context: context) - return true - } - if ElementCodec.self == Int8Codec.self { - writePrimitiveArray(uncheckedPackedArrayCast(value, to: Int8.self), context: context) - return true - } - if ElementCodec.self == Int16Codec.self { - writePrimitiveArray(uncheckedPackedArrayCast(value, to: Int16.self), context: context) - return true - } - if ElementCodec.self == Int32FixedCodec.self { - writePrimitiveArray(uncheckedPackedArrayCast(value, to: Int32.self), context: context) - return true - } - if ElementCodec.self == Int64FixedCodec.self { - writePrimitiveArray(uncheckedPackedArrayCast(value, to: Int64.self), context: context) - return true - } - if ElementCodec.self == IntFixedCodec.self { - writeIntArrayPayload(uncheckedPackedArrayCast(value, to: Int.self), context) - return true - } - if ElementCodec.self == UInt8Codec.self { - writePrimitiveArray(uncheckedPackedArrayCast(value, to: UInt8.self), context: context) - return true - } - if ElementCodec.self == UInt16Codec.self { - writePrimitiveArray(uncheckedPackedArrayCast(value, to: UInt16.self), context: context) - return true - } - if ElementCodec.self == UInt32FixedCodec.self { - writePrimitiveArray(uncheckedPackedArrayCast(value, to: UInt32.self), context: context) - return true - } - if ElementCodec.self == UInt64FixedCodec.self { - writePrimitiveArray(uncheckedPackedArrayCast(value, to: UInt64.self), context: context) - return true - } - if ElementCodec.self == UIntFixedCodec.self { - writeUIntArrayPayload(uncheckedPackedArrayCast(value, to: UInt.self), context) - return true - } - if ElementCodec.self == Float16Codec.self { - writePrimitiveArray(uncheckedPackedArrayCast(value, to: Float16.self), context: context) - return true - } - if ElementCodec.self == BFloat16Codec.self { - writePrimitiveArray(uncheckedPackedArrayCast(value, to: BFloat16.self), context: context) - return true - } - if ElementCodec.self == FloatCodec.self { - writePrimitiveArray(uncheckedPackedArrayCast(value, to: Float.self), context: context) - return true - } - if ElementCodec.self == DoubleCodec.self { - writePrimitiveArray(uncheckedPackedArrayCast(value, to: Double.self), context: context) - return true - } - return false + if ElementCodec.self == BoolCodec.self { + writePrimitiveArray(uncheckedPackedArrayCast(value, to: Bool.self), context: context) + return true + } + if ElementCodec.self == Int8Codec.self { + writePrimitiveArray(uncheckedPackedArrayCast(value, to: Int8.self), context: context) + return true + } + if ElementCodec.self == Int16Codec.self { + writePrimitiveArray(uncheckedPackedArrayCast(value, to: Int16.self), context: context) + return true + } + if ElementCodec.self == Int32FixedCodec.self { + writePrimitiveArray(uncheckedPackedArrayCast(value, to: Int32.self), context: context) + return true + } + if ElementCodec.self == Int64FixedCodec.self { + writePrimitiveArray(uncheckedPackedArrayCast(value, to: Int64.self), context: context) + return true + } + if ElementCodec.self == IntFixedCodec.self { + writeIntArrayPayload(uncheckedPackedArrayCast(value, to: Int.self), context) + return true + } + if ElementCodec.self == UInt8Codec.self { + writePrimitiveArray(uncheckedPackedArrayCast(value, to: UInt8.self), context: context) + return true + } + if ElementCodec.self == UInt16Codec.self { + writePrimitiveArray(uncheckedPackedArrayCast(value, to: UInt16.self), context: context) + return true + } + if ElementCodec.self == UInt32FixedCodec.self { + writePrimitiveArray(uncheckedPackedArrayCast(value, to: UInt32.self), context: context) + return true + } + if ElementCodec.self == UInt64FixedCodec.self { + writePrimitiveArray(uncheckedPackedArrayCast(value, to: UInt64.self), context: context) + return true + } + if ElementCodec.self == UIntFixedCodec.self { + writeUIntArrayPayload(uncheckedPackedArrayCast(value, to: UInt.self), context) + return true + } + if ElementCodec.self == Float16Codec.self { + writePrimitiveArray(uncheckedPackedArrayCast(value, to: Float16.self), context: context) + return true + } + if ElementCodec.self == BFloat16Codec.self { + writePrimitiveArray(uncheckedPackedArrayCast(value, to: BFloat16.self), context: context) + return true + } + if ElementCodec.self == FloatCodec.self { + writePrimitiveArray(uncheckedPackedArrayCast(value, to: Float.self), context: context) + return true + } + if ElementCodec.self == DoubleCodec.self { + writePrimitiveArray(uncheckedPackedArrayCast(value, to: Double.self), context: context) + return true + } + return false } private func readPackedArrayPayload( - _ context: ReadContext, - elementCodec _: ElementCodec.Type + _ context: ReadContext, + elementCodec _: ElementCodec.Type ) throws -> [ElementCodec.Value]? { - if ElementCodec.self == BoolCodec.self { - return uncheckedPackedArrayCast( - try readPrimitiveArray(context) as [Bool], to: ElementCodec.Value.self) - } - if ElementCodec.self == Int8Codec.self { - return uncheckedPackedArrayCast( - try readPrimitiveArray(context) as [Int8], to: ElementCodec.Value.self) - } - if ElementCodec.self == Int16Codec.self { - return uncheckedPackedArrayCast( - try readPrimitiveArray(context) as [Int16], to: ElementCodec.Value.self) - } - if ElementCodec.self == Int32FixedCodec.self { - return uncheckedPackedArrayCast( - try readPrimitiveArray(context) as [Int32], to: ElementCodec.Value.self) - } - if ElementCodec.self == Int64FixedCodec.self { - return uncheckedPackedArrayCast( - try readPrimitiveArray(context) as [Int64], to: ElementCodec.Value.self) - } - if ElementCodec.self == IntFixedCodec.self { - return uncheckedPackedArrayCast(try readIntArrayPayload(context), to: ElementCodec.Value.self) - } - if ElementCodec.self == UInt8Codec.self { - return uncheckedPackedArrayCast( - try readPrimitiveArray(context) as [UInt8], to: ElementCodec.Value.self) - } - if ElementCodec.self == UInt16Codec.self { - return uncheckedPackedArrayCast( - try readPrimitiveArray(context) as [UInt16], to: ElementCodec.Value.self) - } - if ElementCodec.self == UInt32FixedCodec.self { - return uncheckedPackedArrayCast( - try readPrimitiveArray(context) as [UInt32], to: ElementCodec.Value.self) - } - if ElementCodec.self == UInt64FixedCodec.self { - return uncheckedPackedArrayCast( - try readPrimitiveArray(context) as [UInt64], to: ElementCodec.Value.self) - } - if ElementCodec.self == UIntFixedCodec.self { - return uncheckedPackedArrayCast(try readUIntArrayPayload(context), to: ElementCodec.Value.self) - } - if ElementCodec.self == Float16Codec.self { - return uncheckedPackedArrayCast( - try readPrimitiveArray(context) as [Float16], to: ElementCodec.Value.self) - } - if ElementCodec.self == BFloat16Codec.self { - return uncheckedPackedArrayCast( - try readPrimitiveArray(context) as [BFloat16], to: ElementCodec.Value.self) - } - if ElementCodec.self == FloatCodec.self { - return uncheckedPackedArrayCast( - try readPrimitiveArray(context) as [Float], to: ElementCodec.Value.self) - } - if ElementCodec.self == DoubleCodec.self { - return uncheckedPackedArrayCast( - try readPrimitiveArray(context) as [Double], to: ElementCodec.Value.self) - } - return nil + if ElementCodec.self == BoolCodec.self { + return uncheckedPackedArrayCast( + try readPrimitiveArray(context) as [Bool], to: ElementCodec.Value.self) + } + if ElementCodec.self == Int8Codec.self { + return uncheckedPackedArrayCast( + try readPrimitiveArray(context) as [Int8], to: ElementCodec.Value.self) + } + if ElementCodec.self == Int16Codec.self { + return uncheckedPackedArrayCast( + try readPrimitiveArray(context) as [Int16], to: ElementCodec.Value.self) + } + if ElementCodec.self == Int32FixedCodec.self { + return uncheckedPackedArrayCast( + try readPrimitiveArray(context) as [Int32], to: ElementCodec.Value.self) + } + if ElementCodec.self == Int64FixedCodec.self { + return uncheckedPackedArrayCast( + try readPrimitiveArray(context) as [Int64], to: ElementCodec.Value.self) + } + if ElementCodec.self == IntFixedCodec.self { + return uncheckedPackedArrayCast(try readIntArrayPayload(context), to: ElementCodec.Value.self) + } + if ElementCodec.self == UInt8Codec.self { + return uncheckedPackedArrayCast( + try readPrimitiveArray(context) as [UInt8], to: ElementCodec.Value.self) + } + if ElementCodec.self == UInt16Codec.self { + return uncheckedPackedArrayCast( + try readPrimitiveArray(context) as [UInt16], to: ElementCodec.Value.self) + } + if ElementCodec.self == UInt32FixedCodec.self { + return uncheckedPackedArrayCast( + try readPrimitiveArray(context) as [UInt32], to: ElementCodec.Value.self) + } + if ElementCodec.self == UInt64FixedCodec.self { + return uncheckedPackedArrayCast( + try readPrimitiveArray(context) as [UInt64], to: ElementCodec.Value.self) + } + if ElementCodec.self == UIntFixedCodec.self { + return uncheckedPackedArrayCast(try readUIntArrayPayload(context), to: ElementCodec.Value.self) + } + if ElementCodec.self == Float16Codec.self { + return uncheckedPackedArrayCast( + try readPrimitiveArray(context) as [Float16], to: ElementCodec.Value.self) + } + if ElementCodec.self == BFloat16Codec.self { + return uncheckedPackedArrayCast( + try readPrimitiveArray(context) as [BFloat16], to: ElementCodec.Value.self) + } + if ElementCodec.self == FloatCodec.self { + return uncheckedPackedArrayCast( + try readPrimitiveArray(context) as [Float], to: ElementCodec.Value.self) + } + if ElementCodec.self == DoubleCodec.self { + return uncheckedPackedArrayCast( + try readPrimitiveArray(context) as [Double], to: ElementCodec.Value.self) + } + return nil } private func writeIntArrayPayload(_ value: [Int], _ context: WriteContext) { - context.buffer.writeVarUInt32(UInt32(value.count * 8)) - for item in value { - context.buffer.writeInt64(Int64(item)) - } + context.buffer.writeVarUInt32(UInt32(value.count * 8)) + for item in value { + context.buffer.writeInt64(Int64(item)) + } } private func writeUIntArrayPayload(_ value: [UInt], _ context: WriteContext) { - context.buffer.writeVarUInt32(UInt32(value.count * 8)) - for item in value { - context.buffer.writeUInt64(UInt64(item)) - } + context.buffer.writeVarUInt32(UInt32(value.count * 8)) + for item in value { + context.buffer.writeUInt64(UInt64(item)) + } } -private func readIntArrayPayload(_ context: ReadContext, reserveGraphStorage: Bool = false) throws - -> [Int] +private func readIntArrayPayload( + _ context: ReadContext, reserveGraphStorage: Bool = false +) throws + -> [Int] { - let count = try readPackedArrayElementCount(context, width: 8, label: "int64_array") - if reserveGraphStorage { - try reserveSerializerArrayMemory(context, Int.self, count: count) - } - var values: [Int] = [] - values.reserveCapacity(count) - for _ in 0.. [UInt] +private func readUIntArrayPayload( + _ context: ReadContext, reserveGraphStorage: Bool = false +) throws + -> [UInt] { - let count = try readPackedArrayElementCount(context, width: 8, label: "uint64_array") - if reserveGraphStorage { - try reserveSerializerArrayMemory(context, UInt.self, count: count) - } - var values: [UInt] = [] - values.reserveCapacity(count) - for _ in 0..( - _ context: ReadContext, - refMode: RefMode, - elementCodec _: ElementCodec.Type + _ context: ReadContext, + refMode: RefMode, + elementCodec _: ElementCodec.Type ) throws -> [ElementCodec.Value] { - switch refMode { - case .none: - return try readCompatiblePackedArrayPayload(context, elementCodec: ElementCodec.self) - case .nullOnly, .tracking: - let rawFlag = try context.buffer.readInt8() - guard rawFlag != RefFlag.null.rawValue else { - return [] - } - if rawFlag == RefFlag.ref.rawValue { - let refID = try context.buffer.readVarUInt32() - return try context.refReader.readRef(refID, as: [ElementCodec.Value].self) - } - let reservedRefID = - (rawFlag == RefFlag.refValue.rawValue && context.trackRef) - ? context.refReader.reserveRefID() - : nil - let value = try readCompatiblePackedArrayPayload(context, elementCodec: ElementCodec.self) - if let reservedRefID { - context.refReader.storeRef(value, at: reservedRefID) - } - return value - } + switch refMode { + case .none: + return try readCompatiblePackedArrayPayload(context, elementCodec: ElementCodec.self) + case .nullOnly, .tracking: + let rawFlag = try context.buffer.readInt8() + guard rawFlag != RefFlag.null.rawValue else { + return [] + } + if rawFlag == RefFlag.ref.rawValue { + let refID = try context.buffer.readVarUInt32() + return try context.refReader.readRef(refID, as: [ElementCodec.Value].self) + } + let reservedRefID = + (rawFlag == RefFlag.refValue.rawValue && context.trackRef) + ? context.refReader.reserveRefID() + : nil + let value = try readCompatiblePackedArrayPayload(context, elementCodec: ElementCodec.self) + if let reservedRefID { + context.refReader.storeRef(value, at: reservedRefID) + } + return value + } } private func readCompatiblePackedArrayPayload( - _ context: ReadContext, - elementCodec _: ElementCodec.Type + _ context: ReadContext, + elementCodec _: ElementCodec.Type ) throws -> [ElementCodec.Value] { - if ElementCodec.self == BoolCodec.self { - return uncheckedPackedArrayCast( - try readPrimitiveArray(context, reserveGraphStorage: true) as [Bool], - to: ElementCodec.Value.self) - } - if ElementCodec.self == Int8Codec.self { - return uncheckedPackedArrayCast( - try readPrimitiveArray(context, reserveGraphStorage: true) as [Int8], - to: ElementCodec.Value.self) - } - if ElementCodec.self == Int16Codec.self { - return uncheckedPackedArrayCast( - try readPrimitiveArray(context, reserveGraphStorage: true) as [Int16], - to: ElementCodec.Value.self) - } - if ElementCodec.self == Int32FixedCodec.self || ElementCodec.self == Int32VarintCodec.self { - return uncheckedPackedArrayCast( - try readPrimitiveArray(context, reserveGraphStorage: true) as [Int32], - to: ElementCodec.Value.self) - } - if ElementCodec.self == Int64FixedCodec.self || ElementCodec.self == Int64VarintCodec.self - || ElementCodec.self == Int64TaggedCodec.self - { - return uncheckedPackedArrayCast( - try readPrimitiveArray(context, reserveGraphStorage: true) as [Int64], - to: ElementCodec.Value.self) - } - if ElementCodec.self == IntFixedCodec.self || ElementCodec.self == IntVarintCodec.self - || ElementCodec.self == IntTaggedCodec.self - { - return uncheckedPackedArrayCast( - try readIntArrayPayload(context, reserveGraphStorage: true), to: ElementCodec.Value.self) - } - if ElementCodec.self == UInt8Codec.self { - return uncheckedPackedArrayCast( - try readPrimitiveArray(context, reserveGraphStorage: true) as [UInt8], - to: ElementCodec.Value.self) - } - if ElementCodec.self == UInt16Codec.self { - return uncheckedPackedArrayCast( - try readPrimitiveArray(context, reserveGraphStorage: true) as [UInt16], - to: ElementCodec.Value.self) - } - if ElementCodec.self == UInt32FixedCodec.self || ElementCodec.self == UInt32VarintCodec.self { - return uncheckedPackedArrayCast( - try readPrimitiveArray(context, reserveGraphStorage: true) as [UInt32], - to: ElementCodec.Value.self) - } - if ElementCodec.self == UInt64FixedCodec.self || ElementCodec.self == UInt64VarintCodec.self - || ElementCodec.self == UInt64TaggedCodec.self - { - return uncheckedPackedArrayCast( - try readPrimitiveArray(context, reserveGraphStorage: true) as [UInt64], - to: ElementCodec.Value.self) - } - if ElementCodec.self == UIntFixedCodec.self || ElementCodec.self == UIntVarintCodec.self - || ElementCodec.self == UIntTaggedCodec.self - { - return uncheckedPackedArrayCast( - try readUIntArrayPayload(context, reserveGraphStorage: true), to: ElementCodec.Value.self) - } - if ElementCodec.self == Float16Codec.self { - return uncheckedPackedArrayCast( - try readPrimitiveArray(context, reserveGraphStorage: true) as [Float16], - to: ElementCodec.Value.self) - } - if ElementCodec.self == BFloat16Codec.self { - return uncheckedPackedArrayCast( - try readPrimitiveArray(context, reserveGraphStorage: true) as [BFloat16], - to: ElementCodec.Value.self) - } - if ElementCodec.self == FloatCodec.self { - return uncheckedPackedArrayCast( - try readPrimitiveArray(context, reserveGraphStorage: true) as [Float], - to: ElementCodec.Value.self) - } - if ElementCodec.self == DoubleCodec.self { - return uncheckedPackedArrayCast( - try readPrimitiveArray(context, reserveGraphStorage: true) as [Double], - to: ElementCodec.Value.self) - } - throw ForyError.invalidData( - "unsupported compatible array-to-list field element codec \(ElementCodec.self)") + if ElementCodec.self == BoolCodec.self { + return uncheckedPackedArrayCast( + try readPrimitiveArray(context, reserveGraphStorage: true) as [Bool], + to: ElementCodec.Value.self) + } + if ElementCodec.self == Int8Codec.self { + return uncheckedPackedArrayCast( + try readPrimitiveArray(context, reserveGraphStorage: true) as [Int8], + to: ElementCodec.Value.self) + } + if ElementCodec.self == Int16Codec.self { + return uncheckedPackedArrayCast( + try readPrimitiveArray(context, reserveGraphStorage: true) as [Int16], + to: ElementCodec.Value.self) + } + if ElementCodec.self == Int32FixedCodec.self || ElementCodec.self == Int32VarintCodec.self { + return uncheckedPackedArrayCast( + try readPrimitiveArray(context, reserveGraphStorage: true) as [Int32], + to: ElementCodec.Value.self) + } + if ElementCodec.self == Int64FixedCodec.self || ElementCodec.self == Int64VarintCodec.self + || ElementCodec.self == Int64TaggedCodec.self + { + return uncheckedPackedArrayCast( + try readPrimitiveArray(context, reserveGraphStorage: true) as [Int64], + to: ElementCodec.Value.self) + } + if ElementCodec.self == IntFixedCodec.self || ElementCodec.self == IntVarintCodec.self + || ElementCodec.self == IntTaggedCodec.self + { + return uncheckedPackedArrayCast( + try readIntArrayPayload(context, reserveGraphStorage: true), to: ElementCodec.Value.self) + } + if ElementCodec.self == UInt8Codec.self { + return uncheckedPackedArrayCast( + try readPrimitiveArray(context, reserveGraphStorage: true) as [UInt8], + to: ElementCodec.Value.self) + } + if ElementCodec.self == UInt16Codec.self { + return uncheckedPackedArrayCast( + try readPrimitiveArray(context, reserveGraphStorage: true) as [UInt16], + to: ElementCodec.Value.self) + } + if ElementCodec.self == UInt32FixedCodec.self || ElementCodec.self == UInt32VarintCodec.self { + return uncheckedPackedArrayCast( + try readPrimitiveArray(context, reserveGraphStorage: true) as [UInt32], + to: ElementCodec.Value.self) + } + if ElementCodec.self == UInt64FixedCodec.self || ElementCodec.self == UInt64VarintCodec.self + || ElementCodec.self == UInt64TaggedCodec.self + { + return uncheckedPackedArrayCast( + try readPrimitiveArray(context, reserveGraphStorage: true) as [UInt64], + to: ElementCodec.Value.self) + } + if ElementCodec.self == UIntFixedCodec.self || ElementCodec.self == UIntVarintCodec.self + || ElementCodec.self == UIntTaggedCodec.self + { + return uncheckedPackedArrayCast( + try readUIntArrayPayload(context, reserveGraphStorage: true), to: ElementCodec.Value.self) + } + if ElementCodec.self == Float16Codec.self { + return uncheckedPackedArrayCast( + try readPrimitiveArray(context, reserveGraphStorage: true) as [Float16], + to: ElementCodec.Value.self) + } + if ElementCodec.self == BFloat16Codec.self { + return uncheckedPackedArrayCast( + try readPrimitiveArray(context, reserveGraphStorage: true) as [BFloat16], + to: ElementCodec.Value.self) + } + if ElementCodec.self == FloatCodec.self { + return uncheckedPackedArrayCast( + try readPrimitiveArray(context, reserveGraphStorage: true) as [Float], + to: ElementCodec.Value.self) + } + if ElementCodec.self == DoubleCodec.self { + return uncheckedPackedArrayCast( + try readPrimitiveArray(context, reserveGraphStorage: true) as [Double], + to: ElementCodec.Value.self) + } + throw ForyError.invalidData( + "unsupported compatible array-to-list field element codec \(ElementCodec.self)") } private func readCompatibleElementPayload( - _ context: ReadContext, - elementCodec _: ElementCodec.Type, - remoteElementTypeID: UInt32? + _ context: ReadContext, + elementCodec _: ElementCodec.Type, + remoteElementTypeID: UInt32? ) throws -> ElementCodec.Value { - guard let remoteElementTypeID, - remoteElementTypeID != ElementCodec.typeId.rawValue, - let remoteTypeID = TypeId(rawValue: remoteElementTypeID) - else { - return try ElementCodec.readPayload(context) - } - - if ElementCodec.self == Int32FixedCodec.self || ElementCodec.self == Int32VarintCodec.self { - switch remoteTypeID { - case .int32: - return uncheckedScalarCast( - try context.buffer.readInt32() as Int32, to: ElementCodec.Value.self) - case .varint32: - return uncheckedScalarCast( - try context.buffer.readVarInt32() as Int32, to: ElementCodec.Value.self) - default: - break - } - } - if ElementCodec.self == Int64FixedCodec.self || ElementCodec.self == Int64VarintCodec.self - || ElementCodec.self == Int64TaggedCodec.self - { - switch remoteTypeID { - case .int64: - return uncheckedScalarCast( - try context.buffer.readInt64() as Int64, to: ElementCodec.Value.self) - case .varint64: - return uncheckedScalarCast( - try context.buffer.readVarInt64() as Int64, to: ElementCodec.Value.self) - case .taggedInt64: - return uncheckedScalarCast( - try context.buffer.readTaggedInt64() as Int64, to: ElementCodec.Value.self) - default: - break - } - } - if ElementCodec.self == IntFixedCodec.self || ElementCodec.self == IntVarintCodec.self - || ElementCodec.self == IntTaggedCodec.self - { - switch remoteTypeID { - case .int64: - return uncheckedScalarCast(Int(try context.buffer.readInt64()), to: ElementCodec.Value.self) - case .varint64: - return uncheckedScalarCast( - Int(try context.buffer.readVarInt64()), to: ElementCodec.Value.self) - case .taggedInt64: - return uncheckedScalarCast( - Int(try context.buffer.readTaggedInt64()), to: ElementCodec.Value.self) - default: - break - } - } - if ElementCodec.self == UInt32FixedCodec.self || ElementCodec.self == UInt32VarintCodec.self { - switch remoteTypeID { - case .uint32: - return uncheckedScalarCast( - try context.buffer.readUInt32() as UInt32, to: ElementCodec.Value.self) - case .varUInt32: - return uncheckedScalarCast( - try context.buffer.readVarUInt32() as UInt32, to: ElementCodec.Value.self) - default: - break - } - } - if ElementCodec.self == UInt64FixedCodec.self || ElementCodec.self == UInt64VarintCodec.self - || ElementCodec.self == UInt64TaggedCodec.self - { - switch remoteTypeID { - case .uint64: - return uncheckedScalarCast( - try context.buffer.readUInt64() as UInt64, to: ElementCodec.Value.self) - case .varUInt64: - return uncheckedScalarCast( - try context.buffer.readVarUInt64() as UInt64, to: ElementCodec.Value.self) - case .taggedUInt64: - return uncheckedScalarCast( - try context.buffer.readTaggedUInt64() as UInt64, to: ElementCodec.Value.self) - default: - break - } - } - if ElementCodec.self == UIntFixedCodec.self || ElementCodec.self == UIntVarintCodec.self - || ElementCodec.self == UIntTaggedCodec.self - { - switch remoteTypeID { - case .uint64: - return uncheckedScalarCast(UInt(try context.buffer.readUInt64()), to: ElementCodec.Value.self) - case .varUInt64: - return uncheckedScalarCast( - UInt(try context.buffer.readVarUInt64()), to: ElementCodec.Value.self) - case .taggedUInt64: - return uncheckedScalarCast( - UInt(try context.buffer.readTaggedUInt64()), to: ElementCodec.Value.self) - default: - break - } - } - throw ForyError.typeMismatch(expected: ElementCodec.typeId.rawValue, actual: remoteElementTypeID) + guard let remoteElementTypeID, + remoteElementTypeID != ElementCodec.typeId.rawValue, + let remoteTypeID = TypeId(rawValue: remoteElementTypeID) + else { + return try ElementCodec.readPayload(context) + } + + if ElementCodec.self == Int32FixedCodec.self || ElementCodec.self == Int32VarintCodec.self { + switch remoteTypeID { + case .int32: + return uncheckedScalarCast( + try context.buffer.readInt32() as Int32, to: ElementCodec.Value.self) + case .varint32: + return uncheckedScalarCast( + try context.buffer.readVarInt32() as Int32, to: ElementCodec.Value.self) + default: + break + } + } + if ElementCodec.self == Int64FixedCodec.self || ElementCodec.self == Int64VarintCodec.self + || ElementCodec.self == Int64TaggedCodec.self + { + switch remoteTypeID { + case .int64: + return uncheckedScalarCast( + try context.buffer.readInt64() as Int64, to: ElementCodec.Value.self) + case .varint64: + return uncheckedScalarCast( + try context.buffer.readVarInt64() as Int64, to: ElementCodec.Value.self) + case .taggedInt64: + return uncheckedScalarCast( + try context.buffer.readTaggedInt64() as Int64, to: ElementCodec.Value.self) + default: + break + } + } + if ElementCodec.self == IntFixedCodec.self || ElementCodec.self == IntVarintCodec.self + || ElementCodec.self == IntTaggedCodec.self + { + switch remoteTypeID { + case .int64: + return uncheckedScalarCast(Int(try context.buffer.readInt64()), to: ElementCodec.Value.self) + case .varint64: + return uncheckedScalarCast( + Int(try context.buffer.readVarInt64()), to: ElementCodec.Value.self) + case .taggedInt64: + return uncheckedScalarCast( + Int(try context.buffer.readTaggedInt64()), to: ElementCodec.Value.self) + default: + break + } + } + if ElementCodec.self == UInt32FixedCodec.self || ElementCodec.self == UInt32VarintCodec.self { + switch remoteTypeID { + case .uint32: + return uncheckedScalarCast( + try context.buffer.readUInt32() as UInt32, to: ElementCodec.Value.self) + case .varUInt32: + return uncheckedScalarCast( + try context.buffer.readVarUInt32() as UInt32, to: ElementCodec.Value.self) + default: + break + } + } + if ElementCodec.self == UInt64FixedCodec.self || ElementCodec.self == UInt64VarintCodec.self + || ElementCodec.self == UInt64TaggedCodec.self + { + switch remoteTypeID { + case .uint64: + return uncheckedScalarCast( + try context.buffer.readUInt64() as UInt64, to: ElementCodec.Value.self) + case .varUInt64: + return uncheckedScalarCast( + try context.buffer.readVarUInt64() as UInt64, to: ElementCodec.Value.self) + case .taggedUInt64: + return uncheckedScalarCast( + try context.buffer.readTaggedUInt64() as UInt64, to: ElementCodec.Value.self) + default: + break + } + } + if ElementCodec.self == UIntFixedCodec.self || ElementCodec.self == UIntVarintCodec.self + || ElementCodec.self == UIntTaggedCodec.self + { + switch remoteTypeID { + case .uint64: + return uncheckedScalarCast(UInt(try context.buffer.readUInt64()), to: ElementCodec.Value.self) + case .varUInt64: + return uncheckedScalarCast( + UInt(try context.buffer.readVarUInt64()), to: ElementCodec.Value.self) + case .taggedUInt64: + return uncheckedScalarCast( + UInt(try context.buffer.readTaggedUInt64()), to: ElementCodec.Value.self) + default: + break + } + } + throw ForyError.typeMismatch(expected: ElementCodec.typeId.rawValue, actual: remoteElementTypeID) } private func readPackedArrayElementCount( - _ context: ReadContext, - width: Int, - label: String + _ context: ReadContext, + width: Int, + label: String ) throws -> Int { - let byteSize = Int(try context.buffer.readVarUInt32()) - try context.ensureRemainingBytes(byteSize, label: "primitive_array_bytes") - if byteSize % width != 0 { - throw ForyError.invalidData("\(label) byte size mismatch") - } - let count = byteSize / width - try context.ensureCollectionLength(count, label: label) - return count + let byteSize = Int(try context.buffer.readVarUInt32()) + try context.ensureRemainingBytes(byteSize, label: "primitive_array_bytes") + if byteSize % width != 0 { + throw ForyError.invalidData("\(label) byte size mismatch") + } + let count = byteSize / width + try context.ensureCollectionLength(count, label: label) + return count } private func writeCollectionPayload( - _ value: [ElementCodec.Value], - _ context: WriteContext, - elementCodec _: ElementCodec.Type + _ value: [ElementCodec.Value], + _ context: WriteContext, + elementCodec _: ElementCodec.Type ) throws { - let buffer = context.buffer - buffer.writeVarUInt32(UInt32(value.count)) - if value.isEmpty { - return - } - - let hasNull = ElementCodec.isNullableType && value.contains(where: ElementCodec.isNone) - let trackRef = context.trackRef && ElementCodec.isRefType - let declaredElementType = !TypeId.needsTypeInfoForField(ElementCodec.typeId) - let dynamicElementType = ElementCodec.typeId == .unknown - - var header: UInt8 = dynamicElementType ? 0 : CollectionHeader.sameType - if trackRef { - header |= CollectionHeader.trackingRef - } - if hasNull { - header |= CollectionHeader.hasNull - } - if declaredElementType { - header |= CollectionHeader.declaredElementType - } - - buffer.writeUInt8(header) - if !dynamicElementType && !declaredElementType { - try ElementCodec.writeStaticTypeInfo(context) - } - - if dynamicElementType { - let refMode = RefMode.from(nullable: hasNull, trackRef: trackRef) - for element in value { - try ElementCodec.write(element, context, refMode: refMode, writeTypeInfo: true) - } - return - } - - if trackRef { - for element in value { - try ElementCodec.write(element, context, refMode: .tracking, writeTypeInfo: false) - } - } else if hasNull { - for element in value { - if ElementCodec.isNone(element) { - buffer.writeInt8(RefFlag.null.rawValue) - } else { - buffer.writeInt8(RefFlag.notNullValue.rawValue) - try ElementCodec.writePayload(element, context) - } - } - } else { - for element in value { - try ElementCodec.writePayload(element, context) - } - } + let buffer = context.buffer + buffer.writeVarUInt32(UInt32(value.count)) + if value.isEmpty { + return + } + + let hasNull = ElementCodec.isNullableType && value.contains(where: ElementCodec.isNone) + let trackRef = context.trackRef && ElementCodec.isRefType + let declaredElementType = !TypeId.needsTypeInfoForField(ElementCodec.typeId) + let dynamicElementType = ElementCodec.typeId == .unknown + + var header: UInt8 = dynamicElementType ? 0 : CollectionHeader.sameType + if trackRef { + header |= CollectionHeader.trackingRef + } + if hasNull { + header |= CollectionHeader.hasNull + } + if declaredElementType { + header |= CollectionHeader.declaredElementType + } + + buffer.writeUInt8(header) + if !dynamicElementType && !declaredElementType { + try ElementCodec.writeStaticTypeInfo(context) + } + + if dynamicElementType { + let refMode = RefMode.from(nullable: hasNull, trackRef: trackRef) + for element in value { + try ElementCodec.write(element, context, refMode: refMode, writeTypeInfo: true) + } + return + } + + if trackRef { + for element in value { + try ElementCodec.write(element, context, refMode: .tracking, writeTypeInfo: false) + } + } else if hasNull { + for element in value { + if ElementCodec.isNone(element) { + buffer.writeInt8(RefFlag.null.rawValue) + } else { + buffer.writeInt8(RefFlag.notNullValue.rawValue) + try ElementCodec.writePayload(element, context) + } + } + } else { + for element in value { + try ElementCodec.writePayload(element, context) + } + } } private func readCollectionPayload( - _ context: ReadContext, - elementCodec _: ElementCodec.Type + _ context: ReadContext, + elementCodec _: ElementCodec.Type ) throws -> [ElementCodec.Value] { - let buffer = context.buffer - let length = Int(try buffer.readVarUInt32()) - try context.ensureCollectionLength(length, label: "array") - if length == 0 { + let buffer = context.buffer + let length = Int(try buffer.readVarUInt32()) + try context.ensureCollectionLength(length, label: "array") + if length == 0 { + try reserveFieldArrayStorage(context, ElementCodec.self, count: length) + return [] + } + + let header = try buffer.readUInt8() + // IMPORTANT: collection readers must obey the ref/null bits written on the + // wire, not the local Swift element metadata that may imply a different + // ref policy. Shared xlang tests intentionally deserialize one ref policy + // and then serialize another local payload. DO NOT REMOVE this comment. + let trackRef = (header & CollectionHeader.trackingRef) != 0 + let hasNull = (header & CollectionHeader.hasNull) != 0 + let declared = (header & CollectionHeader.declaredElementType) != 0 + let sameType = (header & CollectionHeader.sameType) != 0 + + var result: [ElementCodec.Value] = [] try reserveFieldArrayStorage(context, ElementCodec.self, count: length) - return [] - } - - let header = try buffer.readUInt8() - // IMPORTANT: collection readers must obey the ref/null bits written on the - // wire, not the local Swift element metadata that may imply a different - // ref policy. Shared xlang tests intentionally deserialize one ref policy - // and then serialize another local payload. DO NOT REMOVE this comment. - let trackRef = (header & CollectionHeader.trackingRef) != 0 - let hasNull = (header & CollectionHeader.hasNull) != 0 - let declared = (header & CollectionHeader.declaredElementType) != 0 - let sameType = (header & CollectionHeader.sameType) != 0 - - var result: [ElementCodec.Value] = [] - try reserveFieldArrayStorage(context, ElementCodec.self, count: length) - try context.ensureRemainingBytes(length, label: "array") - result.reserveCapacity(length) - - if !sameType { - let refMode = RefMode.from(nullable: hasNull, trackRef: trackRef) - for _ in 0..( - _ context: ReadContext, - refMode: RefMode, - elementCodec _: ElementCodec.Type, - remoteElementTypeID: UInt32 + _ context: ReadContext, + refMode: RefMode, + elementCodec _: ElementCodec.Type, + remoteElementTypeID: UInt32 ) throws -> [ElementCodec.Value] { - switch refMode { - case .none: - return try readListPayloadAsArrayPayload( - context, - elementCodec: ElementCodec.self, - remoteElementTypeID: remoteElementTypeID - ) - case .nullOnly, .tracking: - let rawFlag = try context.buffer.readInt8() - guard rawFlag != RefFlag.null.rawValue else { - return [] - } - if rawFlag == RefFlag.ref.rawValue { - let refID = try context.buffer.readVarUInt32() - return try context.refReader.readRef(refID, as: [ElementCodec.Value].self) - } - let reservedRefID = - (rawFlag == RefFlag.refValue.rawValue && context.trackRef) - ? context.refReader.reserveRefID() - : nil - let value = try readListPayloadAsArrayPayload( - context, - elementCodec: ElementCodec.self, - remoteElementTypeID: remoteElementTypeID - ) - if let reservedRefID { - context.refReader.storeRef(value, at: reservedRefID) + switch refMode { + case .none: + return try readListPayloadAsArrayPayload( + context, + elementCodec: ElementCodec.self, + remoteElementTypeID: remoteElementTypeID + ) + case .nullOnly, .tracking: + let rawFlag = try context.buffer.readInt8() + guard rawFlag != RefFlag.null.rawValue else { + return [] + } + if rawFlag == RefFlag.ref.rawValue { + let refID = try context.buffer.readVarUInt32() + return try context.refReader.readRef(refID, as: [ElementCodec.Value].self) + } + let reservedRefID = + (rawFlag == RefFlag.refValue.rawValue && context.trackRef) + ? context.refReader.reserveRefID() + : nil + let value = try readListPayloadAsArrayPayload( + context, + elementCodec: ElementCodec.self, + remoteElementTypeID: remoteElementTypeID + ) + if let reservedRefID { + context.refReader.storeRef(value, at: reservedRefID) + } + return value } - return value - } } private func readListPayloadAsArrayPayload( - _ context: ReadContext, - elementCodec _: ElementCodec.Type, - remoteElementTypeID: UInt32 + _ context: ReadContext, + elementCodec _: ElementCodec.Type, + remoteElementTypeID: UInt32 ) throws -> [ElementCodec.Value] { - let buffer = context.buffer - let length = Int(try buffer.readVarUInt32()) - try context.ensureCollectionLength(length, label: "array") - if length == 0 { + let buffer = context.buffer + let length = Int(try buffer.readVarUInt32()) + try context.ensureCollectionLength(length, label: "array") + if length == 0 { + try reserveFieldArrayStorage(context, ElementCodec.self, count: length) + return [] + } + + let header = try buffer.readUInt8() + let trackRef = (header & CollectionHeader.trackingRef) != 0 + let hasNull = (header & CollectionHeader.hasNull) != 0 + if hasNull { + throw ForyError.invalidData("compatible list-to-array field cannot read nullable elements") + } + let declared = (header & CollectionHeader.declaredElementType) != 0 + let sameType = (header & CollectionHeader.sameType) != 0 + + if !sameType { + throw ForyError.invalidData("compatible list-to-array field requires same-type elements") + } + + if trackRef { + throw ForyError.invalidData("compatible list-to-array field cannot read ref-tracked elements") + } + let elementTypeInfo: TypeInfo? + if declared { + elementTypeInfo = nil + } else { + throw ForyError.invalidData("compatible list-to-array field requires declared elements") + } + try context.ensureRemainingBytes(length, label: "array") + var result: [ElementCodec.Value] = [] try reserveFieldArrayStorage(context, ElementCodec.self, count: length) - return [] - } - - let header = try buffer.readUInt8() - let trackRef = (header & CollectionHeader.trackingRef) != 0 - let hasNull = (header & CollectionHeader.hasNull) != 0 - if hasNull { - throw ForyError.invalidData("compatible list-to-array field cannot read nullable elements") - } - let declared = (header & CollectionHeader.declaredElementType) != 0 - let sameType = (header & CollectionHeader.sameType) != 0 - - if !sameType { - throw ForyError.invalidData("compatible list-to-array field requires same-type elements") - } - - if trackRef { - throw ForyError.invalidData("compatible list-to-array field cannot read ref-tracked elements") - } - let elementTypeInfo: TypeInfo? - if declared { - elementTypeInfo = nil - } else { - throw ForyError.invalidData("compatible list-to-array field requires declared elements") - } - try context.ensureRemainingBytes(length, label: "array") - var result: [ElementCodec.Value] = [] - try reserveFieldArrayStorage(context, ElementCodec.self, count: length) - result.reserveCapacity(length) - return try ElementCodec.withTypeInfo(elementTypeInfo, context) { - for _ in 0.. Bool { - guard let resolved = TypeId(rawValue: typeID) else { - return true - } - return TypeId.needsTypeInfoForField(resolved) - } - - private func readSkippedFieldValue( - fieldType: TypeMeta.FieldType, - typeInfo: TypeInfo? = nil, - readTypeInfo: Bool - ) throws -> Any? { - let refMode = RefMode.from(nullable: fieldType.nullable, trackRef: fieldType.trackRef) - return try readSkippedValue( - fieldType: fieldType, - typeInfo: typeInfo, - refMode: refMode, - readTypeInfo: readTypeInfo - ) - } - - private func readSkippedValue( - fieldType: TypeMeta.FieldType, - typeInfo: TypeInfo?, - refMode: RefMode, - readTypeInfo: Bool - ) throws -> Any? { - switch refMode { - case .none: - return try readSkippedFieldPayload( - fieldType: fieldType, - typeInfo: typeInfo, - readTypeInfo: readTypeInfo - ) - case .nullOnly: - let flag = try buffer.readInt8() - if flag == RefFlag.null.rawValue { - return nil - } - guard flag == RefFlag.notNullValue.rawValue else { - throw ForyError.invalidData("unexpected nullOnly flag \(flag)") - } - return try readSkippedFieldPayload( - fieldType: fieldType, - typeInfo: typeInfo, - readTypeInfo: readTypeInfo - ) - case .tracking: - let rawFlag = try buffer.readInt8() - guard let flag = RefFlag(rawValue: rawFlag) else { - throw ForyError.invalidData("unexpected tracking flag \(rawFlag)") - } - - switch flag { - case .null: - return nil - case .ref: - let refID = try buffer.readVarUInt32() - return try refReader.readRefValue(refID) - case .refValue: - let refID = refReader.reserveRefID() - let value = try readSkippedFieldPayload( - fieldType: fieldType, - typeInfo: typeInfo, - readTypeInfo: readTypeInfo + public func skipFieldValue(_ fieldType: TypeMeta.FieldType) throws { + _ = try readSkippedFieldValue( + fieldType: fieldType, + readTypeInfo: needsTypeInfoForSkippedField(fieldType.typeID) ) - refReader.storeRef(value, at: refID) - return value - case .notNullValue: - return try readSkippedFieldPayload( - fieldType: fieldType, - typeInfo: typeInfo, - readTypeInfo: readTypeInfo - ) - } } - } - private func readSkippedFieldPayload( - fieldType: TypeMeta.FieldType, - typeInfo: TypeInfo?, - readTypeInfo: Bool - ) throws -> Any { - if let typeInfo { - return try readAnyValue(typeInfo: typeInfo) - } - if readTypeInfo { - let typeInfo = try self.readTypeInfo() - return try readAnyValue(typeInfo: typeInfo) + private func needsTypeInfoForSkippedField(_ typeID: UInt32) -> Bool { + guard let resolved = TypeId(rawValue: typeID) else { + return true + } + return TypeId.needsTypeInfoForField(resolved) } - guard let resolvedTypeID = TypeId(rawValue: fieldType.typeID) else { - throw ForyError.invalidData("unknown compatible field type id \(fieldType.typeID)") + private func readSkippedFieldValue( + fieldType: TypeMeta.FieldType, + typeInfo: TypeInfo? = nil, + readTypeInfo: Bool + ) throws -> Any? { + let refMode = RefMode.from(nullable: fieldType.nullable, trackRef: fieldType.trackRef) + return try readSkippedValue( + fieldType: fieldType, + typeInfo: typeInfo, + refMode: refMode, + readTypeInfo: readTypeInfo + ) } - switch resolvedTypeID { - case .none: - return ForyAnyNullValue() - case .bool: - return try Bool.foryRead(self, refMode: .none, readTypeInfo: false) - case .int8: - return try Int8.foryRead(self, refMode: .none, readTypeInfo: false) - case .int16: - return try Int16.foryRead(self, refMode: .none, readTypeInfo: false) - case .int32: - return try buffer.readInt32() - case .varint32: - return try Int32.foryRead(self, refMode: .none, readTypeInfo: false) - case .int64: - return try buffer.readInt64() - case .varint64: - return try Int64.foryRead(self, refMode: .none, readTypeInfo: false) - case .taggedInt64: - return try buffer.readTaggedInt64() - case .uint8: - return try UInt8.foryRead(self, refMode: .none, readTypeInfo: false) - case .uint16: - return try UInt16.foryRead(self, refMode: .none, readTypeInfo: false) - case .uint32: - return try buffer.readUInt32() - case .varUInt32: - return try UInt32.foryRead(self, refMode: .none, readTypeInfo: false) - case .uint64: - return try buffer.readUInt64() - case .varUInt64: - return try UInt64.foryRead(self, refMode: .none, readTypeInfo: false) - case .taggedUInt64: - return try buffer.readTaggedUInt64() - case .float16: - return try Float16.foryRead(self, refMode: .none, readTypeInfo: false) - case .bfloat16: - return try BFloat16.foryRead(self, refMode: .none, readTypeInfo: false) - case .float32: - return try Float.foryRead(self, refMode: .none, readTypeInfo: false) - case .float64: - return try Double.foryRead(self, refMode: .none, readTypeInfo: false) - case .string: - return try String.foryRead(self, refMode: .none, readTypeInfo: false) - case .duration: - return try Duration.foryRead(self, refMode: .none, readTypeInfo: false) - case .timestamp: - return try Date.foryRead(self, refMode: .none, readTypeInfo: false) - case .date: - return try LocalDate.foryRead(self, refMode: .none, readTypeInfo: false) - case .decimal: - return try Decimal.foryRead(self, refMode: .none, readTypeInfo: false) - case .binary, .uint8Array: - return try Data.foryRead(self, refMode: .none, readTypeInfo: false) - case .boolArray: - return try [Bool].foryRead(self, refMode: .none, readTypeInfo: false) - case .int8Array: - return try [Int8].foryRead(self, refMode: .none, readTypeInfo: false) - case .int16Array: - return try [Int16].foryRead(self, refMode: .none, readTypeInfo: false) - case .int32Array: - return try [Int32].foryRead(self, refMode: .none, readTypeInfo: false) - case .int64Array: - return try [Int64].foryRead(self, refMode: .none, readTypeInfo: false) - case .uint16Array: - return try [UInt16].foryRead(self, refMode: .none, readTypeInfo: false) - case .uint32Array: - return try [UInt32].foryRead(self, refMode: .none, readTypeInfo: false) - case .uint64Array: - return try [UInt64].foryRead(self, refMode: .none, readTypeInfo: false) - case .float16Array: - return try [Float16].foryRead(self, refMode: .none, readTypeInfo: false) - case .bfloat16Array: - return try [BFloat16].foryRead(self, refMode: .none, readTypeInfo: false) - case .float32Array: - return try [Float].foryRead(self, refMode: .none, readTypeInfo: false) - case .float64Array: - return try [Double].foryRead(self, refMode: .none, readTypeInfo: false) - case .array, .list: - return try readSkippedCollection(fieldType: fieldType) - case .set: - return try readSkippedSet(fieldType: fieldType) - case .map: - return try readSkippedMap(fieldType: fieldType) - case .union, .typedUnion, .namedUnion: - return try readSkippedUnion() - case .enumType, .namedEnum: - return try buffer.readVarUInt32() - default: - throw ForyError.invalidData("unsupported compatible field type id \(fieldType.typeID)") - } - } + private func readSkippedValue( + fieldType: TypeMeta.FieldType, + typeInfo: TypeInfo?, + refMode: RefMode, + readTypeInfo: Bool + ) throws -> Any? { + switch refMode { + case .none: + return try readSkippedFieldPayload( + fieldType: fieldType, + typeInfo: typeInfo, + readTypeInfo: readTypeInfo + ) + case .nullOnly: + let flag = try buffer.readInt8() + if flag == RefFlag.null.rawValue { + return nil + } + guard flag == RefFlag.notNullValue.rawValue else { + throw ForyError.invalidData("unexpected nullOnly flag \(flag)") + } + return try readSkippedFieldPayload( + fieldType: fieldType, + typeInfo: typeInfo, + readTypeInfo: readTypeInfo + ) + case .tracking: + let rawFlag = try buffer.readInt8() + guard let flag = RefFlag(rawValue: rawFlag) else { + throw ForyError.invalidData("unexpected tracking flag \(rawFlag)") + } - private func readSkippedCollection( - fieldType: TypeMeta.FieldType - ) throws -> [Any] { - let elementFieldType = - fieldType.generics.first - ?? TypeMeta.FieldType(typeID: TypeId.unknown.rawValue, nullable: true) - let length = Int(try buffer.readVarUInt32()) - try ensureCollectionLength(length, label: "compatible_collection") - if length == 0 { - return [] + switch flag { + case .null: + return nil + case .ref: + let refID = try buffer.readVarUInt32() + return try refReader.readRefValue(refID) + case .refValue: + let refID = refReader.reserveRefID() + let value = try readSkippedFieldPayload( + fieldType: fieldType, + typeInfo: typeInfo, + readTypeInfo: readTypeInfo + ) + refReader.storeRef(value, at: refID) + return value + case .notNullValue: + return try readSkippedFieldPayload( + fieldType: fieldType, + typeInfo: typeInfo, + readTypeInfo: readTypeInfo + ) + } + } } - let header = try buffer.readUInt8() - let trackRef = (header & 0b0000_0001) != 0 - let hasNull = (header & 0b0000_0010) != 0 - let declared = (header & 0b0000_0100) != 0 - let sameType = (header & 0b0000_1000) != 0 + private func readSkippedFieldPayload( + fieldType: TypeMeta.FieldType, + typeInfo: TypeInfo?, + readTypeInfo: Bool + ) throws -> Any { + if let typeInfo { + return try readAnyValue(typeInfo: typeInfo) + } + if readTypeInfo { + let typeInfo = try self.readTypeInfo() + return try readAnyValue(typeInfo: typeInfo) + } - var typeInfo: TypeInfo? - if sameType, !declared { - typeInfo = try self.readTypeInfo() + guard let resolvedTypeID = TypeId(rawValue: fieldType.typeID) else { + throw ForyError.invalidData("unknown compatible field type id \(fieldType.typeID)") + } + + switch resolvedTypeID { + case .none: + return ForyAnyNullValue() + case .bool: + return try Bool.foryRead(self, refMode: .none, readTypeInfo: false) + case .int8: + return try Int8.foryRead(self, refMode: .none, readTypeInfo: false) + case .int16: + return try Int16.foryRead(self, refMode: .none, readTypeInfo: false) + case .int32: + return try buffer.readInt32() + case .varint32: + return try Int32.foryRead(self, refMode: .none, readTypeInfo: false) + case .int64: + return try buffer.readInt64() + case .varint64: + return try Int64.foryRead(self, refMode: .none, readTypeInfo: false) + case .taggedInt64: + return try buffer.readTaggedInt64() + case .uint8: + return try UInt8.foryRead(self, refMode: .none, readTypeInfo: false) + case .uint16: + return try UInt16.foryRead(self, refMode: .none, readTypeInfo: false) + case .uint32: + return try buffer.readUInt32() + case .varUInt32: + return try UInt32.foryRead(self, refMode: .none, readTypeInfo: false) + case .uint64: + return try buffer.readUInt64() + case .varUInt64: + return try UInt64.foryRead(self, refMode: .none, readTypeInfo: false) + case .taggedUInt64: + return try buffer.readTaggedUInt64() + case .float16: + return try Float16.foryRead(self, refMode: .none, readTypeInfo: false) + case .bfloat16: + return try BFloat16.foryRead(self, refMode: .none, readTypeInfo: false) + case .float32: + return try Float.foryRead(self, refMode: .none, readTypeInfo: false) + case .float64: + return try Double.foryRead(self, refMode: .none, readTypeInfo: false) + case .string: + return try String.foryRead(self, refMode: .none, readTypeInfo: false) + case .duration: + return try Duration.foryRead(self, refMode: .none, readTypeInfo: false) + case .timestamp: + return try Date.foryRead(self, refMode: .none, readTypeInfo: false) + case .date: + return try LocalDate.foryRead(self, refMode: .none, readTypeInfo: false) + case .decimal: + return try Decimal.foryRead(self, refMode: .none, readTypeInfo: false) + case .binary, .uint8Array: + return try Data.foryRead(self, refMode: .none, readTypeInfo: false) + case .boolArray: + return try [Bool].foryRead(self, refMode: .none, readTypeInfo: false) + case .int8Array: + return try [Int8].foryRead(self, refMode: .none, readTypeInfo: false) + case .int16Array: + return try [Int16].foryRead(self, refMode: .none, readTypeInfo: false) + case .int32Array: + return try [Int32].foryRead(self, refMode: .none, readTypeInfo: false) + case .int64Array: + return try [Int64].foryRead(self, refMode: .none, readTypeInfo: false) + case .uint16Array: + return try [UInt16].foryRead(self, refMode: .none, readTypeInfo: false) + case .uint32Array: + return try [UInt32].foryRead(self, refMode: .none, readTypeInfo: false) + case .uint64Array: + return try [UInt64].foryRead(self, refMode: .none, readTypeInfo: false) + case .float16Array: + return try [Float16].foryRead(self, refMode: .none, readTypeInfo: false) + case .bfloat16Array: + return try [BFloat16].foryRead(self, refMode: .none, readTypeInfo: false) + case .float32Array: + return try [Float].foryRead(self, refMode: .none, readTypeInfo: false) + case .float64Array: + return try [Double].foryRead(self, refMode: .none, readTypeInfo: false) + case .array, .list: + return try readSkippedCollection(fieldType: fieldType) + case .set: + return try readSkippedSet(fieldType: fieldType) + case .map: + return try readSkippedMap(fieldType: fieldType) + case .union, .typedUnion, .namedUnion: + return try readSkippedUnion() + case .enumType, .namedEnum: + return try buffer.readVarUInt32() + default: + throw ForyError.invalidData("unsupported compatible field type id \(fieldType.typeID)") + } } - for _ in 0.. [Any] { + let elementFieldType = + fieldType.generics.first + ?? TypeMeta.FieldType(typeID: TypeId.unknown.rawValue, nullable: true) + let length = Int(try buffer.readVarUInt32()) + try ensureCollectionLength(length, label: "compatible_collection") + if length == 0 { + return [] } - continue - } - if trackRef { - _ = try readSkippedValue( - fieldType: elementFieldType, - typeInfo: nil, - refMode: .tracking, - readTypeInfo: true - ) - } else if hasNull { - let refFlag = try buffer.readInt8() - if refFlag == RefFlag.null.rawValue { - continue + let header = try buffer.readUInt8() + let trackRef = (header & 0b0000_0001) != 0 + let hasNull = (header & 0b0000_0010) != 0 + let declared = (header & 0b0000_0100) != 0 + let sameType = (header & 0b0000_1000) != 0 + + var typeInfo: TypeInfo? + if sameType, !declared { + typeInfo = try self.readTypeInfo() } - if refFlag != RefFlag.notNullValue.rawValue { - throw ForyError.invalidData("invalid collection nullability flag \(refFlag)") + + for _ in 0.. Set { + _ = try readSkippedCollection(fieldType: fieldType) + return [] + } - private func readSkippedSet( - fieldType: TypeMeta.FieldType - ) throws -> Set { - _ = try readSkippedCollection(fieldType: fieldType) - return [] - } + private func readSkippedMap( + fieldType: TypeMeta.FieldType + ) throws -> [AnyHashable: Any] { + let keyType = + fieldType.generics.first + ?? TypeMeta.FieldType(typeID: TypeId.unknown.rawValue, nullable: true) + let valueType = + fieldType.generics.dropFirst().first + ?? TypeMeta.FieldType(typeID: TypeId.unknown.rawValue, nullable: true) - private func readSkippedMap( - fieldType: TypeMeta.FieldType - ) throws -> [AnyHashable: Any] { - let keyType = - fieldType.generics.first - ?? TypeMeta.FieldType(typeID: TypeId.unknown.rawValue, nullable: true) - let valueType = - fieldType.generics.dropFirst().first - ?? TypeMeta.FieldType(typeID: TypeId.unknown.rawValue, nullable: true) + let totalLength = Int(try buffer.readVarUInt32()) + try ensureCollectionLength(totalLength, label: "compatible_map") + if totalLength == 0 { + return [:] + } - let totalLength = Int(try buffer.readVarUInt32()) - try ensureCollectionLength(totalLength, label: "compatible_map") - if totalLength == 0 { - return [:] - } + var readCount = 0 + while readCount < totalLength { + let header = try buffer.readUInt8() + let trackKeyRef = (header & 0b0000_0001) != 0 + let keyNull = (header & 0b0000_0010) != 0 + let keyDeclared = (header & 0b0000_0100) != 0 - var readCount = 0 - while readCount < totalLength { - let header = try buffer.readUInt8() - let trackKeyRef = (header & 0b0000_0001) != 0 - let keyNull = (header & 0b0000_0010) != 0 - let keyDeclared = (header & 0b0000_0100) != 0 + let trackValueRef = (header & 0b0000_1000) != 0 + let valueNull = (header & 0b0001_0000) != 0 + let valueDeclared = (header & 0b0010_0000) != 0 - let trackValueRef = (header & 0b0000_1000) != 0 - let valueNull = (header & 0b0001_0000) != 0 - let valueDeclared = (header & 0b0010_0000) != 0 + if keyNull && valueNull { + readCount += 1 + continue + } - if keyNull && valueNull { - readCount += 1 - continue - } + if keyNull { + let valueTypeInfo = valueDeclared ? nil : try self.readTypeInfo() + _ = try readSkippedValue( + fieldType: valueType, + typeInfo: valueTypeInfo, + refMode: trackValueRef ? .tracking : .none, + readTypeInfo: false + ) + readCount += 1 + continue + } - if keyNull { - let valueTypeInfo = valueDeclared ? nil : try self.readTypeInfo() - _ = try readSkippedValue( - fieldType: valueType, - typeInfo: valueTypeInfo, - refMode: trackValueRef ? .tracking : .none, - readTypeInfo: false - ) - readCount += 1 - continue - } + if valueNull { + let keyTypeInfo = keyDeclared ? nil : try self.readTypeInfo() + _ = try readSkippedValue( + fieldType: keyType, + typeInfo: keyTypeInfo, + refMode: trackKeyRef ? .tracking : .none, + readTypeInfo: false + ) + readCount += 1 + continue + } - if valueNull { - let keyTypeInfo = keyDeclared ? nil : try self.readTypeInfo() - _ = try readSkippedValue( - fieldType: keyType, - typeInfo: keyTypeInfo, - refMode: trackKeyRef ? .tracking : .none, - readTypeInfo: false - ) - readCount += 1 - continue - } + let chunkSize = Int(try buffer.readUInt8()) + if chunkSize <= 0 { + throw ForyError.invalidData("invalid map chunk size \(chunkSize)") + } + if chunkSize > (totalLength - readCount) { + throw ForyError.invalidData("map chunk size exceeds remaining entries") + } - let chunkSize = Int(try buffer.readUInt8()) - if chunkSize <= 0 { - throw ForyError.invalidData("invalid map chunk size \(chunkSize)") - } - if chunkSize > (totalLength - readCount) { - throw ForyError.invalidData("map chunk size exceeds remaining entries") - } + let keyTypeInfo = keyDeclared ? nil : try self.readTypeInfo() + let valueTypeInfo = valueDeclared ? nil : try self.readTypeInfo() - let keyTypeInfo = keyDeclared ? nil : try self.readTypeInfo() - let valueTypeInfo = valueDeclared ? nil : try self.readTypeInfo() + for _ in 0.. Any { - _ = try buffer.readVarUInt32() - return try readAny(context: self, refMode: .tracking, readTypeInfo: true) - ?? ForyAnyNullValue() - } + private func readSkippedUnion() throws -> Any { + _ = try buffer.readVarUInt32() + return try readAny(context: self, refMode: .tracking, readTypeInfo: true) + ?? ForyAnyNullValue() + } } diff --git a/swift/Sources/Fory/Fory.swift b/swift/Sources/Fory/Fory.swift index 1da04359e0..769c70fb96 100644 --- a/swift/Sources/Fory/Fory.swift +++ b/swift/Sources/Fory/Fory.swift @@ -18,45 +18,45 @@ import Foundation public struct Config { - public let trackRef: Bool - public let compatible: Bool - public let checkClassVersion: Bool - public let maxDepth: Int - public let maxGraphMemoryBytes: Int64 - public let maxTypeFields: Int - public let maxTypeMetaBytes: Int - public let maxSchemaVersionsPerType: Int - public let maxAverageSchemaVersionsPerType: Int - - public init( - trackRef: Bool = false, - compatible: Bool? = nil, - checkClassVersion: Bool? = nil, - maxDepth: Int = 5, - maxGraphMemoryBytes: Int64 = 128 * 1024 * 1024, - maxTypeFields: Int = 512, - maxTypeMetaBytes: Int = 4096, - maxSchemaVersionsPerType: Int = 10, - maxAverageSchemaVersionsPerType: Int = 3 - ) { - precondition(maxTypeFields > 0, "maxTypeFields must be positive") - precondition(maxTypeMetaBytes > 0, "maxTypeMetaBytes must be positive") - precondition(maxSchemaVersionsPerType > 0, "maxSchemaVersionsPerType must be positive") - precondition( - maxAverageSchemaVersionsPerType > 0, - "maxAverageSchemaVersionsPerType must be positive") - let effectiveCompatible = compatible ?? true - let effectiveCheckClassVersion = checkClassVersion ?? !effectiveCompatible - self.trackRef = trackRef - self.compatible = effectiveCompatible - self.checkClassVersion = effectiveCheckClassVersion - self.maxDepth = maxDepth - self.maxGraphMemoryBytes = maxGraphMemoryBytes - self.maxTypeFields = maxTypeFields - self.maxTypeMetaBytes = maxTypeMetaBytes - self.maxSchemaVersionsPerType = maxSchemaVersionsPerType - self.maxAverageSchemaVersionsPerType = maxAverageSchemaVersionsPerType - } + public let trackRef: Bool + public let compatible: Bool + public let checkClassVersion: Bool + public let maxDepth: Int + public let maxGraphMemoryBytes: Int64 + public let maxTypeFields: Int + public let maxTypeMetaBytes: Int + public let maxSchemaVersionsPerType: Int + public let maxAverageSchemaVersionsPerType: Int + + public init( + trackRef: Bool = false, + compatible: Bool? = nil, + checkClassVersion: Bool? = nil, + maxDepth: Int = 5, + maxGraphMemoryBytes: Int64 = 128 * 1024 * 1024, + maxTypeFields: Int = 512, + maxTypeMetaBytes: Int = 4096, + maxSchemaVersionsPerType: Int = 10, + maxAverageSchemaVersionsPerType: Int = 3 + ) { + precondition(maxTypeFields > 0, "maxTypeFields must be positive") + precondition(maxTypeMetaBytes > 0, "maxTypeMetaBytes must be positive") + precondition(maxSchemaVersionsPerType > 0, "maxSchemaVersionsPerType must be positive") + precondition( + maxAverageSchemaVersionsPerType > 0, + "maxAverageSchemaVersionsPerType must be positive") + let effectiveCompatible = compatible ?? true + let effectiveCheckClassVersion = checkClassVersion ?? !effectiveCompatible + self.trackRef = trackRef + self.compatible = effectiveCompatible + self.checkClassVersion = effectiveCheckClassVersion + self.maxDepth = maxDepth + self.maxGraphMemoryBytes = maxGraphMemoryBytes + self.maxTypeFields = maxTypeFields + self.maxTypeMetaBytes = maxTypeMetaBytes + self.maxSchemaVersionsPerType = maxSchemaVersionsPerType + self.maxAverageSchemaVersionsPerType = maxAverageSchemaVersionsPerType + } } /// Single-threaded Fory runtime. @@ -65,503 +65,516 @@ public struct Config { /// reusable read/write context pair and must not be used concurrently from /// multiple threads. public final class Fory { - let typeResolver: TypeResolver - private let writeContext: WriteContext - private let readContext: ReadContext - public let config: Config - - public convenience init( - ref: Bool = false, - compatible: Bool? = nil, - checkClassVersion: Bool? = nil, - maxDepth: Int = 5, - maxGraphMemoryBytes: Int64 = 128 * 1024 * 1024, - maxTypeFields: Int = 512, - maxTypeMetaBytes: Int = 4096, - maxSchemaVersionsPerType: Int = 10, - maxAverageSchemaVersionsPerType: Int = 3 - ) { - self.init( - config: Config( - trackRef: ref, - compatible: compatible, - checkClassVersion: checkClassVersion, - maxDepth: maxDepth, - maxGraphMemoryBytes: maxGraphMemoryBytes, - maxTypeFields: maxTypeFields, - maxTypeMetaBytes: maxTypeMetaBytes, - maxSchemaVersionsPerType: maxSchemaVersionsPerType, - maxAverageSchemaVersionsPerType: maxAverageSchemaVersionsPerType - )) - } - - public init(config: Config) { - self.typeResolver = TypeResolver(trackRef: config.trackRef) - self.writeContext = WriteContext( - buffer: ByteBuffer(), - typeResolver: typeResolver, - trackRef: config.trackRef, - compatible: config.compatible, - checkClassVersion: config.checkClassVersion, - maxDepth: config.maxDepth, - metaStringWriteState: MetaStringWriteState() - ) - self.readContext = ReadContext( - buffer: ByteBuffer(), - typeResolver: typeResolver, - config: config + let typeResolver: TypeResolver + private let writeContext: WriteContext + private let readContext: ReadContext + public let config: Config + + public convenience init( + ref: Bool = false, + compatible: Bool? = nil, + checkClassVersion: Bool? = nil, + maxDepth: Int = 5, + maxGraphMemoryBytes: Int64 = 128 * 1024 * 1024, + maxTypeFields: Int = 512, + maxTypeMetaBytes: Int = 4096, + maxSchemaVersionsPerType: Int = 10, + maxAverageSchemaVersionsPerType: Int = 3 + ) { + self.init( + config: Config( + trackRef: ref, + compatible: compatible, + checkClassVersion: checkClassVersion, + maxDepth: maxDepth, + maxGraphMemoryBytes: maxGraphMemoryBytes, + maxTypeFields: maxTypeFields, + maxTypeMetaBytes: maxTypeMetaBytes, + maxSchemaVersionsPerType: maxSchemaVersionsPerType, + maxAverageSchemaVersionsPerType: maxAverageSchemaVersionsPerType + )) + } + + public init(config: Config) { + self.typeResolver = TypeResolver(trackRef: config.trackRef) + self.writeContext = WriteContext( + buffer: ByteBuffer(), + typeResolver: typeResolver, + trackRef: config.trackRef, + compatible: config.compatible, + checkClassVersion: config.checkClassVersion, + maxDepth: config.maxDepth, + metaStringWriteState: MetaStringWriteState() + ) + self.readContext = ReadContext( + buffer: ByteBuffer(), + typeResolver: typeResolver, + config: config + ) + self.config = config + } + + public func register(_ type: T.Type, id: UInt32) { + typeResolver.register(type, id: id) + } + + /// Registers a user type by name. The last `.` separates namespace from the final type name. + public func register(_ type: T.Type, name: String) throws { + try typeResolver.register(type, name: name) + } + + public func serialize(_ value: T) throws -> Data { + try serializeRoot { context in + try writeRootTypedValue(value, context: context) + } + } + + public func deserialize(_ data: Data, as _: T.Type = T.self) throws -> T { + try deserializeRoot( + data: data + ) { context in + try readRootTypedValue(context: context) + } + } + + public func serialize(_ value: T, to buffer: inout Data) throws { + try appendSerializedRoot(to: &buffer) { context in + try writeRootTypedValue(value, context: context) + } + } + + public func deserialize(from buffer: ByteBuffer, as _: T.Type = T.self) throws -> T { + try deserializeRoot( + from: buffer + ) { context in + try readRootTypedValue(context: context) + } + } + + @_disfavoredOverload + public func serialize(_ value: Any) throws -> Data { + try serializeRoot { context in + try context.writeAny(value, refMode: refMode, writeTypeInfo: true, hasGenerics: false) + } + } + + @_disfavoredOverload + public func deserialize(_ data: Data, as _: Any.Type = Any.self) throws -> Any { + try deserializeRoot( + data: data + ) { context in + try castAnyDynamicValue( + readAny(context: context, refMode: refMode, readTypeInfo: true), + to: Any.self + ) + } + } + + @_disfavoredOverload + public func serialize(_ value: AnyObject) throws -> Data { + try serializeRoot { context in + try context.writeAny(value, refMode: refMode, writeTypeInfo: true, hasGenerics: false) + } + } + + @_disfavoredOverload + public func deserialize(_ data: Data, as _: AnyObject.Type = AnyObject.self) throws -> AnyObject { + try deserializeRoot( + data: data + ) { context in + try castAnyDynamicValue( + readAny(context: context, refMode: refMode, readTypeInfo: true), + to: AnyObject.self + ) + } + } + + @_disfavoredOverload + public func serialize(_ value: any Serializer) throws -> Data { + try serializeRoot { context in + try context.writeAny(value, refMode: refMode, writeTypeInfo: true, hasGenerics: false) + } + } + + @_disfavoredOverload + public func deserialize( + _ data: Data, as _: (any Serializer).Type = (any Serializer).self + ) throws + -> any Serializer + { + try deserializeRoot( + data: data + ) { context in + try castAnyDynamicValue( + readAny(context: context, refMode: refMode, readTypeInfo: true), + to: (any Serializer).self + ) + } + } + + @_disfavoredOverload + public func serialize(_ value: [Any]) throws -> Data { + try serializeRoot { context in + try context.writeListOfAny(value, refMode: refMode, writeTypeInfo: true, hasGenerics: false) + } + } + + @_disfavoredOverload + public func deserialize(_ data: Data, as _: [Any].Type = [Any].self) throws -> [Any] { + try deserializeRoot( + data: data + ) { context in + try readListOfAny(context: context, refMode: refMode, readTypeInfo: true) ?? [] + } + } + + @_disfavoredOverload + public func serialize(_ value: [String: Any]) throws -> Data { + try serializeRoot { context in + try context.writeMapStringToAny( + value, refMode: refMode, writeTypeInfo: true, hasGenerics: false) + } + } + + @_disfavoredOverload + public func deserialize( + _ data: Data, as _: [String: Any].Type = [String: Any].self + ) throws + -> [String: Any] + { + try deserializeRoot( + data: data + ) { context in + try readMapStringToAny(context: context, refMode: refMode, readTypeInfo: true) ?? [:] + } + } + + @_disfavoredOverload + public func serialize(_ value: [Int32: Any]) throws -> Data { + try serializeRoot { context in + try context.writeMapInt32ToAny( + value, refMode: refMode, writeTypeInfo: true, hasGenerics: false) + } + } + + @_disfavoredOverload + public func deserialize( + _ data: Data, as _: [Int32: Any].Type = [Int32: Any].self + ) throws + -> [Int32: Any] + { + try deserializeRoot( + data: data + ) { context in + try readMapInt32ToAny(context: context, refMode: refMode, readTypeInfo: true) ?? [:] + } + } + + @_disfavoredOverload + public func serialize(_ value: [AnyHashable: Any]) throws -> Data { + try serializeRoot { context in + try context.writeMapAnyHashableToAny( + value, refMode: refMode, writeTypeInfo: true, hasGenerics: false) + } + } + + @_disfavoredOverload + public func deserialize( + _ data: Data, as _: [AnyHashable: Any].Type = [AnyHashable: Any].self ) - self.config = config - } - - public func register(_ type: T.Type, id: UInt32) { - typeResolver.register(type, id: id) - } - - /// Registers a user type by name. The last `.` separates namespace from the final type name. - public func register(_ type: T.Type, name: String) throws { - try typeResolver.register(type, name: name) - } - - public func serialize(_ value: T) throws -> Data { - try serializeRoot { context in - try writeRootTypedValue(value, context: context) - } - } - - public func deserialize(_ data: Data, as _: T.Type = T.self) throws -> T { - try deserializeRoot( - data: data - ) { context in - try readRootTypedValue(context: context) - } - } - - public func serialize(_ value: T, to buffer: inout Data) throws { - try appendSerializedRoot(to: &buffer) { context in - try writeRootTypedValue(value, context: context) - } - } - - public func deserialize(from buffer: ByteBuffer, as _: T.Type = T.self) throws -> T - { - try deserializeRoot( - from: buffer - ) { context in - try readRootTypedValue(context: context) - } - } - - @_disfavoredOverload - public func serialize(_ value: Any) throws -> Data { - try serializeRoot { context in - try context.writeAny(value, refMode: refMode, writeTypeInfo: true, hasGenerics: false) - } - } - - @_disfavoredOverload - public func deserialize(_ data: Data, as _: Any.Type = Any.self) throws -> Any { - try deserializeRoot( - data: data - ) { context in - try castAnyDynamicValue( - readAny(context: context, refMode: refMode, readTypeInfo: true), - to: Any.self - ) - } - } - - @_disfavoredOverload - public func serialize(_ value: AnyObject) throws -> Data { - try serializeRoot { context in - try context.writeAny(value, refMode: refMode, writeTypeInfo: true, hasGenerics: false) - } - } - - @_disfavoredOverload - public func deserialize(_ data: Data, as _: AnyObject.Type = AnyObject.self) throws -> AnyObject { - try deserializeRoot( - data: data - ) { context in - try castAnyDynamicValue( - readAny(context: context, refMode: refMode, readTypeInfo: true), - to: AnyObject.self - ) - } - } - - @_disfavoredOverload - public func serialize(_ value: any Serializer) throws -> Data { - try serializeRoot { context in - try context.writeAny(value, refMode: refMode, writeTypeInfo: true, hasGenerics: false) - } - } - - @_disfavoredOverload - public func deserialize(_ data: Data, as _: (any Serializer).Type = (any Serializer).self) throws - -> any Serializer - { - try deserializeRoot( - data: data - ) { context in - try castAnyDynamicValue( - readAny(context: context, refMode: refMode, readTypeInfo: true), - to: (any Serializer).self - ) - } - } - - @_disfavoredOverload - public func serialize(_ value: [Any]) throws -> Data { - try serializeRoot { context in - try context.writeListOfAny(value, refMode: refMode, writeTypeInfo: true, hasGenerics: false) - } - } - - @_disfavoredOverload - public func deserialize(_ data: Data, as _: [Any].Type = [Any].self) throws -> [Any] { - try deserializeRoot( - data: data - ) { context in - try readListOfAny(context: context, refMode: refMode, readTypeInfo: true) ?? [] - } - } - - @_disfavoredOverload - public func serialize(_ value: [String: Any]) throws -> Data { - try serializeRoot { context in - try context.writeMapStringToAny( - value, refMode: refMode, writeTypeInfo: true, hasGenerics: false) - } - } - - @_disfavoredOverload - public func deserialize(_ data: Data, as _: [String: Any].Type = [String: Any].self) throws - -> [String: Any] - { - try deserializeRoot( - data: data - ) { context in - try readMapStringToAny(context: context, refMode: refMode, readTypeInfo: true) ?? [:] - } - } - - @_disfavoredOverload - public func serialize(_ value: [Int32: Any]) throws -> Data { - try serializeRoot { context in - try context.writeMapInt32ToAny( - value, refMode: refMode, writeTypeInfo: true, hasGenerics: false) - } - } - - @_disfavoredOverload - public func deserialize(_ data: Data, as _: [Int32: Any].Type = [Int32: Any].self) throws - -> [Int32: Any] - { - try deserializeRoot( - data: data - ) { context in - try readMapInt32ToAny(context: context, refMode: refMode, readTypeInfo: true) ?? [:] - } - } - - @_disfavoredOverload - public func serialize(_ value: [AnyHashable: Any]) throws -> Data { - try serializeRoot { context in - try context.writeMapAnyHashableToAny( - value, refMode: refMode, writeTypeInfo: true, hasGenerics: false) - } - } - - @_disfavoredOverload - public func deserialize(_ data: Data, as _: [AnyHashable: Any].Type = [AnyHashable: Any].self) - throws -> [AnyHashable: Any] - { - try deserializeRoot( - data: data - ) { context in - try readMapAnyHashableToAny(context: context, refMode: refMode, readTypeInfo: true) ?? [:] - } - } - - @_disfavoredOverload - public func serialize(_ value: [Any], to buffer: inout Data) throws { - try appendSerializedRoot(to: &buffer) { context in - try context.writeListOfAny(value, refMode: refMode, writeTypeInfo: true, hasGenerics: false) - } - } - - @_disfavoredOverload - public func serialize(_ value: Any, to buffer: inout Data) throws { - try appendSerializedRoot(to: &buffer) { context in - try context.writeAny(value, refMode: refMode, writeTypeInfo: true, hasGenerics: false) - } - } - - @_disfavoredOverload - public func deserialize(from buffer: ByteBuffer, as _: Any.Type = Any.self) throws -> Any { - try deserializeRoot( - from: buffer - ) { context in - try castAnyDynamicValue( - readAny(context: context, refMode: refMode, readTypeInfo: true), - to: Any.self - ) - } - } - - @_disfavoredOverload - public func serialize(_ value: AnyObject, to buffer: inout Data) throws { - try appendSerializedRoot(to: &buffer) { context in - try context.writeAny(value, refMode: refMode, writeTypeInfo: true, hasGenerics: false) - } - } - - @_disfavoredOverload - public func deserialize(from buffer: ByteBuffer, as _: AnyObject.Type = AnyObject.self) throws - -> AnyObject - { - try deserializeRoot( - from: buffer - ) { context in - try castAnyDynamicValue( - readAny(context: context, refMode: refMode, readTypeInfo: true), - to: AnyObject.self - ) - } - } - - @_disfavoredOverload - public func serialize(_ value: any Serializer, to buffer: inout Data) throws { - try appendSerializedRoot(to: &buffer) { context in - try context.writeAny(value, refMode: refMode, writeTypeInfo: true, hasGenerics: false) - } - } - - @_disfavoredOverload - public func deserialize( - from buffer: ByteBuffer, - as _: (any Serializer).Type = (any Serializer).self - ) throws -> any Serializer { - try deserializeRoot( - from: buffer - ) { context in - try castAnyDynamicValue( - readAny(context: context, refMode: refMode, readTypeInfo: true), - to: (any Serializer).self - ) - } - } - - @_disfavoredOverload - public func deserialize(from buffer: ByteBuffer, as _: [Any].Type = [Any].self) throws -> [Any] { - try deserializeRoot( - from: buffer - ) { context in - try readListOfAny(context: context, refMode: refMode, readTypeInfo: true) ?? [] - } - } - - @_disfavoredOverload - public func serialize(_ value: [String: Any], to buffer: inout Data) throws { - try appendSerializedRoot(to: &buffer) { context in - try context.writeMapStringToAny( - value, refMode: refMode, writeTypeInfo: true, hasGenerics: false) - } - } - - @_disfavoredOverload - public func deserialize(from buffer: ByteBuffer, as _: [String: Any].Type = [String: Any].self) - throws -> [String: Any] - { - try deserializeRoot( - from: buffer - ) { context in - try readMapStringToAny(context: context, refMode: refMode, readTypeInfo: true) ?? [:] - } - } - - @_disfavoredOverload - public func serialize(_ value: [Int32: Any], to buffer: inout Data) throws { - try appendSerializedRoot(to: &buffer) { context in - try context.writeMapInt32ToAny( - value, refMode: refMode, writeTypeInfo: true, hasGenerics: false) - } - } - - @_disfavoredOverload - public func serialize(_ value: [AnyHashable: Any], to buffer: inout Data) throws { - try appendSerializedRoot(to: &buffer) { context in - try context.writeMapAnyHashableToAny( - value, refMode: refMode, writeTypeInfo: true, hasGenerics: false) - } - } - - @_disfavoredOverload - public func deserialize(from buffer: ByteBuffer, as _: [Int32: Any].Type = [Int32: Any].self) - throws -> [Int32: Any] - { - try deserializeRoot( - from: buffer - ) { context in - try readMapInt32ToAny(context: context, refMode: refMode, readTypeInfo: true) ?? [:] - } - } - - @_disfavoredOverload - public func deserialize( - from buffer: ByteBuffer, as _: [AnyHashable: Any].Type = [AnyHashable: Any].self - ) throws -> [AnyHashable: Any] { - try deserializeRoot( - from: buffer - ) { context in - try readMapAnyHashableToAny(context: context, refMode: refMode, readTypeInfo: true) ?? [:] - } - } - - @inlinable - @inline(__always) - func writeHead(buffer: ByteBuffer) { - buffer.writeUInt8(ForyHeaderFlag.isXlang) - } - - @inlinable - @inline(__always) - func readHead(buffer: ByteBuffer) throws { - let bitmap = try buffer.readUInt8() - let expected = ForyHeaderFlag.isXlang - if bitmap != expected { - try readHeadSlow(bitmap: bitmap, expected: expected) - } - } - - @usableFromInline - @inline(never) - func readHeadSlow(bitmap: UInt8, expected: UInt8) throws { - if (bitmap & ~ForyHeaderFlag.knownMask) != 0 || (bitmap & ForyHeaderFlag.isOutOfBand) != 0 { - throw ForyError.invalidData("unsupported root header bitmap 0x\(String(bitmap, radix: 16))") - } - if (bitmap & ForyHeaderFlag.isXlang) != (expected & ForyHeaderFlag.isXlang) { - throw ForyError.invalidData("xlang bitmap mismatch") - } - } - - @inline(__always) - private var refMode: RefMode { - config.trackRef ? .tracking : .nullOnly - } - - private func writeRootTypedValue( - _ value: T, - context: WriteContext - ) throws { - try value.foryWrite( - context, - refMode: refMode, - writeTypeInfo: true, - hasGenerics: false + throws -> [AnyHashable: Any] + { + try deserializeRoot( + data: data + ) { context in + try readMapAnyHashableToAny(context: context, refMode: refMode, readTypeInfo: true) ?? [:] + } + } + + @_disfavoredOverload + public func serialize(_ value: [Any], to buffer: inout Data) throws { + try appendSerializedRoot(to: &buffer) { context in + try context.writeListOfAny(value, refMode: refMode, writeTypeInfo: true, hasGenerics: false) + } + } + + @_disfavoredOverload + public func serialize(_ value: Any, to buffer: inout Data) throws { + try appendSerializedRoot(to: &buffer) { context in + try context.writeAny(value, refMode: refMode, writeTypeInfo: true, hasGenerics: false) + } + } + + @_disfavoredOverload + public func deserialize(from buffer: ByteBuffer, as _: Any.Type = Any.self) throws -> Any { + try deserializeRoot( + from: buffer + ) { context in + try castAnyDynamicValue( + readAny(context: context, refMode: refMode, readTypeInfo: true), + to: Any.self + ) + } + } + + @_disfavoredOverload + public func serialize(_ value: AnyObject, to buffer: inout Data) throws { + try appendSerializedRoot(to: &buffer) { context in + try context.writeAny(value, refMode: refMode, writeTypeInfo: true, hasGenerics: false) + } + } + + @_disfavoredOverload + public func deserialize( + from buffer: ByteBuffer, as _: AnyObject.Type = AnyObject.self + ) throws + -> AnyObject + { + try deserializeRoot( + from: buffer + ) { context in + try castAnyDynamicValue( + readAny(context: context, refMode: refMode, readTypeInfo: true), + to: AnyObject.self + ) + } + } + + @_disfavoredOverload + public func serialize(_ value: any Serializer, to buffer: inout Data) throws { + try appendSerializedRoot(to: &buffer) { context in + try context.writeAny(value, refMode: refMode, writeTypeInfo: true, hasGenerics: false) + } + } + + @_disfavoredOverload + public func deserialize( + from buffer: ByteBuffer, + as _: (any Serializer).Type = (any Serializer).self + ) throws -> any Serializer { + try deserializeRoot( + from: buffer + ) { context in + try castAnyDynamicValue( + readAny(context: context, refMode: refMode, readTypeInfo: true), + to: (any Serializer).self + ) + } + } + + @_disfavoredOverload + public func deserialize(from buffer: ByteBuffer, as _: [Any].Type = [Any].self) throws -> [Any] { + try deserializeRoot( + from: buffer + ) { context in + try readListOfAny(context: context, refMode: refMode, readTypeInfo: true) ?? [] + } + } + + @_disfavoredOverload + public func serialize(_ value: [String: Any], to buffer: inout Data) throws { + try appendSerializedRoot(to: &buffer) { context in + try context.writeMapStringToAny( + value, refMode: refMode, writeTypeInfo: true, hasGenerics: false) + } + } + + @_disfavoredOverload + public func deserialize( + from buffer: ByteBuffer, as _: [String: Any].Type = [String: Any].self ) - } - - @inline(__always) - private func readRootTypedValue( - context: ReadContext - ) throws -> T { - try reserveRootGraphOwner(T.self, context: context) - return try T.foryRead( - context, - refMode: refMode, - readTypeInfo: true + throws -> [String: Any] + { + try deserializeRoot( + from: buffer + ) { context in + try readMapStringToAny(context: context, refMode: refMode, readTypeInfo: true) ?? [:] + } + } + + @_disfavoredOverload + public func serialize(_ value: [Int32: Any], to buffer: inout Data) throws { + try appendSerializedRoot(to: &buffer) { context in + try context.writeMapInt32ToAny( + value, refMode: refMode, writeTypeInfo: true, hasGenerics: false) + } + } + + @_disfavoredOverload + public func serialize(_ value: [AnyHashable: Any], to buffer: inout Data) throws { + try appendSerializedRoot(to: &buffer) { context in + try context.writeMapAnyHashableToAny( + value, refMode: refMode, writeTypeInfo: true, hasGenerics: false) + } + } + + @_disfavoredOverload + public func deserialize( + from buffer: ByteBuffer, as _: [Int32: Any].Type = [Int32: Any].self ) - } - - @inline(__always) - private func reserveRootGraphOwner( - _: T.Type, - context: ReadContext - ) throws { - switch T.staticTypeId { - case .list, .set, .map: - try context.reserveGraphMemory(max(1, MemoryLayout.stride)) - default: - break - } - } - - @inline(__always) - func withReusableReadContext( - data: Data, - _ body: (ReadContext) throws -> R - ) throws -> R { - readContext.buffer.replace(with: data) - try readContext.initGraphMemoryBudget() - defer { - readContext.reset() - } - return try body(readContext) - } - - @inline(__always) - private func serializeRoot( - _ body: (WriteContext) throws -> Void - ) throws -> Data { - try typeResolver.finishRegistration() - let context = writeContext - context.buffer.clear() - defer { - context.reset() - } - writeHead(buffer: context.buffer) - try body(context) - return context.buffer.copyToData() - } - - @inline(__always) - private func appendSerializedRoot( - to output: inout Data, - _ body: (WriteContext) throws -> Void - ) throws { - try typeResolver.finishRegistration() - let context = writeContext - context.buffer.clear() - defer { - context.reset() - } - writeHead(buffer: context.buffer) - try body(context) - output.append(contentsOf: context.buffer.storage.prefix(context.buffer.count)) - } - - @inline(__always) - private func deserializeRoot( - data: Data, - _ body: (ReadContext) throws -> R - ) throws -> R { - try typeResolver.finishRegistration() - return try withReusableReadContext(data: data) { context in - try readHead(buffer: context.buffer) - let value = try body(context) - if context.buffer.remaining != 0 { - throw ForyError.invalidData( - "unexpected trailing bytes at root: \(context.buffer.remaining)") - } - return value - } - } - - @inline(__always) - private func deserializeRoot( - from buffer: ByteBuffer, - _ body: (ReadContext) throws -> R - ) throws -> R { - try typeResolver.finishRegistration() - readContext.buffer.swapState(with: buffer) - try readContext.initGraphMemoryBudget() - defer { - readContext.buffer.swapState(with: buffer) - readContext.reset() - } - try readHead(buffer: readContext.buffer) - return try body(readContext) - } + throws -> [Int32: Any] + { + try deserializeRoot( + from: buffer + ) { context in + try readMapInt32ToAny(context: context, refMode: refMode, readTypeInfo: true) ?? [:] + } + } + + @_disfavoredOverload + public func deserialize( + from buffer: ByteBuffer, as _: [AnyHashable: Any].Type = [AnyHashable: Any].self + ) throws -> [AnyHashable: Any] { + try deserializeRoot( + from: buffer + ) { context in + try readMapAnyHashableToAny(context: context, refMode: refMode, readTypeInfo: true) ?? [:] + } + } + + @inlinable + @inline(__always) + func writeHead(buffer: ByteBuffer) { + buffer.writeUInt8(ForyHeaderFlag.isXlang) + } + + @inlinable + @inline(__always) + func readHead(buffer: ByteBuffer) throws { + let bitmap = try buffer.readUInt8() + let expected = ForyHeaderFlag.isXlang + if bitmap != expected { + try readHeadSlow(bitmap: bitmap, expected: expected) + } + } + + @usableFromInline + @inline(never) + func readHeadSlow(bitmap: UInt8, expected: UInt8) throws { + if (bitmap & ~ForyHeaderFlag.knownMask) != 0 || (bitmap & ForyHeaderFlag.isOutOfBand) != 0 { + throw ForyError.invalidData("unsupported root header bitmap 0x\(String(bitmap, radix: 16))") + } + if (bitmap & ForyHeaderFlag.isXlang) != (expected & ForyHeaderFlag.isXlang) { + throw ForyError.invalidData("xlang bitmap mismatch") + } + } + + @inline(__always) + private var refMode: RefMode { + config.trackRef ? .tracking : .nullOnly + } + + private func writeRootTypedValue( + _ value: T, + context: WriteContext + ) throws { + try value.foryWrite( + context, + refMode: refMode, + writeTypeInfo: true, + hasGenerics: false + ) + } + + @inline(__always) + private func readRootTypedValue( + context: ReadContext + ) throws -> T { + try reserveRootGraphOwner(T.self, context: context) + return try T.foryRead( + context, + refMode: refMode, + readTypeInfo: true + ) + } + + @inline(__always) + private func reserveRootGraphOwner( + _: T.Type, + context: ReadContext + ) throws { + switch T.staticTypeId { + case .list, .set, .map: + try context.reserveGraphMemory(max(1, MemoryLayout.stride)) + default: + break + } + } + + @inline(__always) + func withReusableReadContext( + data: Data, + _ body: (ReadContext) throws -> R + ) throws -> R { + readContext.buffer.replace(with: data) + try readContext.initGraphMemoryBudget() + defer { + readContext.reset() + } + return try body(readContext) + } + + @inline(__always) + private func serializeRoot( + _ body: (WriteContext) throws -> Void + ) throws -> Data { + try typeResolver.finishRegistration() + let context = writeContext + context.buffer.clear() + defer { + context.reset() + } + writeHead(buffer: context.buffer) + try body(context) + return context.buffer.copyToData() + } + + @inline(__always) + private func appendSerializedRoot( + to output: inout Data, + _ body: (WriteContext) throws -> Void + ) throws { + try typeResolver.finishRegistration() + let context = writeContext + context.buffer.clear() + defer { + context.reset() + } + writeHead(buffer: context.buffer) + try body(context) + output.append(contentsOf: context.buffer.storage.prefix(context.buffer.count)) + } + + @inline(__always) + private func deserializeRoot( + data: Data, + _ body: (ReadContext) throws -> R + ) throws -> R { + try typeResolver.finishRegistration() + return try withReusableReadContext(data: data) { context in + try readHead(buffer: context.buffer) + let value = try body(context) + if context.buffer.remaining != 0 { + throw ForyError.invalidData( + "unexpected trailing bytes at root: \(context.buffer.remaining)") + } + return value + } + } + + @inline(__always) + private func deserializeRoot( + from buffer: ByteBuffer, + _ body: (ReadContext) throws -> R + ) throws -> R { + try typeResolver.finishRegistration() + readContext.buffer.swapState(with: buffer) + try readContext.initGraphMemoryBudget() + defer { + readContext.buffer.swapState(with: buffer) + readContext.reset() + } + try readHead(buffer: readContext.buffer) + return try body(readContext) + } } diff --git a/swift/Sources/Fory/ReadContext.swift b/swift/Sources/Fory/ReadContext.swift index 09f99fa749..a8680932f3 100644 --- a/swift/Sources/Fory/ReadContext.swift +++ b/swift/Sources/Fory/ReadContext.swift @@ -20,732 +20,732 @@ import Foundation private let typeMetaSizeMask = 0xFF public final class ReadContext { - public let buffer: ByteBuffer - let typeResolver: TypeResolver - public let trackRef: Bool - public let compatible: Bool - public let checkClassVersion: Bool - public let maxDepth: Int - public let refReader: RefReader - private let compatibleTypeDefTypeInfos = ReusableArray(defaultValue: nil, reserve: 2) - private let metaStrings = ReusableArray(defaultValue: nil, reserve: 16) - private var dynamicAnyDepth = 0 - - private var typeInfoStack = UInt64Map(initialCapacity: 8) - private var typeInfoScopeStack: [(typeKey: UInt64, previousTypeInfo: TypeInfo?)] = [] - private var lastTypeInfo = TypeInfo.uncached - private let config: Config - private let maxGraphMemoryBytes: Int - private var remainingGraphMemoryBytes = Int.max - - init( - buffer: ByteBuffer, - typeResolver: TypeResolver, - config: Config - ) { - self.buffer = buffer - self.typeResolver = typeResolver - self.trackRef = config.trackRef - self.compatible = config.compatible - self.checkClassVersion = config.checkClassVersion - self.maxDepth = config.maxDepth - self.config = config - self.maxGraphMemoryBytes = Int(config.maxGraphMemoryBytes) - self.refReader = RefReader() - } - - @inline(__always) - func initGraphMemoryBudget() throws { - remainingGraphMemoryBytes = maxGraphMemoryBytes > 0 ? maxGraphMemoryBytes : Int.max - } - - @inline(__always) - public func reserveGraphMemory(_ bytes: Int) throws { - if bytes < 0 { - try throwGraphMemoryOverflow() - } - if maxGraphMemoryBytes <= 0 { - return - } - if bytes > remainingGraphMemoryBytes { - try throwGraphMemoryExceeded(bytes: bytes) - } - remainingGraphMemoryBytes -= bytes - } - - @inline(never) - private func throwGraphMemoryOverflow() throws -> Never { - throw ForyError.invalidData("graph memory estimate overflows") - } - - @inline(never) - private func throwGraphMemoryExceeded(bytes: Int) throws -> Never { - let message = - "estimated graph memory request \(bytes) bytes exceeds maxGraphMemoryBytes " - + "remaining budget \(remainingGraphMemoryBytes) bytes" - throw ForyError.invalidData(message) - } - - @inline(__always) - func enterDynamicAnyDepth() throws { - if maxDepth < 0 { - throw ForyError.invalidData("configured maxDepth \(maxDepth) is negative") - } - let nextDepth = dynamicAnyDepth + 1 - if nextDepth > maxDepth { - throw ForyError.invalidData( - "dynamic Any nesting depth \(nextDepth) exceeds configured maxDepth \(maxDepth)" - ) + public let buffer: ByteBuffer + let typeResolver: TypeResolver + public let trackRef: Bool + public let compatible: Bool + public let checkClassVersion: Bool + public let maxDepth: Int + public let refReader: RefReader + private let compatibleTypeDefTypeInfos = ReusableArray(defaultValue: nil, reserve: 2) + private let metaStrings = ReusableArray(defaultValue: nil, reserve: 16) + private var dynamicAnyDepth = 0 + + private var typeInfoStack = UInt64Map(initialCapacity: 8) + private var typeInfoScopeStack: [(typeKey: UInt64, previousTypeInfo: TypeInfo?)] = [] + private var lastTypeInfo = TypeInfo.uncached + private let config: Config + private let maxGraphMemoryBytes: Int + private var remainingGraphMemoryBytes = Int.max + + init( + buffer: ByteBuffer, + typeResolver: TypeResolver, + config: Config + ) { + self.buffer = buffer + self.typeResolver = typeResolver + self.trackRef = config.trackRef + self.compatible = config.compatible + self.checkClassVersion = config.checkClassVersion + self.maxDepth = config.maxDepth + self.config = config + self.maxGraphMemoryBytes = Int(config.maxGraphMemoryBytes) + self.refReader = RefReader() + } + + @inline(__always) + func initGraphMemoryBudget() throws { + remainingGraphMemoryBytes = maxGraphMemoryBytes > 0 ? maxGraphMemoryBytes : Int.max + } + + @inline(__always) + public func reserveGraphMemory(_ bytes: Int) throws { + if bytes < 0 { + try throwGraphMemoryOverflow() + } + if maxGraphMemoryBytes <= 0 { + return + } + if bytes > remainingGraphMemoryBytes { + try throwGraphMemoryExceeded(bytes: bytes) + } + remainingGraphMemoryBytes -= bytes } - dynamicAnyDepth = nextDepth - } - @inline(__always) - func leaveDynamicAnyDepth() { - if dynamicAnyDepth > 0 { - dynamicAnyDepth -= 1 + @inline(never) + private func throwGraphMemoryOverflow() throws -> Never { + throw ForyError.invalidData("graph memory estimate overflows") } - } - @inline(__always) - func ensureCollectionLength(_ length: Int, label: String) throws { - if length < 0 { - throw ForyError.invalidData("\(label) length is negative") + @inline(never) + private func throwGraphMemoryExceeded(bytes: Int) throws -> Never { + let message = + "estimated graph memory request \(bytes) bytes exceeds maxGraphMemoryBytes " + + "remaining budget \(remainingGraphMemoryBytes) bytes" + throw ForyError.invalidData(message) } - } - @inline(__always) - func ensureRemainingBytes(_ byteCount: Int, label: String) throws { - if byteCount < 0 { - throw ForyError.invalidData("\(label) size is negative") - } - let remainingBytes = buffer.remaining - if byteCount > remainingBytes { - throw ForyError.invalidData( - "\(label) requires \(byteCount) bytes but only \(remainingBytes) remain in buffer" - ) + @inline(__always) + func enterDynamicAnyDepth() throws { + if maxDepth < 0 { + throw ForyError.invalidData("configured maxDepth \(maxDepth) is negative") + } + let nextDepth = dynamicAnyDepth + 1 + if nextDepth > maxDepth { + throw ForyError.invalidData( + "dynamic Any nesting depth \(nextDepth) exceeds configured maxDepth \(maxDepth)" + ) + } + dynamicAnyDepth = nextDepth } - } - @inline(__always) - func typeInfo(for type: T.Type) throws -> TypeInfo { - let typeID = ObjectIdentifier(type) - if lastTypeInfo.swiftTypeID == typeID { - return lastTypeInfo - } - let info = try typeResolver.requireTypeInfo(for: type) - lastTypeInfo = info - return info - } - - @inline(__always) - func readStaticTypeInfo(_ typeID: TypeId) throws -> TypeInfo? { - let rawTypeID = UInt32(try buffer.readUInt8()) - guard let actualTypeID = TypeId(rawValue: rawTypeID) else { - throw ForyError.invalidData("unknown type id \(rawTypeID)") - } - if actualTypeID != typeID { - throw ForyError.typeMismatch(expected: typeID.rawValue, actual: rawTypeID) + @inline(__always) + func leaveDynamicAnyDepth() { + if dynamicAnyDepth > 0 { + dynamicAnyDepth -= 1 + } } - return nil - } - func readTypeInfo() throws -> TypeInfo { - let rawTypeID = UInt32(try buffer.readUInt8()) - guard let wireTypeID = TypeId(rawValue: rawTypeID) else { - throw ForyError.invalidData("unknown dynamic type id \(rawTypeID)") + @inline(__always) + func ensureCollectionLength(_ length: Int, label: String) throws { + if length < 0 { + throw ForyError.invalidData("\(label) length is negative") + } } - switch wireTypeID { - case .compatibleStruct, .namedCompatibleStruct: - return try readCompatibleTypeInfo() - case .namedEnum, .namedStruct, .namedExt, .namedUnion: - if compatible { - return try readCompatibleTypeInfo() - } - let namespace = try readMetaString( - context: self, - decoder: .namespace, - encodings: namespaceMetaStringEncodings - ) - let typeName = try readMetaString( - context: self, - decoder: .typeName, - encodings: typeNameMetaStringEncodings - ) - return try typeResolver.requireTypeInfo(namespace: namespace.value, typeName: typeName.value) - case .structType, .enumType, .ext, .typedUnion, .union: - let userTypeID = try buffer.readVarUInt32() - return try typeResolver.requireTypeInfo(userTypeID: userTypeID) - default: - return typeResolver.builtinTypeInfo(for: wireTypeID) + @inline(__always) + func ensureRemainingBytes(_ byteCount: Int, label: String) throws { + if byteCount < 0 { + throw ForyError.invalidData("\(label) size is negative") + } + let remainingBytes = buffer.remaining + if byteCount > remainingBytes { + throw ForyError.invalidData( + "\(label) requires \(byteCount) bytes but only \(remainingBytes) remain in buffer" + ) + } } - } - func readTypeInfo(for type: T.Type) throws -> TypeInfo? { - let rawTypeID = UInt32(try buffer.readUInt8()) - guard let typeID = TypeId(rawValue: rawTypeID) else { - throw ForyError.invalidData("unknown type id \(rawTypeID)") + @inline(__always) + func typeInfo(for type: T.Type) throws -> TypeInfo { + let typeID = ObjectIdentifier(type) + if lastTypeInfo.swiftTypeID == typeID { + return lastTypeInfo + } + let info = try typeResolver.requireTypeInfo(for: type) + lastTypeInfo = info + return info } - guard T.staticTypeId.isUserTypeKind else { - if typeID != T.staticTypeId { - throw ForyError.typeMismatch(expected: T.staticTypeId.rawValue, actual: rawTypeID) - } - return nil + @inline(__always) + func readStaticTypeInfo(_ typeID: TypeId) throws -> TypeInfo? { + let rawTypeID = UInt32(try buffer.readUInt8()) + guard let actualTypeID = TypeId(rawValue: rawTypeID) else { + throw ForyError.invalidData("unknown type id \(rawTypeID)") + } + if actualTypeID != typeID { + throw ForyError.typeMismatch(expected: typeID.rawValue, actual: rawTypeID) + } + return nil } - let localTypeInfo = try typeInfo(for: type) - let expectedWireTypeID = localTypeInfo.wireTypeID(compatible: compatible) - if !isAllowedRegisteredWireTypeID( - typeID, - declaredTypeID: localTypeInfo.typeID, - registerByName: localTypeInfo.registerByName, - compatible: compatible, - evolving: localTypeInfo.evolving - ) { - throw ForyError.typeMismatch(expected: expectedWireTypeID.rawValue, actual: rawTypeID) - } + func readTypeInfo() throws -> TypeInfo { + let rawTypeID = UInt32(try buffer.readUInt8()) + guard let wireTypeID = TypeId(rawValue: rawTypeID) else { + throw ForyError.invalidData("unknown dynamic type id \(rawTypeID)") + } - switch typeID { - case .compatibleStruct, .namedCompatibleStruct: - return try readCompatibleTypeInfoIfNeeded( - for: localTypeInfo, - wireTypeID: typeID - ) - case .namedEnum, .namedStruct, .namedExt, .namedUnion: - if compatible { - _ = try readCompatibleTypeInfoIfNeeded( - for: localTypeInfo, - wireTypeID: typeID - ) - } else { - let namespace = try readMetaString( - context: self, - decoder: .namespace, - encodings: namespaceMetaStringEncodings - ) - let typeName = try readMetaString( - context: self, - decoder: .typeName, - encodings: typeNameMetaStringEncodings - ) - guard localTypeInfo.registerByName else { - throw ForyError.invalidData( - "received name-registered type info for id-registered local type") + switch wireTypeID { + case .compatibleStruct, .namedCompatibleStruct: + return try readCompatibleTypeInfo() + case .namedEnum, .namedStruct, .namedExt, .namedUnion: + if compatible { + return try readCompatibleTypeInfo() + } + let namespace = try readMetaString( + context: self, + decoder: .namespace, + encodings: namespaceMetaStringEncodings + ) + let typeName = try readMetaString( + context: self, + decoder: .typeName, + encodings: typeNameMetaStringEncodings + ) + return try typeResolver.requireTypeInfo(namespace: namespace.value, typeName: typeName.value) + case .structType, .enumType, .ext, .typedUnion, .union: + let userTypeID = try buffer.readVarUInt32() + return try typeResolver.requireTypeInfo(userTypeID: userTypeID) + default: + return typeResolver.builtinTypeInfo(for: wireTypeID) } - if namespace.value != localTypeInfo.namespace.value - || typeName.value != localTypeInfo.typeName.value - { - let expectedTypeName = "\(localTypeInfo.namespace.value)::\(localTypeInfo.typeName.value)" - let actualTypeName = "\(namespace.value)::\(typeName.value)" - throw ForyError.invalidData( - "type name mismatch: expected \(expectedTypeName), got \(actualTypeName)" - ) - } - } - default: - if !localTypeInfo.registerByName && registeredWireTypeNeedsUserTypeID(typeID) { - guard let localUserTypeID = localTypeInfo.userTypeID else { - throw ForyError.invalidData("missing user type id for id-registered type") - } - let remoteUserTypeID = try buffer.readVarUInt32() - if remoteUserTypeID != localUserTypeID { - throw ForyError.typeMismatch(expected: localUserTypeID, actual: remoteUserTypeID) - } - } } - return nil - } - - @inline(__always) - private func readCompatibleTypeInfoIfNeeded( - for localTypeInfo: TypeInfo, - wireTypeID: TypeId - ) throws -> TypeInfo? { - let buffer = self.buffer - let compatibleTypeDefTypeInfos = self.compatibleTypeDefTypeInfos - if !checkClassVersion, - compatibleTypeDefTypeInfos.isEmpty, - !localTypeInfo.typeDefHasUserTypeFields, - let localTypeDefHeader = localTypeInfo.typeDefHeader - { - let indexMarker = try buffer.readVarUInt32() - if indexMarker == 0 { - let headerStart = buffer.getCursor() - let header = try buffer.readUInt64() - var bodySize = Int(header & UInt64(typeMetaSizeMask)) - if bodySize == typeMetaSizeMask { - bodySize += Int(try buffer.readVarUInt32()) - } - if header == localTypeDefHeader { - // The declared local type owns this exact metadata header, so this is a - // local-schema hit rather than a remote cache publish. Keep it allocation-free: - // skip the body, add the local type to the per-read table, and do not parse/hash. - try buffer.skip(bodySize) - compatibleTypeDefTypeInfos.push(localTypeInfo) - return nil + + func readTypeInfo(for type: T.Type) throws -> TypeInfo? { + let rawTypeID = UInt32(try buffer.readUInt8()) + guard let typeID = TypeId(rawValue: rawTypeID) else { + throw ForyError.invalidData("unknown type id \(rawTypeID)") } - if let cached = typeResolver.getTypeInfo(forHeader: header) { - // Header-cache hits intentionally skip without rehashing. Entries reach this cache only - // after a successful TypeDef parse and 52-bit metadata-hash validation. Do not add - // body/hash/schema-limit/exact-local checks here; the miss path owns them before publish. - try buffer.skip(bodySize) - compatibleTypeDefTypeInfos.push(cached) - return try validateCompatibleTypeInfo(cached, for: localTypeInfo, wireTypeID: wireTypeID) + + guard T.staticTypeId.isUserTypeKind else { + if typeID != T.staticTypeId { + throw ForyError.typeMismatch(expected: T.staticTypeId.rawValue, actual: rawTypeID) + } + return nil } - let cachedTypeInfo = try readTypeInfoBody( - start: headerStart, - header: header, - for: localTypeInfo, - wireTypeID: wireTypeID) - compatibleTypeDefTypeInfos.push(cachedTypeInfo) - if cachedTypeInfo === localTypeInfo { - return nil + + let localTypeInfo = try typeInfo(for: type) + let expectedWireTypeID = localTypeInfo.wireTypeID(compatible: compatible) + if !isAllowedRegisteredWireTypeID( + typeID, + declaredTypeID: localTypeInfo.typeID, + registerByName: localTypeInfo.registerByName, + compatible: compatible, + evolving: localTypeInfo.evolving + ) { + throw ForyError.typeMismatch(expected: expectedWireTypeID.rawValue, actual: rawTypeID) } - return try validateCompatibleTypeInfo( - cachedTypeInfo, for: localTypeInfo, wireTypeID: wireTypeID) - } - return try readCompatibleTypeInfo( - afterMarker: indexMarker, - for: localTypeInfo, - wireTypeID: wireTypeID) - } - return try readCompatibleTypeInfo( - for: localTypeInfo, - wireTypeID: wireTypeID - ) - } - - private func readCompatibleTypeInfo() throws -> TypeInfo { - let indexMarker = try buffer.readVarUInt32() - return try readCompatibleTypeInfo(afterMarker: indexMarker) - } - - private func readCompatibleTypeInfo(afterMarker indexMarker: UInt32) throws -> TypeInfo { - let buffer = self.buffer - let compatibleTypeDefTypeInfos = self.compatibleTypeDefTypeInfos - let isRef = (indexMarker & 1) == 1 - let index = Int(indexMarker >> 1) - if isRef { - guard let typeInfo = compatibleTypeDefTypeInfos.get(index) else { - throw ForyError.invalidData("unknown compatible type definition ref index \(index)") - } - return typeInfo - } - let typeMetaStart = buffer.getCursor() - let header = try buffer.readUInt64() - var bodySize = Int(header & UInt64(typeMetaSizeMask)) - if bodySize == typeMetaSizeMask { - bodySize += Int(try buffer.readVarUInt32()) - } - if let cached = typeResolver.getTypeInfo(forHeader: header) { - // Header-cache hits intentionally skip without rehashing. Entries reach this cache only - // after a successful TypeDef parse and 52-bit metadata-hash validation. Do not add - // body/hash/schema-limit/exact-local checks here; the miss path owns them before publish. - try buffer.skip(bodySize) - compatibleTypeDefTypeInfos.push(cached) - return cached + switch typeID { + case .compatibleStruct, .namedCompatibleStruct: + return try readCompatibleTypeInfoIfNeeded( + for: localTypeInfo, + wireTypeID: typeID + ) + case .namedEnum, .namedStruct, .namedExt, .namedUnion: + if compatible { + _ = try readCompatibleTypeInfoIfNeeded( + for: localTypeInfo, + wireTypeID: typeID + ) + } else { + let namespace = try readMetaString( + context: self, + decoder: .namespace, + encodings: namespaceMetaStringEncodings + ) + let typeName = try readMetaString( + context: self, + decoder: .typeName, + encodings: typeNameMetaStringEncodings + ) + guard localTypeInfo.registerByName else { + throw ForyError.invalidData( + "received name-registered type info for id-registered local type") + } + if namespace.value != localTypeInfo.namespace.value + || typeName.value != localTypeInfo.typeName.value + { + let expectedTypeName = "\(localTypeInfo.namespace.value)::\(localTypeInfo.typeName.value)" + let actualTypeName = "\(namespace.value)::\(typeName.value)" + throw ForyError.invalidData( + "type name mismatch: expected \(expectedTypeName), got \(actualTypeName)" + ) + } + } + default: + if !localTypeInfo.registerByName && registeredWireTypeNeedsUserTypeID(typeID) { + guard let localUserTypeID = localTypeInfo.userTypeID else { + throw ForyError.invalidData("missing user type id for id-registered type") + } + let remoteUserTypeID = try buffer.readVarUInt32() + if remoteUserTypeID != localUserTypeID { + throw ForyError.typeMismatch(expected: localUserTypeID, actual: remoteUserTypeID) + } + } + } + return nil + } + + @inline(__always) + private func readCompatibleTypeInfoIfNeeded( + for localTypeInfo: TypeInfo, + wireTypeID: TypeId + ) throws -> TypeInfo? { + let buffer = self.buffer + let compatibleTypeDefTypeInfos = self.compatibleTypeDefTypeInfos + if !checkClassVersion, + compatibleTypeDefTypeInfos.isEmpty, + !localTypeInfo.typeDefHasUserTypeFields, + let localTypeDefHeader = localTypeInfo.typeDefHeader + { + let indexMarker = try buffer.readVarUInt32() + if indexMarker == 0 { + let headerStart = buffer.getCursor() + let header = try buffer.readUInt64() + var bodySize = Int(header & UInt64(typeMetaSizeMask)) + if bodySize == typeMetaSizeMask { + bodySize += Int(try buffer.readVarUInt32()) + } + if header == localTypeDefHeader { + // The declared local type owns this exact metadata header, so this is a + // local-schema hit rather than a remote cache publish. Keep it allocation-free: + // skip the body, add the local type to the per-read table, and do not parse/hash. + try buffer.skip(bodySize) + compatibleTypeDefTypeInfos.push(localTypeInfo) + return nil + } + if let cached = typeResolver.getTypeInfo(forHeader: header) { + // Header-cache hits intentionally skip without rehashing. Entries reach this cache only + // after a successful TypeDef parse and 52-bit metadata-hash validation. Do not add + // body/hash/schema-limit/exact-local checks here; the miss path owns them before publish. + try buffer.skip(bodySize) + compatibleTypeDefTypeInfos.push(cached) + return try validateCompatibleTypeInfo(cached, for: localTypeInfo, wireTypeID: wireTypeID) + } + let cachedTypeInfo = try readTypeInfoBody( + start: headerStart, + header: header, + for: localTypeInfo, + wireTypeID: wireTypeID) + compatibleTypeDefTypeInfos.push(cachedTypeInfo) + if cachedTypeInfo === localTypeInfo { + return nil + } + return try validateCompatibleTypeInfo( + cachedTypeInfo, for: localTypeInfo, wireTypeID: wireTypeID) + } + return try readCompatibleTypeInfo( + afterMarker: indexMarker, + for: localTypeInfo, + wireTypeID: wireTypeID) + } + return try readCompatibleTypeInfo( + for: localTypeInfo, + wireTypeID: wireTypeID + ) } - let cachedTypeInfo = try readTypeInfoBody(start: typeMetaStart, header: header) - compatibleTypeDefTypeInfos.push(cachedTypeInfo) - return cachedTypeInfo - } - - @inline(never) - private func readCompatibleTypeInfo( - afterMarker indexMarker: UInt32, - for localTypeInfo: TypeInfo, - wireTypeID: TypeId - ) throws -> TypeInfo { - let buffer = self.buffer - let compatibleTypeDefTypeInfos = self.compatibleTypeDefTypeInfos - let isRef = (indexMarker & 1) == 1 - let index = Int(indexMarker >> 1) - if isRef { - guard let typeInfo = compatibleTypeDefTypeInfos.get(index) else { - throw ForyError.invalidData("unknown compatible type definition ref index \(index)") - } - return try validateCompatibleTypeInfo(typeInfo, for: localTypeInfo, wireTypeID: wireTypeID) + private func readCompatibleTypeInfo() throws -> TypeInfo { + let indexMarker = try buffer.readVarUInt32() + return try readCompatibleTypeInfo(afterMarker: indexMarker) } - let typeMetaStart = buffer.getCursor() - let header = try buffer.readUInt64() - var bodySize = Int(header & UInt64(typeMetaSizeMask)) - if bodySize == typeMetaSizeMask { - bodySize += Int(try buffer.readVarUInt32()) - } - if let cached = typeResolver.getTypeInfo(forHeader: header) { - // Header-cache hits intentionally skip without rehashing. Entries reach this cache only - // after a successful TypeDef parse and 52-bit metadata-hash validation. Do not add - // body/hash/schema-limit/exact-local checks here; the miss path owns them before publish. - try buffer.skip(bodySize) - compatibleTypeDefTypeInfos.push(cached) - return try validateCompatibleTypeInfo(cached, for: localTypeInfo, wireTypeID: wireTypeID) - } + private func readCompatibleTypeInfo(afterMarker indexMarker: UInt32) throws -> TypeInfo { + let buffer = self.buffer + let compatibleTypeDefTypeInfos = self.compatibleTypeDefTypeInfos + let isRef = (indexMarker & 1) == 1 + let index = Int(indexMarker >> 1) + if isRef { + guard let typeInfo = compatibleTypeDefTypeInfos.get(index) else { + throw ForyError.invalidData("unknown compatible type definition ref index \(index)") + } + return typeInfo + } - let cachedTypeInfo = try readTypeInfoBody( - start: typeMetaStart, - header: header, - for: localTypeInfo, - wireTypeID: wireTypeID) - compatibleTypeDefTypeInfos.push(cachedTypeInfo) - return try validateCompatibleTypeInfo( - cachedTypeInfo, for: localTypeInfo, wireTypeID: wireTypeID) - } - - @inline(__always) - private func readCompatibleTypeInfo( - for localTypeInfo: TypeInfo, - wireTypeID: TypeId - ) throws -> TypeInfo { - let buffer = self.buffer - let compatibleTypeDefTypeInfos = self.compatibleTypeDefTypeInfos - if compatibleTypeDefTypeInfos.isEmpty, - let localTypeDefHeader = localTypeInfo.typeDefHeader - { - let indexMarker = try buffer.readVarUInt32() - if indexMarker != 0 { - return try readCompatibleTypeInfo( - afterMarker: indexMarker, - for: localTypeInfo, - wireTypeID: wireTypeID) - } else { - let headerStart = buffer.getCursor() + let typeMetaStart = buffer.getCursor() let header = try buffer.readUInt64() var bodySize = Int(header & UInt64(typeMetaSizeMask)) if bodySize == typeMetaSizeMask { - bodySize += Int(try buffer.readVarUInt32()) + bodySize += Int(try buffer.readVarUInt32()) + } + if let cached = typeResolver.getTypeInfo(forHeader: header) { + // Header-cache hits intentionally skip without rehashing. Entries reach this cache only + // after a successful TypeDef parse and 52-bit metadata-hash validation. Do not add + // body/hash/schema-limit/exact-local checks here; the miss path owns them before publish. + try buffer.skip(bodySize) + compatibleTypeDefTypeInfos.push(cached) + return cached } - if header == localTypeDefHeader { - // The declared local type owns this exact metadata header, so this is a - // local-schema hit rather than a remote cache publish. Keep it allocation-free: - // skip the body, add the local type to the per-read table, and do not parse/hash. - try buffer.skip(bodySize) - compatibleTypeDefTypeInfos.push(localTypeInfo) - return localTypeInfo + let cachedTypeInfo = try readTypeInfoBody(start: typeMetaStart, header: header) + compatibleTypeDefTypeInfos.push(cachedTypeInfo) + return cachedTypeInfo + } + + @inline(never) + private func readCompatibleTypeInfo( + afterMarker indexMarker: UInt32, + for localTypeInfo: TypeInfo, + wireTypeID: TypeId + ) throws -> TypeInfo { + let buffer = self.buffer + let compatibleTypeDefTypeInfos = self.compatibleTypeDefTypeInfos + let isRef = (indexMarker & 1) == 1 + let index = Int(indexMarker >> 1) + if isRef { + guard let typeInfo = compatibleTypeDefTypeInfos.get(index) else { + throw ForyError.invalidData("unknown compatible type definition ref index \(index)") + } + return try validateCompatibleTypeInfo(typeInfo, for: localTypeInfo, wireTypeID: wireTypeID) } + let typeMetaStart = buffer.getCursor() + let header = try buffer.readUInt64() + var bodySize = Int(header & UInt64(typeMetaSizeMask)) + if bodySize == typeMetaSizeMask { + bodySize += Int(try buffer.readVarUInt32()) + } if let cached = typeResolver.getTypeInfo(forHeader: header) { - // Header-cache hits intentionally skip without rehashing. Entries reach this cache only - // after a successful TypeDef parse and 52-bit metadata-hash validation. Do not add - // body/hash/schema-limit/exact-local checks here; the miss path owns them before publish. - try buffer.skip(bodySize) - compatibleTypeDefTypeInfos.push(cached) - return try validateCompatibleTypeInfo(cached, for: localTypeInfo, wireTypeID: wireTypeID) - } else { - let remoteTypeInfo = try readTypeInfoBody( - start: headerStart, + // Header-cache hits intentionally skip without rehashing. Entries reach this cache only + // after a successful TypeDef parse and 52-bit metadata-hash validation. Do not add + // body/hash/schema-limit/exact-local checks here; the miss path owns them before publish. + try buffer.skip(bodySize) + compatibleTypeDefTypeInfos.push(cached) + return try validateCompatibleTypeInfo(cached, for: localTypeInfo, wireTypeID: wireTypeID) + } + + let cachedTypeInfo = try readTypeInfoBody( + start: typeMetaStart, header: header, for: localTypeInfo, wireTypeID: wireTypeID) - compatibleTypeDefTypeInfos.push(remoteTypeInfo) - return try validateCompatibleTypeInfo( - remoteTypeInfo, for: localTypeInfo, wireTypeID: wireTypeID) + compatibleTypeDefTypeInfos.push(cachedTypeInfo) + return try validateCompatibleTypeInfo( + cachedTypeInfo, for: localTypeInfo, wireTypeID: wireTypeID) + } + + @inline(__always) + private func readCompatibleTypeInfo( + for localTypeInfo: TypeInfo, + wireTypeID: TypeId + ) throws -> TypeInfo { + let buffer = self.buffer + let compatibleTypeDefTypeInfos = self.compatibleTypeDefTypeInfos + if compatibleTypeDefTypeInfos.isEmpty, + let localTypeDefHeader = localTypeInfo.typeDefHeader + { + let indexMarker = try buffer.readVarUInt32() + if indexMarker != 0 { + return try readCompatibleTypeInfo( + afterMarker: indexMarker, + for: localTypeInfo, + wireTypeID: wireTypeID) + } else { + let headerStart = buffer.getCursor() + let header = try buffer.readUInt64() + var bodySize = Int(header & UInt64(typeMetaSizeMask)) + if bodySize == typeMetaSizeMask { + bodySize += Int(try buffer.readVarUInt32()) + } + + if header == localTypeDefHeader { + // The declared local type owns this exact metadata header, so this is a + // local-schema hit rather than a remote cache publish. Keep it allocation-free: + // skip the body, add the local type to the per-read table, and do not parse/hash. + try buffer.skip(bodySize) + compatibleTypeDefTypeInfos.push(localTypeInfo) + return localTypeInfo + } + + if let cached = typeResolver.getTypeInfo(forHeader: header) { + // Header-cache hits intentionally skip without rehashing. Entries reach this cache only + // after a successful TypeDef parse and 52-bit metadata-hash validation. Do not add + // body/hash/schema-limit/exact-local checks here; the miss path owns them before publish. + try buffer.skip(bodySize) + compatibleTypeDefTypeInfos.push(cached) + return try validateCompatibleTypeInfo(cached, for: localTypeInfo, wireTypeID: wireTypeID) + } else { + let remoteTypeInfo = try readTypeInfoBody( + start: headerStart, + header: header, + for: localTypeInfo, + wireTypeID: wireTypeID) + compatibleTypeDefTypeInfos.push(remoteTypeInfo) + return try validateCompatibleTypeInfo( + remoteTypeInfo, for: localTypeInfo, wireTypeID: wireTypeID) + } + } } - } - } - let indexMarker = try buffer.readVarUInt32() - return try readCompatibleTypeInfo( - afterMarker: indexMarker, - for: localTypeInfo, - wireTypeID: wireTypeID) - } - - @inline(never) - private func readTypeInfoBody(start: Int, header: UInt64) throws -> TypeInfo { - buffer.setCursor(start) - let decoded = try TypeMeta.decode( - buffer, - maxTypeFields: config.maxTypeFields, - maxTypeMetaBytes: config.maxTypeMetaBytes) - let typeMetaEnd = buffer.getCursor() - let localTypeInfo = try typeResolver.requireTypeInfo(for: decoded) - return try typeResolver.cacheTypeInfo( - decoded, - forHeader: header, - localTypeInfo: localTypeInfo, - exactLocal: try matchesLocalTypeDefBytes( - localTypeInfo: localTypeInfo, - typeMeta: decoded, - start: start, - end: typeMetaEnd), - config: config - ) - } - - @inline(never) - private func readTypeInfoBody( - start: Int, - header: UInt64, - for localTypeInfo: TypeInfo, - wireTypeID: TypeId - ) throws -> TypeInfo { - buffer.setCursor(start) - let decoded = try TypeMeta.decode( - buffer, - maxTypeFields: config.maxTypeFields, - maxTypeMetaBytes: config.maxTypeMetaBytes) - let typeMetaEnd = buffer.getCursor() - try validateCompatibleTypeMeta(decoded, for: localTypeInfo, wireTypeID: wireTypeID) - // The typed path is owned by the declared local type. After identity validation, the - // decoded metadata must describe this same TypeInfo; do not resolve another owner here. - return try typeResolver.cacheTypeInfo( - decoded, - forHeader: header, - localTypeInfo: localTypeInfo, - exactLocal: try matchesLocalTypeDefBytes( - localTypeInfo: localTypeInfo, - typeMeta: decoded, - start: start, - end: typeMetaEnd), - config: config - ) - } - - @inline(never) - private func matchesLocalTypeDefBytes( - localTypeInfo: TypeInfo, - typeMeta: TypeMeta, - start: Int, - end: Int - ) throws -> Bool { - guard typeMeta.typeID != nil else { - return false - } - guard let localTypeDefBytes = localTypeInfo.typeDefBytes, - end - start == localTypeDefBytes.count - else { - return false - } - return buffer.matchesBytes(start: start, bytes: localTypeDefBytes) - } - - private func validateCompatibleTypeInfo( - _ remoteTypeInfo: TypeInfo, - for localTypeInfo: TypeInfo, - wireTypeID: TypeId - ) throws -> TypeInfo { - guard let remoteTypeMeta = remoteTypeInfo.compatibleTypeMeta else { - throw ForyError.invalidData("compatible type metadata is required") - } - try validateCompatibleTypeMeta(remoteTypeMeta, for: localTypeInfo, wireTypeID: wireTypeID) - return remoteTypeInfo - } - - @inline(__always) - private func validateCompatibleTypeMeta( - _ remoteTypeMeta: TypeMeta, - for localTypeInfo: TypeInfo, - wireTypeID: TypeId - ) throws { - if let localTypeMeta = localTypeInfo.typeMeta, - remoteTypeMeta === localTypeMeta - { - return + let indexMarker = try buffer.readVarUInt32() + return try readCompatibleTypeInfo( + afterMarker: indexMarker, + for: localTypeInfo, + wireTypeID: wireTypeID) } - if remoteTypeMeta.registerByName { - guard localTypeInfo.registerByName else { - throw ForyError.invalidData( - "received name-registered compatible metadata for id-registered local type") - } - if remoteTypeMeta.namespace.value != localTypeInfo.namespace.value { - throw ForyError.invalidData( - "namespace mismatch: expected \(localTypeInfo.namespace.value), got \(remoteTypeMeta.namespace.value)" + + @inline(never) + private func readTypeInfoBody(start: Int, header: UInt64) throws -> TypeInfo { + buffer.setCursor(start) + let decoded = try TypeMeta.decode( + buffer, + maxTypeFields: config.maxTypeFields, + maxTypeMetaBytes: config.maxTypeMetaBytes) + let typeMetaEnd = buffer.getCursor() + let localTypeInfo = try typeResolver.requireTypeInfo(for: decoded) + return try typeResolver.cacheTypeInfo( + decoded, + forHeader: header, + localTypeInfo: localTypeInfo, + exactLocal: try matchesLocalTypeDefBytes( + localTypeInfo: localTypeInfo, + typeMeta: decoded, + start: start, + end: typeMetaEnd), + config: config ) - } - if remoteTypeMeta.typeName.value != localTypeInfo.typeName.value { - throw ForyError.invalidData( - "type name mismatch: expected \(localTypeInfo.typeName.value), got \(remoteTypeMeta.typeName.value)" + } + + @inline(never) + private func readTypeInfoBody( + start: Int, + header: UInt64, + for localTypeInfo: TypeInfo, + wireTypeID: TypeId + ) throws -> TypeInfo { + buffer.setCursor(start) + let decoded = try TypeMeta.decode( + buffer, + maxTypeFields: config.maxTypeFields, + maxTypeMetaBytes: config.maxTypeMetaBytes) + let typeMetaEnd = buffer.getCursor() + try validateCompatibleTypeMeta(decoded, for: localTypeInfo, wireTypeID: wireTypeID) + // The typed path is owned by the declared local type. After identity validation, the + // decoded metadata must describe this same TypeInfo; do not resolve another owner here. + return try typeResolver.cacheTypeInfo( + decoded, + forHeader: header, + localTypeInfo: localTypeInfo, + exactLocal: try matchesLocalTypeDefBytes( + localTypeInfo: localTypeInfo, + typeMeta: decoded, + start: start, + end: typeMetaEnd), + config: config ) - } - } else { - guard !localTypeInfo.registerByName else { - throw ForyError.invalidData( - "received id-registered compatible metadata for name-registered local type") - } - guard let remoteUserTypeID = remoteTypeMeta.userTypeID else { - throw ForyError.invalidData("missing user type id in compatible type metadata") - } - guard let localUserTypeID = localTypeInfo.userTypeID else { - throw ForyError.invalidData("missing local user type id metadata for id-registered type") - } - if remoteUserTypeID != localUserTypeID { - throw ForyError.typeMismatch(expected: localUserTypeID, actual: remoteUserTypeID) - } } - if let remoteTypeID = remoteTypeMeta.typeID, - let remoteWireTypeID = TypeId(rawValue: remoteTypeID), - !isAllowedRegisteredWireTypeID( - remoteWireTypeID, - declaredTypeID: localTypeInfo.typeID, - registerByName: localTypeInfo.registerByName, - compatible: compatible, - evolving: localTypeInfo.evolving - ) - { - throw ForyError.typeMismatch(expected: wireTypeID.rawValue, actual: remoteTypeID) + @inline(never) + private func matchesLocalTypeDefBytes( + localTypeInfo: TypeInfo, + typeMeta: TypeMeta, + start: Int, + end: Int + ) throws -> Bool { + guard typeMeta.typeID != nil else { + return false + } + guard let localTypeDefBytes = localTypeInfo.typeDefBytes, + end - start == localTypeDefBytes.count + else { + return false + } + return buffer.matchesBytes(start: start, bytes: localTypeDefBytes) + } + + private func validateCompatibleTypeInfo( + _ remoteTypeInfo: TypeInfo, + for localTypeInfo: TypeInfo, + wireTypeID: TypeId + ) throws -> TypeInfo { + guard let remoteTypeMeta = remoteTypeInfo.compatibleTypeMeta else { + throw ForyError.invalidData("compatible type metadata is required") + } + try validateCompatibleTypeMeta(remoteTypeMeta, for: localTypeInfo, wireTypeID: wireTypeID) + return remoteTypeInfo + } + + @inline(__always) + private func validateCompatibleTypeMeta( + _ remoteTypeMeta: TypeMeta, + for localTypeInfo: TypeInfo, + wireTypeID: TypeId + ) throws { + if let localTypeMeta = localTypeInfo.typeMeta, + remoteTypeMeta === localTypeMeta + { + return + } + if remoteTypeMeta.registerByName { + guard localTypeInfo.registerByName else { + throw ForyError.invalidData( + "received name-registered compatible metadata for id-registered local type") + } + if remoteTypeMeta.namespace.value != localTypeInfo.namespace.value { + throw ForyError.invalidData( + "namespace mismatch: expected \(localTypeInfo.namespace.value), got \(remoteTypeMeta.namespace.value)" + ) + } + if remoteTypeMeta.typeName.value != localTypeInfo.typeName.value { + throw ForyError.invalidData( + "type name mismatch: expected \(localTypeInfo.typeName.value), got \(remoteTypeMeta.typeName.value)" + ) + } + } else { + guard !localTypeInfo.registerByName else { + throw ForyError.invalidData( + "received id-registered compatible metadata for name-registered local type") + } + guard let remoteUserTypeID = remoteTypeMeta.userTypeID else { + throw ForyError.invalidData("missing user type id in compatible type metadata") + } + guard let localUserTypeID = localTypeInfo.userTypeID else { + throw ForyError.invalidData("missing local user type id metadata for id-registered type") + } + if remoteUserTypeID != localUserTypeID { + throw ForyError.typeMismatch(expected: localUserTypeID, actual: remoteUserTypeID) + } + } + + if let remoteTypeID = remoteTypeMeta.typeID, + let remoteWireTypeID = TypeId(rawValue: remoteTypeID), + !isAllowedRegisteredWireTypeID( + remoteWireTypeID, + declaredTypeID: localTypeInfo.typeID, + registerByName: localTypeInfo.registerByName, + compatible: compatible, + evolving: localTypeInfo.evolving + ) + { + throw ForyError.typeMismatch(expected: wireTypeID.rawValue, actual: remoteTypeID) + } } - } - - func readAnyValue(typeInfo: TypeInfo) throws -> Any { - try enterDynamicAnyDepth() - defer { leaveDynamicAnyDepth() } - - let value: Any - switch typeInfo.typeID { - case .bool: - value = try Bool.foryRead(self, refMode: .none, readTypeInfo: false) - case .int8: - value = try Int8.foryRead(self, refMode: .none, readTypeInfo: false) - case .int16: - value = try Int16.foryRead(self, refMode: .none, readTypeInfo: false) - case .int32: - value = try buffer.readInt32() - case .varint32: - value = try Int32.foryRead(self, refMode: .none, readTypeInfo: false) - case .int64: - value = try buffer.readInt64() - case .varint64: - value = try Int64.foryRead(self, refMode: .none, readTypeInfo: false) - case .taggedInt64: - value = try buffer.readTaggedInt64() - case .uint8: - value = try UInt8.foryRead(self, refMode: .none, readTypeInfo: false) - case .uint16: - value = try UInt16.foryRead(self, refMode: .none, readTypeInfo: false) - case .uint32: - value = try buffer.readUInt32() - case .varUInt32: - value = try UInt32.foryRead(self, refMode: .none, readTypeInfo: false) - case .uint64: - value = try buffer.readUInt64() - case .varUInt64: - value = try UInt64.foryRead(self, refMode: .none, readTypeInfo: false) - case .taggedUInt64: - value = try buffer.readTaggedUInt64() - case .float16: - value = try Float16.foryRead(self, refMode: .none, readTypeInfo: false) - case .bfloat16: - value = try BFloat16.foryRead(self, refMode: .none, readTypeInfo: false) - case .float32: - value = try Float.foryRead(self, refMode: .none, readTypeInfo: false) - case .float64: - value = try Double.foryRead(self, refMode: .none, readTypeInfo: false) - case .string: - value = try String.foryRead(self, refMode: .none, readTypeInfo: false) - case .duration: - value = try Duration.foryRead(self, refMode: .none, readTypeInfo: false) - case .timestamp: - value = try Date.foryRead(self, refMode: .none, readTypeInfo: false) - case .date: - value = try LocalDate.foryRead(self, refMode: .none, readTypeInfo: false) - case .decimal: - value = try Decimal.foryRead(self, refMode: .none, readTypeInfo: false) - case .binary: - value = try Data.foryRead(self, refMode: .none, readTypeInfo: false) - case .boolArray: - value = try readPrimitiveArray(self) as [Bool] - case .int8Array: - value = try readPrimitiveArray(self) as [Int8] - case .int16Array: - value = try readPrimitiveArray(self) as [Int16] - case .int32Array: - value = try readPrimitiveArray(self) as [Int32] - case .int64Array: - value = try readPrimitiveArray(self) as [Int64] - case .uint8Array: - value = try readPrimitiveArray(self) as [UInt8] - case .uint16Array: - value = try readPrimitiveArray(self) as [UInt16] - case .uint32Array: - value = try readPrimitiveArray(self) as [UInt32] - case .uint64Array: - value = try readPrimitiveArray(self) as [UInt64] - case .float16Array: - value = try readPrimitiveArray(self) as [Float16] - case .bfloat16Array: - value = try readPrimitiveArray(self) as [BFloat16] - case .float32Array: - value = try readPrimitiveArray(self) as [Float] - case .float64Array: - value = try readPrimitiveArray(self) as [Double] - case .array, .list: - value = try readListOfAny(context: self, refMode: .none) ?? [] - case .set: - value = try Set.foryRead(self, refMode: .none, readTypeInfo: false) - case .map: - value = try readDynamicAnyMapValue(context: self) - case .none: - value = ForyAnyNullValue() - default: - if typeInfo.typeID.isUserTypeKind { - value = try typeInfo.read(self) - } else { - throw ForyError.invalidData("unsupported dynamic type id \(typeInfo.typeID)") - } + + func readAnyValue(typeInfo: TypeInfo) throws -> Any { + try enterDynamicAnyDepth() + defer { leaveDynamicAnyDepth() } + + let value: Any + switch typeInfo.typeID { + case .bool: + value = try Bool.foryRead(self, refMode: .none, readTypeInfo: false) + case .int8: + value = try Int8.foryRead(self, refMode: .none, readTypeInfo: false) + case .int16: + value = try Int16.foryRead(self, refMode: .none, readTypeInfo: false) + case .int32: + value = try buffer.readInt32() + case .varint32: + value = try Int32.foryRead(self, refMode: .none, readTypeInfo: false) + case .int64: + value = try buffer.readInt64() + case .varint64: + value = try Int64.foryRead(self, refMode: .none, readTypeInfo: false) + case .taggedInt64: + value = try buffer.readTaggedInt64() + case .uint8: + value = try UInt8.foryRead(self, refMode: .none, readTypeInfo: false) + case .uint16: + value = try UInt16.foryRead(self, refMode: .none, readTypeInfo: false) + case .uint32: + value = try buffer.readUInt32() + case .varUInt32: + value = try UInt32.foryRead(self, refMode: .none, readTypeInfo: false) + case .uint64: + value = try buffer.readUInt64() + case .varUInt64: + value = try UInt64.foryRead(self, refMode: .none, readTypeInfo: false) + case .taggedUInt64: + value = try buffer.readTaggedUInt64() + case .float16: + value = try Float16.foryRead(self, refMode: .none, readTypeInfo: false) + case .bfloat16: + value = try BFloat16.foryRead(self, refMode: .none, readTypeInfo: false) + case .float32: + value = try Float.foryRead(self, refMode: .none, readTypeInfo: false) + case .float64: + value = try Double.foryRead(self, refMode: .none, readTypeInfo: false) + case .string: + value = try String.foryRead(self, refMode: .none, readTypeInfo: false) + case .duration: + value = try Duration.foryRead(self, refMode: .none, readTypeInfo: false) + case .timestamp: + value = try Date.foryRead(self, refMode: .none, readTypeInfo: false) + case .date: + value = try LocalDate.foryRead(self, refMode: .none, readTypeInfo: false) + case .decimal: + value = try Decimal.foryRead(self, refMode: .none, readTypeInfo: false) + case .binary: + value = try Data.foryRead(self, refMode: .none, readTypeInfo: false) + case .boolArray: + value = try readPrimitiveArray(self) as [Bool] + case .int8Array: + value = try readPrimitiveArray(self) as [Int8] + case .int16Array: + value = try readPrimitiveArray(self) as [Int16] + case .int32Array: + value = try readPrimitiveArray(self) as [Int32] + case .int64Array: + value = try readPrimitiveArray(self) as [Int64] + case .uint8Array: + value = try readPrimitiveArray(self) as [UInt8] + case .uint16Array: + value = try readPrimitiveArray(self) as [UInt16] + case .uint32Array: + value = try readPrimitiveArray(self) as [UInt32] + case .uint64Array: + value = try readPrimitiveArray(self) as [UInt64] + case .float16Array: + value = try readPrimitiveArray(self) as [Float16] + case .bfloat16Array: + value = try readPrimitiveArray(self) as [BFloat16] + case .float32Array: + value = try readPrimitiveArray(self) as [Float] + case .float64Array: + value = try readPrimitiveArray(self) as [Double] + case .array, .list: + value = try readListOfAny(context: self, refMode: .none) ?? [] + case .set: + value = try Set.foryRead(self, refMode: .none, readTypeInfo: false) + case .map: + value = try readDynamicAnyMapValue(context: self) + case .none: + value = ForyAnyNullValue() + default: + if typeInfo.typeID.isUserTypeKind { + value = try typeInfo.read(self) + } else { + throw ForyError.invalidData("unsupported dynamic type id \(typeInfo.typeID)") + } + } + return value } - return value - } - - @inline(__always) - func getTypeInfo(for type: T.Type) -> TypeInfo? { - typeInfoStack.value(for: UInt64(UInt(bitPattern: ObjectIdentifier(type)))) - } - - func withTypeInfo( - _ typeInfo: TypeInfo?, - for type: T.Type, - _ body: () throws -> R - ) rethrows -> R { - guard let typeInfo else { - return try body() + + @inline(__always) + func getTypeInfo(for type: T.Type) -> TypeInfo? { + typeInfoStack.value(for: UInt64(UInt(bitPattern: ObjectIdentifier(type)))) } - let typeKey = UInt64(UInt(bitPattern: ObjectIdentifier(type))) - let previousTypeInfo = typeInfoStack.value(for: typeKey) - typeInfoScopeStack.append((typeKey: typeKey, previousTypeInfo: previousTypeInfo)) - typeInfoStack.set(typeInfo, for: typeKey) - defer { - if let scope = typeInfoScopeStack.popLast() { - if let previousTypeInfo = scope.previousTypeInfo { - typeInfoStack.set(previousTypeInfo, for: scope.typeKey) - } else { - _ = typeInfoStack.removeValue(for: scope.typeKey) + func withTypeInfo( + _ typeInfo: TypeInfo?, + for type: T.Type, + _ body: () throws -> R + ) rethrows -> R { + guard let typeInfo else { + return try body() + } + + let typeKey = UInt64(UInt(bitPattern: ObjectIdentifier(type))) + let previousTypeInfo = typeInfoStack.value(for: typeKey) + typeInfoScopeStack.append((typeKey: typeKey, previousTypeInfo: previousTypeInfo)) + typeInfoStack.set(typeInfo, for: typeKey) + defer { + if let scope = typeInfoScopeStack.popLast() { + if let previousTypeInfo = scope.previousTypeInfo { + typeInfoStack.set(previousTypeInfo, for: scope.typeKey) + } else { + _ = typeInfoStack.removeValue(for: scope.typeKey) + } + } else { + assertionFailure("type info scope stack underflow") + } } - } else { - assertionFailure("type info scope stack underflow") - } + return try body() } - return try body() - } - - @inline(__always) - func getReadMetaString(at index: Int) -> MetaString? { - metaStrings.get(index) - } - - @inline(__always) - func appendReadMetaString(_ value: MetaString) { - metaStrings.push(value) - } - - func reset() { - if dynamicAnyDepth != 0 { - dynamicAnyDepth = 0 + + @inline(__always) + func getReadMetaString(at index: Int) -> MetaString? { + metaStrings.get(index) } - refReader.reset() - if !typeInfoStack.isEmpty { - typeInfoStack.clear() + + @inline(__always) + func appendReadMetaString(_ value: MetaString) { + metaStrings.push(value) } - if !typeInfoScopeStack.isEmpty { - typeInfoScopeStack.removeAll(keepingCapacity: true) + + func reset() { + if dynamicAnyDepth != 0 { + dynamicAnyDepth = 0 + } + refReader.reset() + if !typeInfoStack.isEmpty { + typeInfoStack.clear() + } + if !typeInfoScopeStack.isEmpty { + typeInfoScopeStack.removeAll(keepingCapacity: true) + } + compatibleTypeDefTypeInfos.reset() + metaStrings.reset() } - compatibleTypeDefTypeInfos.reset() - metaStrings.reset() - } } diff --git a/swift/Sources/ForyMacro/ForyObjectMacroReadGeneration.swift b/swift/Sources/ForyMacro/ForyObjectMacroReadGeneration.swift index ff21a1401f..307d3ba918 100644 --- a/swift/Sources/ForyMacro/ForyObjectMacroReadGeneration.swift +++ b/swift/Sources/ForyMacro/ForyObjectMacroReadGeneration.swift @@ -16,690 +16,690 @@ // under the License. func buildReadDataDecl( - isClass: Bool, - fields: [ParsedField], - sortedFields: [ParsedField], - accessPrefix: String + isClass: Bool, + fields: [ParsedField], + sortedFields: [ParsedField], + accessPrefix: String ) -> String { - if isClass { - return buildClassReadDataDecl(sortedFields: sortedFields, accessPrefix: accessPrefix) - } - if fields.isEmpty { - return buildEmptyStructReadDataDecl(accessPrefix: accessPrefix) - } - return buildStructReadDataDecl( - fields: fields, sortedFields: sortedFields, accessPrefix: accessPrefix) + if isClass { + return buildClassReadDataDecl(sortedFields: sortedFields, accessPrefix: accessPrefix) + } + if fields.isEmpty { + return buildEmptyStructReadDataDecl(accessPrefix: accessPrefix) + } + return buildStructReadDataDecl( + fields: fields, sortedFields: sortedFields, accessPrefix: accessPrefix) } func buildReadCompatibleDataDecl( - isClass: Bool, - fields: [ParsedField], - sortedFields: [ParsedField], - accessPrefix: String + isClass: Bool, + fields: [ParsedField], + sortedFields: [ParsedField], + accessPrefix: String ) -> String { - if isClass { - return buildClassReadCompatibleDataDecl(sortedFields: sortedFields, accessPrefix: accessPrefix) - } - if fields.isEmpty { - return buildEmptyStructReadCompatibleDataDecl(accessPrefix: accessPrefix) - } - return buildStructReadCompatibleDataDecl( - fields: fields, sortedFields: sortedFields, accessPrefix: accessPrefix) + if isClass { + return buildClassReadCompatibleDataDecl(sortedFields: sortedFields, accessPrefix: accessPrefix) + } + if fields.isEmpty { + return buildEmptyStructReadCompatibleDataDecl(accessPrefix: accessPrefix) + } + return buildStructReadCompatibleDataDecl( + fields: fields, sortedFields: sortedFields, accessPrefix: accessPrefix) } private func graphFieldBytesExpr(_ field: ParsedField) -> String { - if field.primitiveSize > 0 { - return "\(field.primitiveSize)" - } - return "(\(field.typeText).isRefType ? 4 : max(1, MemoryLayout<\(field.typeText)>.stride))" + if field.primitiveSize > 0 { + return "\(field.primitiveSize)" + } + return "(\(field.typeText).isRefType ? 4 : max(1, MemoryLayout<\(field.typeText)>.stride))" } private func classGraphOwnerBytesExpr(_ fields: [ParsedField]) -> String { - if fields.isEmpty { - return "1" - } - return "max(1, 1 + " + fields.map(graphFieldBytesExpr).joined(separator: " + ") + ")" + if fields.isEmpty { + return "1" + } + return "max(1, 1 + " + fields.map(graphFieldBytesExpr).joined(separator: " + ") + ")" } private func reserveClassGraphOwnerLine(fields: [ParsedField], indent: String) -> String { - "\(indent)try context.reserveGraphMemory(\(classGraphOwnerBytesExpr(fields)))" + "\(indent)try context.reserveGraphMemory(\(classGraphOwnerBytesExpr(fields)))" } private func reserveValueGraphOwnerLine(indent: String) -> String { - "\(indent)try context.reserveGraphMemory(max(1, MemoryLayout.stride))" + "\(indent)try context.reserveGraphMemory(max(1, MemoryLayout.stride))" } func buildClassReadWrapperDecl(accessPrefix: String) -> String { - """ - @inline(__always) - \(accessPrefix)static func foryRead( - _ context: ReadContext, - refMode: RefMode, - readTypeInfo: Bool - ) throws -> Self { - let __buffer = context.buffer - let __reservedRefID: UInt32? - if refMode != .none { - let rawFlag = try __buffer.readInt8() - guard let flag = RefFlag(rawValue: rawFlag) else { - throw ForyError.refError("invalid ref flag \\(rawFlag)") - } - - switch flag { - case .null: - return Self.foryDefault() - case .ref: - let refID = try __buffer.readVarUInt32() - return try context.refReader.readRef(refID, as: Self.self) - case .refValue: - __reservedRefID = context.trackRef ? context.refReader.reserveRefID() : nil - case .notNullValue: - __reservedRefID = nil - } - } else { - __reservedRefID = nil - } - - return try Self.foryReadPayload( - context, - readTypeInfo: readTypeInfo, - readData: { - try Self.__foryReadDataImpl(context, reservedRefID: __reservedRefID) - }, - readCompatibleData: { remoteTypeInfo in - try Self.__foryReadCompatibleDataImpl( - context, - remoteTypeInfo: remoteTypeInfo, - reservedRefID: __reservedRefID - ) - } - ) - } - """ -} - -private func buildClassReadDataDecl( - sortedFields: [ParsedField], - accessPrefix: String -) -> String { - let primitiveFastFields = leadingPrimitiveFastPathFields(sortedFields) - let schemaAssignBody = buildClassAssignBody( - sortedFields: sortedFields, primitiveFastFields: primitiveFastFields, compatibleAligned: false) - - return """ + """ @inline(__always) - private static func __foryReadDataImpl(_ context: ReadContext, reservedRefID: UInt32?) throws -> Self { + \(accessPrefix)static func foryRead( + _ context: ReadContext, + refMode: RefMode, + readTypeInfo: Bool + ) throws -> Self { let __buffer = context.buffer - \(schemaHashCheckExpr()) - \(reserveClassGraphOwnerLine(fields: sortedFields, indent: " ")) - let value = Self.init() - if let reservedRefID { - context.refReader.storeRef(value, at: reservedRefID) + let __reservedRefID: UInt32? + if refMode != .none { + let rawFlag = try __buffer.readInt8() + guard let flag = RefFlag(rawValue: rawFlag) else { + throw ForyError.refError("invalid ref flag \\(rawFlag)") + } + + switch flag { + case .null: + return Self.foryDefault() + case .ref: + let refID = try __buffer.readVarUInt32() + return try context.refReader.readRef(refID, as: Self.self) + case .refValue: + __reservedRefID = context.trackRef ? context.refReader.reserveRefID() : nil + case .notNullValue: + __reservedRefID = nil + } + } else { + __reservedRefID = nil } - \(schemaAssignBody) - return value - } - @inline(__always) - \(accessPrefix)static func foryReadData(_ context: ReadContext) throws -> Self { - try Self.__foryReadDataImpl(context, reservedRefID: nil) + return try Self.foryReadPayload( + context, + readTypeInfo: readTypeInfo, + readData: { + try Self.__foryReadDataImpl(context, reservedRefID: __reservedRefID) + }, + readCompatibleData: { remoteTypeInfo in + try Self.__foryReadCompatibleDataImpl( + context, + remoteTypeInfo: remoteTypeInfo, + reservedRefID: __reservedRefID + ) + } + ) } """ } -private func buildEmptyStructReadDataDecl(accessPrefix: String) -> String { - """ - @inline(__always) - \(accessPrefix)static func foryReadData(_ context: ReadContext) throws -> Self { - let __buffer = context.buffer - \(schemaHashCheckExpr()) - \(reserveValueGraphOwnerLine(indent: " ")) - return Self() - } - """ +private func buildClassReadDataDecl( + sortedFields: [ParsedField], + accessPrefix: String +) -> String { + let primitiveFastFields = leadingPrimitiveFastPathFields(sortedFields) + let schemaAssignBody = buildClassAssignBody( + sortedFields: sortedFields, primitiveFastFields: primitiveFastFields, compatibleAligned: false) + + return """ + @inline(__always) + private static func __foryReadDataImpl(_ context: ReadContext, reservedRefID: UInt32?) throws -> Self { + let __buffer = context.buffer + \(schemaHashCheckExpr()) + \(reserveClassGraphOwnerLine(fields: sortedFields, indent: " ")) + let value = Self.init() + if let reservedRefID { + context.refReader.storeRef(value, at: reservedRefID) + } + \(schemaAssignBody) + return value + } + + @inline(__always) + \(accessPrefix)static func foryReadData(_ context: ReadContext) throws -> Self { + try Self.__foryReadDataImpl(context, reservedRefID: nil) + } + """ } -private func buildStructReadDataDecl( - fields: [ParsedField], - sortedFields: [ParsedField], - accessPrefix: String -) -> String { - let primitiveFastFields = leadingPrimitiveFastPathFields(sortedFields) - let schemaReadBody = buildStructReadBody( - sortedFields: sortedFields, - primitiveFastFields: primitiveFastFields, - compatibleAligned: false - ) - let ctorArgs = buildCtorArgs(fields) - - return """ +private func buildEmptyStructReadDataDecl(accessPrefix: String) -> String { + """ @inline(__always) \(accessPrefix)static func foryReadData(_ context: ReadContext) throws -> Self { let __buffer = context.buffer \(schemaHashCheckExpr()) - \(reserveValueGraphOwnerLine(indent: " ")) - \(schemaReadBody) - return Self( - \(ctorArgs) - ) + \(reserveValueGraphOwnerLine(indent: " ")) + return Self() } """ } -private func buildClassReadCompatibleDataDecl( - sortedFields: [ParsedField], - accessPrefix: String +private func buildStructReadDataDecl( + fields: [ParsedField], + sortedFields: [ParsedField], + accessPrefix: String ) -> String { - let primitiveFastFields = leadingPrimitiveFastPathFields(sortedFields) - let schemaAssignBody = buildClassAssignBody( - sortedFields: sortedFields, primitiveFastFields: primitiveFastFields, compatibleAligned: false) - let compatibleAlignedAssignBody = buildClassAssignBody( - sortedFields: sortedFields, - primitiveFastFields: primitiveFastFields, - compatibleAligned: true - ) - let compatibleCases = buildCompatibleReadCases( - sortedFields: sortedFields, indent: " " - ) { sortedIndex, field, valueExpr in - "case \(sortedIndex): value.\(field.name) = \(valueExpr)" - } - let bufferBinding = - (schemaAssignBody.contains("__buffer") || compatibleAlignedAssignBody.contains("__buffer") - || compatibleCases.contains("__buffer")) ? "let __buffer = context.buffer\n " : "" - let localFieldsBinding = - compatibleCases.contains("__foryLocalFields") - ? "let __foryLocalFields = remoteTypeInfo.typeMeta?.fields ?? Self.foryFieldsInfo(trackRef: context.trackRef)\n " - : "" - - return """ - @inline(never) - private static func __foryReadCompatibleDataImpl( - _ context: ReadContext, - remoteTypeInfo: TypeInfo, - reservedRefID: UInt32? - ) throws -> Self { - \(bufferBinding)guard let typeMeta = remoteTypeInfo.compatibleTypeMeta else { - throw ForyError.invalidData("compatible type metadata is required") - } - \(reserveClassGraphOwnerLine(fields: sortedFields, indent: " ")) - let value = Self.init() - if let reservedRefID { - context.refReader.storeRef(value, at: reservedRefID) + let primitiveFastFields = leadingPrimitiveFastPathFields(sortedFields) + let schemaReadBody = buildStructReadBody( + sortedFields: sortedFields, + primitiveFastFields: primitiveFastFields, + compatibleAligned: false + ) + let ctorArgs = buildCtorArgs(fields) + + return """ + @inline(__always) + \(accessPrefix)static func foryReadData(_ context: ReadContext) throws -> Self { + let __buffer = context.buffer + \(schemaHashCheckExpr()) + \(reserveValueGraphOwnerLine(indent: " ")) + \(schemaReadBody) + return Self( + \(ctorArgs) + ) } - if let localTypeMeta = remoteTypeInfo.typeMeta, - let localHeaderHash = remoteTypeInfo.typeDefHeaderHash, - typeMeta.headerHash == localHeaderHash, - typeMeta.fields == localTypeMeta.fields { - if !remoteTypeInfo.typeDefHasUserTypeFields { - \(schemaAssignBody) + """ +} + +private func buildClassReadCompatibleDataDecl( + sortedFields: [ParsedField], + accessPrefix: String +) -> String { + let primitiveFastFields = leadingPrimitiveFastPathFields(sortedFields) + let schemaAssignBody = buildClassAssignBody( + sortedFields: sortedFields, primitiveFastFields: primitiveFastFields, compatibleAligned: false) + let compatibleAlignedAssignBody = buildClassAssignBody( + sortedFields: sortedFields, + primitiveFastFields: primitiveFastFields, + compatibleAligned: true + ) + let compatibleCases = buildCompatibleReadCases( + sortedFields: sortedFields, indent: " " + ) { sortedIndex, field, valueExpr in + "case \(sortedIndex): value.\(field.name) = \(valueExpr)" + } + let bufferBinding = + (schemaAssignBody.contains("__buffer") || compatibleAlignedAssignBody.contains("__buffer") + || compatibleCases.contains("__buffer")) ? "let __buffer = context.buffer\n " : "" + let localFieldsBinding = + compatibleCases.contains("__foryLocalFields") + ? "let __foryLocalFields = remoteTypeInfo.typeMeta?.fields ?? Self.foryFieldsInfo(trackRef: context.trackRef)\n " + : "" + + return """ + @inline(never) + private static func __foryReadCompatibleDataImpl( + _ context: ReadContext, + remoteTypeInfo: TypeInfo, + reservedRefID: UInt32? + ) throws -> Self { + \(bufferBinding)guard let typeMeta = remoteTypeInfo.compatibleTypeMeta else { + throw ForyError.invalidData("compatible type metadata is required") + } + \(reserveClassGraphOwnerLine(fields: sortedFields, indent: " ")) + let value = Self.init() + if let reservedRefID { + context.refReader.storeRef(value, at: reservedRefID) + } + if let localTypeMeta = remoteTypeInfo.typeMeta, + let localHeaderHash = remoteTypeInfo.typeDefHeaderHash, + typeMeta.headerHash == localHeaderHash, + typeMeta.fields == localTypeMeta.fields { + if !remoteTypeInfo.typeDefHasUserTypeFields { + \(schemaAssignBody) + return value + } + \(compatibleAlignedAssignBody) return value } - \(compatibleAlignedAssignBody) - return value - } - \(localFieldsBinding)for remoteField in typeMeta.fields { - switch Int(remoteField.fieldID ?? -1) { - \(compatibleCases) - case -1: - try context.skipFieldValue(remoteField.fieldType) - default: - throw ForyError.invalidData("invalid compatible matched id \\(remoteField.fieldID ?? -2)") + \(localFieldsBinding)for remoteField in typeMeta.fields { + switch Int(remoteField.fieldID ?? -1) { + \(compatibleCases) + case -1: + try context.skipFieldValue(remoteField.fieldType) + default: + throw ForyError.invalidData("invalid compatible matched id \\(remoteField.fieldID ?? -2)") + } } + return value } - return value - } - @inline(never) - \(accessPrefix)static func foryReadCompatibleData(_ context: ReadContext, remoteTypeInfo: TypeInfo) throws -> Self { - try Self.__foryReadCompatibleDataImpl(context, remoteTypeInfo: remoteTypeInfo, reservedRefID: nil) - } - """ + @inline(never) + \(accessPrefix)static func foryReadCompatibleData(_ context: ReadContext, remoteTypeInfo: TypeInfo) throws -> Self { + try Self.__foryReadCompatibleDataImpl(context, remoteTypeInfo: remoteTypeInfo, reservedRefID: nil) + } + """ } private func buildEmptyStructReadCompatibleDataDecl(accessPrefix: String) -> String { - """ - @inline(never) - \(accessPrefix)static func foryReadCompatibleData(_ context: ReadContext, remoteTypeInfo: TypeInfo) throws -> Self { - guard let typeMeta = remoteTypeInfo.compatibleTypeMeta else { - throw ForyError.invalidData("compatible type metadata is required") - } - \(reserveValueGraphOwnerLine(indent: " ")) - if let localTypeMeta = remoteTypeInfo.typeMeta, - let localHeaderHash = remoteTypeInfo.typeDefHeaderHash, - typeMeta.headerHash == localHeaderHash, - typeMeta.fields == localTypeMeta.fields { - return Self() - } - for remoteField in typeMeta.fields { - try context.skipFieldValue(remoteField.fieldType) - } - return Self() - } - """ -} - -private func buildStructReadCompatibleDataDecl( - fields: [ParsedField], - sortedFields: [ParsedField], - accessPrefix: String -) -> String { - let primitiveFastFields = leadingPrimitiveFastPathFields(sortedFields) - let schemaReadBody = buildStructReadBody( - sortedFields: sortedFields, - primitiveFastFields: primitiveFastFields, - compatibleAligned: false - ) - let compatibleAlignedReadBody = buildStructReadBody( - sortedFields: sortedFields, - primitiveFastFields: primitiveFastFields, - compatibleAligned: true - ) - let ctorArgs = buildCtorArgs(fields) - let compatibleDefaults = buildStructCompatibleDefaults(fields) - let compatibleCases = buildCompatibleReadCases( - sortedFields: sortedFields, indent: " " - ) { sortedIndex, field, valueExpr in - "case \(sortedIndex): __\(field.name) = \(valueExpr)" - } - let changedFallbackDecl = buildStructChangedFallbackDecl( - defaults: compatibleDefaults, - cases: compatibleCases, - ctorArgs: ctorArgs - ) - let bufferBinding = - (schemaReadBody.contains("__buffer") || compatibleAlignedReadBody.contains("__buffer")) - ? "let __buffer = context.buffer\n " : "" - - return """ - \(changedFallbackDecl) - + """ @inline(never) \(accessPrefix)static func foryReadCompatibleData(_ context: ReadContext, remoteTypeInfo: TypeInfo) throws -> Self { - \(bufferBinding)guard let typeMeta = remoteTypeInfo.compatibleTypeMeta else { + guard let typeMeta = remoteTypeInfo.compatibleTypeMeta else { throw ForyError.invalidData("compatible type metadata is required") } - \(reserveValueGraphOwnerLine(indent: " ")) + \(reserveValueGraphOwnerLine(indent: " ")) if let localTypeMeta = remoteTypeInfo.typeMeta, let localHeaderHash = remoteTypeInfo.typeDefHeaderHash, typeMeta.headerHash == localHeaderHash, typeMeta.fields == localTypeMeta.fields { - if !remoteTypeInfo.typeDefHasUserTypeFields { - \(schemaReadBody) + return Self() + } + for remoteField in typeMeta.fields { + try context.skipFieldValue(remoteField.fieldType) + } + return Self() + } + """ +} + +private func buildStructReadCompatibleDataDecl( + fields: [ParsedField], + sortedFields: [ParsedField], + accessPrefix: String +) -> String { + let primitiveFastFields = leadingPrimitiveFastPathFields(sortedFields) + let schemaReadBody = buildStructReadBody( + sortedFields: sortedFields, + primitiveFastFields: primitiveFastFields, + compatibleAligned: false + ) + let compatibleAlignedReadBody = buildStructReadBody( + sortedFields: sortedFields, + primitiveFastFields: primitiveFastFields, + compatibleAligned: true + ) + let ctorArgs = buildCtorArgs(fields) + let compatibleDefaults = buildStructCompatibleDefaults(fields) + let compatibleCases = buildCompatibleReadCases( + sortedFields: sortedFields, indent: " " + ) { sortedIndex, field, valueExpr in + "case \(sortedIndex): __\(field.name) = \(valueExpr)" + } + let changedFallbackDecl = buildStructChangedFallbackDecl( + defaults: compatibleDefaults, + cases: compatibleCases, + ctorArgs: ctorArgs + ) + let bufferBinding = + (schemaReadBody.contains("__buffer") || compatibleAlignedReadBody.contains("__buffer")) + ? "let __buffer = context.buffer\n " : "" + + return """ + \(changedFallbackDecl) + + @inline(never) + \(accessPrefix)static func foryReadCompatibleData(_ context: ReadContext, remoteTypeInfo: TypeInfo) throws -> Self { + \(bufferBinding)guard let typeMeta = remoteTypeInfo.compatibleTypeMeta else { + throw ForyError.invalidData("compatible type metadata is required") + } + \(reserveValueGraphOwnerLine(indent: " ")) + if let localTypeMeta = remoteTypeInfo.typeMeta, + let localHeaderHash = remoteTypeInfo.typeDefHeaderHash, + typeMeta.headerHash == localHeaderHash, + typeMeta.fields == localTypeMeta.fields { + if !remoteTypeInfo.typeDefHasUserTypeFields { + \(schemaReadBody) + return Self( + \(ctorArgs) + ) + } + \(compatibleAlignedReadBody) return Self( \(ctorArgs) ) } - \(compatibleAlignedReadBody) - return Self( - \(ctorArgs) + return try Self.__foryReadChangedData( + context, + typeMeta: typeMeta ) } - return try Self.__foryReadChangedData( - context, - typeMeta: typeMeta - ) - } - """ + """ } private func buildStructChangedFallbackDecl( - defaults: String, - cases: String, - ctorArgs: String + defaults: String, + cases: String, + ctorArgs: String ) -> String { - let bufferBinding = cases.contains("__buffer") ? "let __buffer = context.buffer\n " : "" - let localFieldsBinding = - cases.contains("__foryLocalFields") - ? "let __foryLocalFields = Self.foryFieldsInfo(trackRef: context.trackRef)\n " : "" - return """ - @inline(never) - private static func __foryReadChangedData( - _ context: ReadContext, - typeMeta: TypeMeta - ) throws -> Self { - \(bufferBinding) - \(defaults) - \(localFieldsBinding)for remoteField in typeMeta.fields { - switch Int(remoteField.fieldID ?? -1) { - \(cases) - case -1: - try context.skipFieldValue(remoteField.fieldType) - default: - throw ForyError.invalidData("invalid compatible matched id \\(remoteField.fieldID ?? -2)") + let bufferBinding = cases.contains("__buffer") ? "let __buffer = context.buffer\n " : "" + let localFieldsBinding = + cases.contains("__foryLocalFields") + ? "let __foryLocalFields = Self.foryFieldsInfo(trackRef: context.trackRef)\n " : "" + return """ + @inline(never) + private static func __foryReadChangedData( + _ context: ReadContext, + typeMeta: TypeMeta + ) throws -> Self { + \(bufferBinding) + \(defaults) + \(localFieldsBinding)for remoteField in typeMeta.fields { + switch Int(remoteField.fieldID ?? -1) { + \(cases) + case -1: + try context.skipFieldValue(remoteField.fieldType) + default: + throw ForyError.invalidData("invalid compatible matched id \\(remoteField.fieldID ?? -2)") + } } + return Self( + \(ctorArgs) + ) } - return Self( - \(ctorArgs) - ) - } - """ + """ } private func buildClassAssignBody( - sortedFields: [ParsedField], - primitiveFastFields: [ParsedField], - compatibleAligned: Bool + sortedFields: [ParsedField], + primitiveFastFields: [ParsedField], + compatibleAligned: Bool ) -> String { - let remainingAssignLines = sortedFields.dropFirst(primitiveFastFields.count).map { - field -> String in - let valueExpr: String - if compatibleAligned { - valueExpr = compatibleSchemaReadFieldExpr(field) - } else { - valueExpr = readFieldExpr( - field, - refModeExpr: fieldRefModeExpression(field), - readTypeInfoExpr: "false" - ) - } - return "value.\(field.name) = \(valueExpr)" - } - - var sections: [String] = [] - if let primitiveReadBlock = buildPrimitiveFastClassReadBlock(primitiveFastFields) { - sections.append(primitiveReadBlock) - } - if !remainingAssignLines.isEmpty { - sections.append(remainingAssignLines.joined(separator: "\n ")) - } - if sections.isEmpty { - sections.append("_ = context") - } - return sections.joined(separator: "\n ") + let remainingAssignLines = sortedFields.dropFirst(primitiveFastFields.count).map { + field -> String in + let valueExpr: String + if compatibleAligned { + valueExpr = compatibleSchemaReadFieldExpr(field) + } else { + valueExpr = readFieldExpr( + field, + refModeExpr: fieldRefModeExpression(field), + readTypeInfoExpr: "false" + ) + } + return "value.\(field.name) = \(valueExpr)" + } + + var sections: [String] = [] + if let primitiveReadBlock = buildPrimitiveFastClassReadBlock(primitiveFastFields) { + sections.append(primitiveReadBlock) + } + if !remainingAssignLines.isEmpty { + sections.append(remainingAssignLines.joined(separator: "\n ")) + } + if sections.isEmpty { + sections.append("_ = context") + } + return sections.joined(separator: "\n ") } private func buildStructReadBody( - sortedFields: [ParsedField], - primitiveFastFields: [ParsedField], - compatibleAligned: Bool + sortedFields: [ParsedField], + primitiveFastFields: [ParsedField], + compatibleAligned: Bool ) -> String { - let remainingReadLines = sortedFields.dropFirst(primitiveFastFields.count).map { - field -> String in - let valueExpr = - compatibleAligned ? compatibleSchemaReadFieldExpr(field) : schemaReadFieldExpr(field) - return "let __\(field.name) = \(valueExpr)" - } - - var sections: [String] = [] - if let primitiveDeclarations = buildPrimitiveFastStructReadDeclarations(primitiveFastFields) { - sections.append(primitiveDeclarations) - } - if let primitiveReadBlock = buildPrimitiveFastStructReadBlock(primitiveFastFields) { - sections.append(primitiveReadBlock) - } - if !remainingReadLines.isEmpty { - sections.append(remainingReadLines.joined(separator: "\n ")) - } - return sections.joined(separator: "\n ") + let remainingReadLines = sortedFields.dropFirst(primitiveFastFields.count).map { + field -> String in + let valueExpr = + compatibleAligned ? compatibleSchemaReadFieldExpr(field) : schemaReadFieldExpr(field) + return "let __\(field.name) = \(valueExpr)" + } + + var sections: [String] = [] + if let primitiveDeclarations = buildPrimitiveFastStructReadDeclarations(primitiveFastFields) { + sections.append(primitiveDeclarations) + } + if let primitiveReadBlock = buildPrimitiveFastStructReadBlock(primitiveFastFields) { + sections.append(primitiveReadBlock) + } + if !remainingReadLines.isEmpty { + sections.append(remainingReadLines.joined(separator: "\n ")) + } + return sections.joined(separator: "\n ") } private func buildCtorArgs(_ fields: [ParsedField]) -> String { - fields - .sorted(by: { $0.originalIndex < $1.originalIndex }) - .map { "\($0.name): __\($0.name)" } - .joined(separator: ",\n ") + fields + .sorted(by: { $0.originalIndex < $1.originalIndex }) + .map { "\($0.name): __\($0.name)" } + .joined(separator: ",\n ") } private func buildStructCompatibleDefaults(_ fields: [ParsedField]) -> String { - fields - .sorted(by: { $0.originalIndex < $1.originalIndex }) - .map(compatibleDefaultDecl) - .joined(separator: "\n ") + fields + .sorted(by: { $0.originalIndex < $1.originalIndex }) + .map(compatibleDefaultDecl) + .joined(separator: "\n ") } private func schemaHashCheckExpr(indent: String = " ") -> String { - """ - \(indent)if context.checkClassVersion { - \(indent) let __schemaHash = UInt32(bitPattern: try __buffer.readInt32()) - \(indent) let __expectedHash = Self.__forySchemaHash(context.trackRef) - \(indent) if __schemaHash != __expectedHash { - \(indent) throw ForyError.invalidData("class version hash mismatch: expected \\(__expectedHash), got \\(__schemaHash)") - \(indent) } - \(indent)} - """ + """ + \(indent)if context.checkClassVersion { + \(indent) let __schemaHash = UInt32(bitPattern: try __buffer.readInt32()) + \(indent) let __expectedHash = Self.__forySchemaHash(context.trackRef) + \(indent) if __schemaHash != __expectedHash { + \(indent) throw ForyError.invalidData("class version hash mismatch: expected \\(__expectedHash), got \\(__schemaHash)") + \(indent) } + \(indent)} + """ } private func buildCompatibleReadCases( - sortedFields: [ParsedField], - indent: String, - assignCase: (Int, ParsedField, String) -> String + sortedFields: [ParsedField], + indent: String, + assignCase: (Int, ParsedField, String) -> String ) -> String { - sortedFields.enumerated().map { sortedIndex, field -> String in - let directValueExpr = compatibleSchemaReadFieldExpr(field) - let compatibleValueExpr = readFieldExpr( - field, - refModeExpr: - "RefMode.from(nullable: remoteField.fieldType.nullable, trackRef: remoteField.fieldType.trackRef)", - readTypeInfoExpr: - "TypeId.needsTypeInfoForField(TypeId(rawValue: remoteField.fieldType.typeID) ?? .unknown)" - ) - let compatibleCaseExpr = compatibleScalarReadExpr( - field, - sortedIndex: sortedIndex, - compatibleValueExpr: compatibleValueExpr - ) - return [ - assignCase(sortedIndex * 2, field, directValueExpr), - assignCase(sortedIndex * 2 + 1, field, compatibleCaseExpr), - ].joined(separator: "\n\(indent)") - }.joined(separator: "\n\(indent)") + sortedFields.enumerated().map { sortedIndex, field -> String in + let directValueExpr = compatibleSchemaReadFieldExpr(field) + let compatibleValueExpr = readFieldExpr( + field, + refModeExpr: + "RefMode.from(nullable: remoteField.fieldType.nullable, trackRef: remoteField.fieldType.trackRef)", + readTypeInfoExpr: + "TypeId.needsTypeInfoForField(TypeId(rawValue: remoteField.fieldType.typeID) ?? .unknown)" + ) + let compatibleCaseExpr = compatibleScalarReadExpr( + field, + sortedIndex: sortedIndex, + compatibleValueExpr: compatibleValueExpr + ) + return [ + assignCase(sortedIndex * 2, field, directValueExpr), + assignCase(sortedIndex * 2 + 1, field, compatibleCaseExpr) + ].joined(separator: "\n\(indent)") + }.joined(separator: "\n\(indent)") } private func compatibleScalarReadExpr( - _ field: ParsedField, - sortedIndex: Int, - compatibleValueExpr: String + _ field: ParsedField, + sortedIndex: Int, + compatibleValueExpr: String ) -> String { - guard - field.dynamicAnyCodec == nil, - let helperTarget = compatibleScalarReaderTarget(field) - else { - return compatibleValueExpr - } - let helperName = - field.isOptional - ? "foryReadCompatibleOptional\(helperTarget)Field" - : "foryReadCompatible\(helperTarget)Field" - return """ - try \(helperName)( - context, - remoteField: remoteField, - localField: __foryLocalFields[\(sortedIndex)] - ) - """ + guard + field.dynamicAnyCodec == nil, + let helperTarget = compatibleScalarReaderTarget(field) + else { + return compatibleValueExpr + } + let helperName = + field.isOptional + ? "foryReadCompatibleOptional\(helperTarget)Field" + : "foryReadCompatible\(helperTarget)Field" + return """ + try \(helperName)( + context, + remoteField: remoteField, + localField: __foryLocalFields[\(sortedIndex)] + ) + """ } private func compatibleScalarReaderTarget(_ field: ParsedField) -> String? { - guard compatibleScalarTypeID(field.typeID) else { - return nil - } - switch compatibleScalarPayloadType(field.typeText) { - case "Bool": - return "Bool" - case "Int8": - return "Int8" - case "Int16": - return "Int16" - case "Int32": - return "Int32" - case "Int64": - return "Int64" - case "Int": - return "Int" - case "UInt8": - return "UInt8" - case "UInt16": - return "UInt16" - case "UInt32": - return "UInt32" - case "UInt64": - return "UInt64" - case "UInt": - return "UInt" - case "Float16": - return "Float16" - case "BFloat16": - return "BFloat16" - case "Float": - return "Float" - case "Double": - return "Double" - case "String": - return "String" - case "Decimal": - return "Decimal" - default: - return nil - } + guard compatibleScalarTypeID(field.typeID) else { + return nil + } + switch compatibleScalarPayloadType(field.typeText) { + case "Bool": + return "Bool" + case "Int8": + return "Int8" + case "Int16": + return "Int16" + case "Int32": + return "Int32" + case "Int64": + return "Int64" + case "Int": + return "Int" + case "UInt8": + return "UInt8" + case "UInt16": + return "UInt16" + case "UInt32": + return "UInt32" + case "UInt64": + return "UInt64" + case "UInt": + return "UInt" + case "Float16": + return "Float16" + case "BFloat16": + return "BFloat16" + case "Float": + return "Float" + case "Double": + return "Double" + case "String": + return "String" + case "Decimal": + return "Decimal" + default: + return nil + } } private func compatibleScalarPayloadType(_ typeText: String) -> String { - var type = trimType(typeText) - if type.hasSuffix("?") { - type.removeLast() - } else if type.hasPrefix("Optional<"), type.hasSuffix(">") { - type = String(type.dropFirst("Optional<".count).dropLast()) - } - for prefix in ["Swift.", "Foundation.", "Fory."] where type.hasPrefix(prefix) { - return String(type.dropFirst(prefix.count)) - } - return type + var type = trimType(typeText) + if type.hasSuffix("?") { + type.removeLast() + } else if type.hasPrefix("Optional<"), type.hasSuffix(">") { + type = String(type.dropFirst("Optional<".count).dropLast()) + } + for prefix in ["Swift.", "Foundation.", "Fory."] where type.hasPrefix(prefix) { + return String(type.dropFirst(prefix.count)) + } + return type } private func compatibleScalarTypeID(_ typeID: UInt32) -> Bool { - switch typeID { - case 1...15, 17...21, 40: - return true - default: - return false - } + switch typeID { + case 1...15, 17...21, 40: + return true + default: + return false + } } private func swiftStringLiteral(_ value: String) -> String { - let escaped = - value - .replacingOccurrences(of: "\\", with: "\\\\") - .replacingOccurrences(of: "\"", with: "\\\"") - return "\"\(escaped)\"" + let escaped = + value + .replacingOccurrences(of: "\\", with: "\\\\") + .replacingOccurrences(of: "\"", with: "\\\"") + return "\"\(escaped)\"" } private func readFieldExpr( - _ field: ParsedField, - refModeExpr: String, - readTypeInfoExpr: String + _ field: ParsedField, + refModeExpr: String, + readTypeInfoExpr: String ) -> String { - if let dynamicAnyCodec = field.dynamicAnyCodec { - return dynamicAnyReadExpr( - field: field, - dynamicAnyCodec: dynamicAnyCodec, - refModeExpr: refModeExpr - ) - } - if let codecType = field.customCodecType { - let fieldCodec = field.isOptional ? "OptionalFieldCodec<\(codecType)>" : codecType - if readTypeInfoExpr.contains("remoteField.fieldType") { - return """ - try \(fieldCodec).readCompatibleField( - context, - remoteFieldType: remoteField.fieldType, - refMode: \(refModeExpr) + if let dynamicAnyCodec = field.dynamicAnyCodec { + return dynamicAnyReadExpr( + field: field, + dynamicAnyCodec: dynamicAnyCodec, + refModeExpr: refModeExpr ) - """ } - return "try \(fieldCodec).read(context, refMode: \(refModeExpr), readTypeInfo: false)" - } - return - "try \(field.typeText).foryRead(context, refMode: \(refModeExpr), readTypeInfo: \(readTypeInfoExpr))" + if let codecType = field.customCodecType { + let fieldCodec = field.isOptional ? "OptionalFieldCodec<\(codecType)>" : codecType + if readTypeInfoExpr.contains("remoteField.fieldType") { + return """ + try \(fieldCodec).readCompatibleField( + context, + remoteFieldType: remoteField.fieldType, + refMode: \(refModeExpr) + ) + """ + } + return "try \(fieldCodec).read(context, refMode: \(refModeExpr), readTypeInfo: false)" + } + return + "try \(field.typeText).foryRead(context, refMode: \(refModeExpr), readTypeInfo: \(readTypeInfoExpr))" } private func schemaReadFieldExpr(_ field: ParsedField) -> String { - if fieldNeedsGeneralSchemaRead(field) { - return readFieldExpr( - field, - refModeExpr: fieldRefModeExpression(field), - readTypeInfoExpr: "false" - ) - } - if let primitiveExpr = primitiveSchemaReadExpr(field) { - return primitiveExpr - } - return "try \(field.typeText).foryReadData(context)" + if fieldNeedsGeneralSchemaRead(field) { + return readFieldExpr( + field, + refModeExpr: fieldRefModeExpression(field), + readTypeInfoExpr: "false" + ) + } + if let primitiveExpr = primitiveSchemaReadExpr(field) { + return primitiveExpr + } + return "try \(field.typeText).foryReadData(context)" } private func compatibleSchemaReadFieldExpr(_ field: ParsedField) -> String { - if fieldNeedsGeneralCompatibleRead(field) { - return readFieldExpr( - field, - refModeExpr: fieldRefModeExpression(field), - readTypeInfoExpr: "TypeId.needsTypeInfoForField(\(field.typeText).staticTypeId)" - ) - } - if let primitiveExpr = primitiveSchemaReadExpr(field) { - return primitiveExpr - } - return "try \(field.typeText).foryReadData(context)" + if fieldNeedsGeneralCompatibleRead(field) { + return readFieldExpr( + field, + refModeExpr: fieldRefModeExpression(field), + readTypeInfoExpr: "TypeId.needsTypeInfoForField(\(field.typeText).staticTypeId)" + ) + } + if let primitiveExpr = primitiveSchemaReadExpr(field) { + return primitiveExpr + } + return "try \(field.typeText).foryReadData(context)" } private func primitiveSchemaReadExpr(_ field: ParsedField) -> String? { - let type = trimType(field.typeText) - switch type { - case "Bool": - return "try __buffer.readUInt8() != 0" - case "Int8": - return "try __buffer.readInt8()" - case "Int16": - return "try __buffer.readInt16()" - case "Int32": - return "try __buffer.readVarInt32()" - case "Int64": - return "try __buffer.readVarInt64()" - case "Int": - return "Int(try __buffer.readVarInt64())" - case "UInt8": - return "try __buffer.readUInt8()" - case "UInt16": - return "try __buffer.readUInt16()" - case "UInt32": - return "try __buffer.readVarUInt32()" - case "UInt64": - return "try __buffer.readVarUInt64()" - case "UInt": - return "UInt(try __buffer.readVarUInt64())" - case "Float": - return "try __buffer.readFloat32()" - case "Double": - return "try __buffer.readFloat64()" - default: - return nil - } + let type = trimType(field.typeText) + switch type { + case "Bool": + return "try __buffer.readUInt8() != 0" + case "Int8": + return "try __buffer.readInt8()" + case "Int16": + return "try __buffer.readInt16()" + case "Int32": + return "try __buffer.readVarInt32()" + case "Int64": + return "try __buffer.readVarInt64()" + case "Int": + return "Int(try __buffer.readVarInt64())" + case "UInt8": + return "try __buffer.readUInt8()" + case "UInt16": + return "try __buffer.readUInt16()" + case "UInt32": + return "try __buffer.readVarUInt32()" + case "UInt64": + return "try __buffer.readVarUInt64()" + case "UInt": + return "UInt(try __buffer.readVarUInt64())" + case "Float": + return "try __buffer.readFloat32()" + case "Double": + return "try __buffer.readFloat64()" + default: + return nil + } } private func dynamicAnyReadExpr( - field: ParsedField, - dynamicAnyCodec: DynamicAnyCodecKind, - refModeExpr: String + field: ParsedField, + dynamicAnyCodec: DynamicAnyCodecKind, + refModeExpr: String ) -> String { - let metatypeExpr = "(\(field.typeText)).self" - let method = dynamicAnyReadMethodName(dynamicAnyCodec) - let readTypeInfoExpr = - dynamicAnyReadsTypeInfo(dynamicAnyCodec) - ? ", readTypeInfo: true" - : "" - return - "try castAnyDynamicValue(\(method)(context: context, refMode: \(refModeExpr)\(readTypeInfoExpr)), to: \(metatypeExpr))" + let metatypeExpr = "(\(field.typeText)).self" + let method = dynamicAnyReadMethodName(dynamicAnyCodec) + let readTypeInfoExpr = + dynamicAnyReadsTypeInfo(dynamicAnyCodec) + ? ", readTypeInfo: true" + : "" + return + "try castAnyDynamicValue(\(method)(context: context, refMode: \(refModeExpr)\(readTypeInfoExpr)), to: \(metatypeExpr))" } private func compatibleDefaultDecl(_ field: ParsedField) -> String { - let explicitType = - (field.dynamicAnyCodec != nil || field.customCodecType != nil) ? ": \(field.typeText)" : "" - return "var __\(field.name)\(explicitType) = \(fieldDefaultExpr(field))" + let explicitType = + (field.dynamicAnyCodec != nil || field.customCodecType != nil) ? ": \(field.typeText)" : "" + return "var __\(field.name)\(explicitType) = \(fieldDefaultExpr(field))" } private func fieldNeedsGeneralSchemaRead(_ field: ParsedField) -> Bool { - field.dynamicAnyCodec != nil || field.customCodecType != nil || field.isOptional - || field.typeID == 27 + field.dynamicAnyCodec != nil || field.customCodecType != nil || field.isOptional + || field.typeID == 27 } private func fieldNeedsGeneralCompatibleRead(_ field: ParsedField) -> Bool { - fieldNeedsGeneralSchemaRead(field) || compatibleFieldNeedsTypeInfo(field) + fieldNeedsGeneralSchemaRead(field) || compatibleFieldNeedsTypeInfo(field) } diff --git a/swift/Tests/ForyTests/ForySwiftTests.swift b/swift/Tests/ForyTests/ForySwiftTests.swift index 15ff5b1ab9..ea26571f17 100644 --- a/swift/Tests/ForyTests/ForySwiftTests.swift +++ b/swift/Tests/ForyTests/ForySwiftTests.swift @@ -22,1720 +22,1720 @@ import Testing @ForyStruct struct Address: Equatable { - var street: String - var zip: Int32 + var street: String + var zip: Int32 } @ForyStruct struct Person: Equatable { - var id: Int64 - var name: String - var nickname: String? - var scores: [Int32] - var tags: Set - var addresses: [Address] - var metadata: [Int8: Int32?] + var id: Int64 + var name: String + var nickname: String? + var scores: [Int32] + var tags: Set + var addresses: [Address] + var metadata: [Int8: Int32?] } @ForyStruct struct FieldOrder: Equatable { - var textTail: String - var longValue: Int64 - var shortValue: Int16 - var intValue: Int32 + var textTail: String + var longValue: Int64 + var shortValue: Int16 + var intValue: Int32 } @ForyStruct struct TaggedFieldOrder: Equatable { - @ForyField(id: 1) - var textTail: String + @ForyField(id: 1) + var textTail: String - @ForyField(id: 10) - var intValue: Int32 + @ForyField(id: 10) + var intValue: Int32 } @ForyStruct struct NonPrimitiveFieldOrder: Equatable { - @ForyField(id: 20) - var stringValue: String + @ForyField(id: 20) + var stringValue: String - @ForyField(id: 10) - var mapValue: [String: Int32] + @ForyField(id: 10) + var mapValue: [String: Int32] - var binaryValue: Data - var addressValue: Address - var intValue: Int32 + var binaryValue: Data + var addressValue: Address + var intValue: Int32 } @ForyStruct struct EncodedNumberFields: Equatable { - @ForyField(encoding: .fixed) - var u32Fixed: UInt32 + @ForyField(encoding: .fixed) + var u32Fixed: UInt32 - @ForyField(encoding: .tagged) - var u64Tagged: UInt64 + @ForyField(encoding: .tagged) + var u64Tagged: UInt64 } @ForyStruct struct ReducedPrecisionMacroFields: Equatable { - var float16Value: Float16 - var bfloat16Value: BFloat16 - @ArrayField(element: .float16) - var float16Array: [Float16] - @ArrayField(element: .bfloat16) - var bfloat16Array: [BFloat16] + var float16Value: Float16 + var bfloat16Value: BFloat16 + @ArrayField(element: .float16) + var float16Array: [Float16] + @ArrayField(element: .bfloat16) + var bfloat16Array: [BFloat16] } @ForyStruct struct FieldIdConfigured: Equatable { - @ForyField(id: 2) - var stableID: Int32 + @ForyField(id: 2) + var stableID: Int32 - @ForyField(id: 5, encoding: .fixed) - var fixedValue: Int32 + @ForyField(id: 5, encoding: .fixed) + var fixedValue: Int32 } @ForyStruct struct FieldIdSource: Equatable { - @ForyField(id: 1) - var value: Int32 + @ForyField(id: 1) + var value: Int32 - @ForyField(id: 4) - var label: String + @ForyField(id: 4) + var label: String } @ForyStruct struct FieldIdTarget: Equatable { - @ForyField(id: 1) - var renamedValue: Int32 + @ForyField(id: 1) + var renamedValue: Int32 - @ForyField(id: 4) - var renamedLabel: String + @ForyField(id: 4) + var renamedLabel: String } @ForyEnum enum SparseStatus: Int32, CaseIterable { - case unknown = 4096 - case ok = 8192 + case unknown = 4096 + case ok = 8192 } @ForyStruct struct EvolvingOverrideValue: Equatable { - var f1: String = "" + var f1: String = "" } @ForyStruct(evolving: false) struct FixedOverrideValue: Equatable { - var f1: String = "" + var f1: String = "" } @ForyUnion enum FieldIdUnionSource: Equatable { - @ForyUnknownCase - case unknown(UnknownCase) + @ForyUnknownCase + case unknown(UnknownCase) - @ForyCase(id: 3) - case number(Int32) + @ForyCase(id: 3) + case number(Int32) - @ForyCase(id: 9) - case text(String) + @ForyCase(id: 9) + case text(String) } @ForyUnion enum FieldIdUnionTarget: Equatable { - @ForyUnknownCase - case unknown(UnknownCase) + @ForyUnknownCase + case unknown(UnknownCase) - @ForyCase(id: 3) - case renamedNumber(Int32) + @ForyCase(id: 3) + case renamedNumber(Int32) - @ForyCase(id: 9) - case renamedText(String) + @ForyCase(id: 9) + case renamedText(String) } @ForyStruct struct CompatibleNestedItem: Equatable { - var id: Int32 - var name: String + var id: Int32 + var name: String } @ForyStruct struct CompatibleNestedArrayHolder: Equatable { - var items: [CompatibleNestedItem] + var items: [CompatibleNestedItem] } @ForyStruct struct CompatibleNestedOptionalArrayHolder: Equatable { - var items: [CompatibleNestedItem?] + var items: [CompatibleNestedItem?] } @ForyStruct struct CompatibleNestedMapHolder: Equatable { - var items: [Int32: CompatibleNestedItem] + var items: [Int32: CompatibleNestedItem] } struct LateMetaExt: Serializer, Equatable { - var value: Int32 = 0 + var value: Int32 = 0 - static func foryDefault() -> LateMetaExt { - LateMetaExt() - } + static func foryDefault() -> LateMetaExt { + LateMetaExt() + } - static var staticTypeId: TypeId { - .ext - } + static var staticTypeId: TypeId { + .ext + } - func foryWriteData(_ context: WriteContext, hasGenerics _: Bool) throws { - context.buffer.writeVarInt32(value) - } + func foryWriteData(_ context: WriteContext, hasGenerics _: Bool) throws { + context.buffer.writeVarInt32(value) + } - static func foryReadData(_ context: ReadContext) throws -> LateMetaExt { - LateMetaExt(value: try context.buffer.readVarInt32()) - } + static func foryReadData(_ context: ReadContext) throws -> LateMetaExt { + LateMetaExt(value: try context.buffer.readVarInt32()) + } } @ForyStruct struct LateMetaHolder: Equatable { - var ext: LateMetaExt + var ext: LateMetaExt } @ForyStruct final class Node { - var value: Int32 = 0 - var next: Node? + var value: Int32 = 0 + var next: Node? - required init() {} + required init() {} - init(value: Int32, next: Node? = nil) { - self.value = value - self.next = next - } + init(value: Int32, next: Node? = nil) { + self.value = value + self.next = next + } } @ForyStruct final class WeakNode { - var value: Int32 = 0 - weak var next: WeakNode? + var value: Int32 = 0 + weak var next: WeakNode? - required init() {} + required init() {} - init(value: Int32, next: WeakNode? = nil) { - self.value = value - self.next = next - } + init(value: Int32, next: WeakNode? = nil) { + self.value = value + self.next = next + } } @ForyStruct struct AnyObjectHolder { - var value: AnyObject - var optionalValue: AnyObject? - var items: [AnyObject] + var value: AnyObject + var optionalValue: AnyObject? + var items: [AnyObject] } @ForyStruct struct AnySerializerHolder { - var value: any Serializer - var items: [any Serializer] - var map: [String: any Serializer] + var value: any Serializer + var items: [any Serializer] + var map: [String: any Serializer] } @ForyStruct struct AnyFieldHolder { - var value: Any - var optionalValue: Any? - var list: [Any] - var stringMap: [String: Any] - var int32Map: [Int32: Any] + var value: Any + var optionalValue: Any? + var list: [Any] + var stringMap: [String: Any] + var int32Map: [Int32: Any] } @Test func primitiveRoundTrip() throws { - let fory = Fory() + let fory = Fory() - let boolData = try fory.serialize(true) - let boolValue: Bool = try fory.deserialize(boolData) - #expect(boolValue == true) + let boolData = try fory.serialize(true) + let boolValue: Bool = try fory.deserialize(boolData) + #expect(boolValue == true) - let int32Data = try fory.serialize(Int32(-123456)) - let int32Value: Int32 = try fory.deserialize(int32Data) - #expect(int32Value == -123456) + let int32Data = try fory.serialize(Int32(-123456)) + let int32Value: Int32 = try fory.deserialize(int32Data) + #expect(int32Value == -123456) - let int64Data = try fory.serialize(Int64(9_223_372_036_854_775_000)) - let int64Value: Int64 = try fory.deserialize(int64Data) - #expect(int64Value == 9_223_372_036_854_775_000) + let int64Data = try fory.serialize(Int64(9_223_372_036_854_775_000)) + let int64Value: Int64 = try fory.deserialize(int64Data) + #expect(int64Value == 9_223_372_036_854_775_000) - let uint32Data = try fory.serialize(UInt32(123456)) - let uint32Value: UInt32 = try fory.deserialize(uint32Data) - #expect(uint32Value == 123456) + let uint32Data = try fory.serialize(UInt32(123456)) + let uint32Value: UInt32 = try fory.deserialize(uint32Data) + #expect(uint32Value == 123456) - let uint64Data = try fory.serialize(UInt64(9_223_372_036_854_775_000)) - let uint64Value: UInt64 = try fory.deserialize(uint64Data) - #expect(uint64Value == 9_223_372_036_854_775_000) + let uint64Data = try fory.serialize(UInt64(9_223_372_036_854_775_000)) + let uint64Value: UInt64 = try fory.deserialize(uint64Data) + #expect(uint64Value == 9_223_372_036_854_775_000) - let floatData = try fory.serialize(Float(3.25)) - let floatValue: Float = try fory.deserialize(floatData) - #expect(floatValue == 3.25) + let floatData = try fory.serialize(Float(3.25)) + let floatValue: Float = try fory.deserialize(floatData) + #expect(floatValue == 3.25) - let doubleData = try fory.serialize(Double(3.1415926)) - let doubleValue: Double = try fory.deserialize(doubleData) - #expect(doubleValue == 3.1415926) + let doubleData = try fory.serialize(Double(3.1415926)) + let doubleValue: Double = try fory.deserialize(doubleData) + #expect(doubleValue == 3.1415926) - let stringData = try fory.serialize("hello_fory") - let stringValue: String = try fory.deserialize(stringData) - #expect(stringValue == "hello_fory") + let stringData = try fory.serialize("hello_fory") + let stringValue: String = try fory.deserialize(stringData) + #expect(stringValue == "hello_fory") - let binary = Data([0x01, 0x02, 0x03, 0xFF]) - let binaryData = try fory.serialize(binary) - let binaryValue: Data = try fory.deserialize(binaryData) - #expect(binaryValue == binary) + let binary = Data([0x01, 0x02, 0x03, 0xFF]) + let binaryData = try fory.serialize(binary) + let binaryValue: Data = try fory.deserialize(binaryData) + #expect(binaryValue == binary) } @Test func extendedWireTypesRoundTrip() throws { - let fory = Fory() + let fory = Fory() - let float16Value = Float16(3.5) - let float16Data = try fory.serialize(float16Value) - let float16Decoded: Float16 = try fory.deserialize(float16Data) - #expect(float16Decoded.bitPattern == float16Value.bitPattern) + let float16Value = Float16(3.5) + let float16Data = try fory.serialize(float16Value) + let float16Decoded: Float16 = try fory.deserialize(float16Data) + #expect(float16Decoded.bitPattern == float16Value.bitPattern) - let bfloatValue = BFloat16(rawValue: 0x3F80) - let bfloatData = try fory.serialize(bfloatValue) - let bfloatDecoded: BFloat16 = try fory.deserialize(bfloatData) - #expect(bfloatDecoded == bfloatValue) + let bfloatValue = BFloat16(rawValue: 0x3F80) + let bfloatData = try fory.serialize(bfloatValue) + let bfloatDecoded: BFloat16 = try fory.deserialize(bfloatData) + #expect(bfloatDecoded == bfloatValue) - let durationValue = Duration.seconds(-2) + Duration.nanoseconds(123_456_789) - let durationData = try fory.serialize(durationValue) - let durationDecoded: Duration = try fory.deserialize(durationData) - #expect(durationDecoded == durationValue) + let durationValue = Duration.seconds(-2) + Duration.nanoseconds(123_456_789) + let durationData = try fory.serialize(durationValue) + let durationDecoded: Duration = try fory.deserialize(durationData) + #expect(durationDecoded == durationValue) - let float16Array: [Float16] = [Float16(1), Float16(-2), Float16(4.5)] - let float16ArrayData = try fory.serialize(float16Array) - let float16ArrayDecoded: [Float16] = try fory.deserialize(float16ArrayData) - #expect(float16ArrayDecoded.map(\.bitPattern) == float16Array.map(\.bitPattern)) + let float16Array: [Float16] = [Float16(1), Float16(-2), Float16(4.5)] + let float16ArrayData = try fory.serialize(float16Array) + let float16ArrayDecoded: [Float16] = try fory.deserialize(float16ArrayData) + #expect(float16ArrayDecoded.map(\.bitPattern) == float16Array.map(\.bitPattern)) } @Test func floatingSpecialsRoundTrip() throws { - let fory = Fory() - - let floatValues: [Float] = [ - 0.0, - -0.0, - .infinity, - -.infinity, - .leastNonzeroMagnitude, - .greatestFiniteMagnitude, - Float(bitPattern: 0x7FC0_1234), - ] - for value in floatValues { - let decoded: Float = try fory.deserialize(try fory.serialize(value)) - #expect(decoded.bitPattern == value.bitPattern) - } - - let doubleValues: [Double] = [ - 0.0, - -0.0, - .infinity, - -.infinity, - .leastNonzeroMagnitude, - .greatestFiniteMagnitude, - Double(bitPattern: 0x7FF8_0000_0000_1234), - ] - for value in doubleValues { - let decoded: Double = try fory.deserialize(try fory.serialize(value)) - #expect(decoded.bitPattern == value.bitPattern) - } - - let float16Values: [Float16] = [ - .init(bitPattern: 0x0000), - .init(bitPattern: 0x8000), - .init(bitPattern: 0x7C00), - .init(bitPattern: 0xFC00), - .init(bitPattern: 0x0001), - .init(bitPattern: 0x7BFF), - .init(bitPattern: 0x7E11), - ] - for value in float16Values { - let decoded: Float16 = try fory.deserialize(try fory.serialize(value)) - #expect(decoded.bitPattern == value.bitPattern) - } - - let bfloat16Values: [BFloat16] = [ - .init(rawValue: 0x0000), - .init(rawValue: 0x8000), - .init(rawValue: 0x7F80), - .init(rawValue: 0xFF80), - .init(rawValue: 0x0001), - .init(rawValue: 0x7FC1), - ] - for value in bfloat16Values { - let decoded: BFloat16 = try fory.deserialize(try fory.serialize(value)) - #expect(decoded.rawValue == value.rawValue) - } + let fory = Fory() + + let floatValues: [Float] = [ + 0.0, + -0.0, + .infinity, + -.infinity, + .leastNonzeroMagnitude, + .greatestFiniteMagnitude, + Float(bitPattern: 0x7FC0_1234) + ] + for value in floatValues { + let decoded: Float = try fory.deserialize(try fory.serialize(value)) + #expect(decoded.bitPattern == value.bitPattern) + } + + let doubleValues: [Double] = [ + 0.0, + -0.0, + .infinity, + -.infinity, + .leastNonzeroMagnitude, + .greatestFiniteMagnitude, + Double(bitPattern: 0x7FF8_0000_0000_1234) + ] + for value in doubleValues { + let decoded: Double = try fory.deserialize(try fory.serialize(value)) + #expect(decoded.bitPattern == value.bitPattern) + } + + let float16Values: [Float16] = [ + .init(bitPattern: 0x0000), + .init(bitPattern: 0x8000), + .init(bitPattern: 0x7C00), + .init(bitPattern: 0xFC00), + .init(bitPattern: 0x0001), + .init(bitPattern: 0x7BFF), + .init(bitPattern: 0x7E11) + ] + for value in float16Values { + let decoded: Float16 = try fory.deserialize(try fory.serialize(value)) + #expect(decoded.bitPattern == value.bitPattern) + } + + let bfloat16Values: [BFloat16] = [ + .init(rawValue: 0x0000), + .init(rawValue: 0x8000), + .init(rawValue: 0x7F80), + .init(rawValue: 0xFF80), + .init(rawValue: 0x0001), + .init(rawValue: 0x7FC1) + ] + for value in bfloat16Values { + let decoded: BFloat16 = try fory.deserialize(try fory.serialize(value)) + #expect(decoded.rawValue == value.rawValue) + } } @Test func namedInitializerBuildsConfig() { - let defaultConfig = Fory() - #expect(defaultConfig.config.trackRef == false) - #expect(defaultConfig.config.compatible == true) - #expect(defaultConfig.config.checkClassVersion == false) - #expect(defaultConfig.config.maxDepth == 5) - #expect(defaultConfig.config.maxGraphMemoryBytes == 128 * 1024 * 1024) - #expect(defaultConfig.config.maxTypeFields == 512) - #expect(defaultConfig.config.maxTypeMetaBytes == 4096) - #expect(defaultConfig.config.maxSchemaVersionsPerType == 10) - #expect(defaultConfig.config.maxAverageSchemaVersionsPerType == 3) - - let explicitConfig = Fory( - ref: true, - compatible: true, - maxDepth: 7, - maxGraphMemoryBytes: 65_536, - maxTypeFields: 31, - maxTypeMetaBytes: 1234, - maxSchemaVersionsPerType: 12, - maxAverageSchemaVersionsPerType: 4 - ) - #expect(explicitConfig.config.trackRef == true) - #expect(explicitConfig.config.compatible == true) - #expect(explicitConfig.config.checkClassVersion == false) - #expect(explicitConfig.config.maxDepth == 7) - #expect(explicitConfig.config.maxGraphMemoryBytes == 65_536) - #expect(explicitConfig.config.maxTypeFields == 31) - #expect(explicitConfig.config.maxTypeMetaBytes == 1234) - #expect(explicitConfig.config.maxSchemaVersionsPerType == 12) - #expect(explicitConfig.config.maxAverageSchemaVersionsPerType == 4) - - let configInit = Fory( - config: .init( - trackRef: false, - compatible: true, - maxDepth: 9, - maxGraphMemoryBytes: 131_072, - maxTypeFields: 41, - maxTypeMetaBytes: 2048, - maxSchemaVersionsPerType: 14, - maxAverageSchemaVersionsPerType: 5 - )) - #expect(configInit.config.trackRef == false) - #expect(configInit.config.compatible == true) - #expect(configInit.config.checkClassVersion == false) - #expect(configInit.config.maxDepth == 9) - #expect(configInit.config.maxGraphMemoryBytes == 131_072) - #expect(configInit.config.maxTypeFields == 41) - #expect(configInit.config.maxTypeMetaBytes == 2048) - #expect(configInit.config.maxSchemaVersionsPerType == 14) - #expect(configInit.config.maxAverageSchemaVersionsPerType == 5) - - let schemaConsistentDirect = Fory(ref: true, compatible: false) - let schemaConsistentViaConfig = Fory(config: Config(trackRef: true, compatible: false)) - #expect(schemaConsistentDirect.config.checkClassVersion == true) - #expect(schemaConsistentViaConfig.config.checkClassVersion == true) + let defaultConfig = Fory() + #expect(defaultConfig.config.trackRef == false) + #expect(defaultConfig.config.compatible == true) + #expect(defaultConfig.config.checkClassVersion == false) + #expect(defaultConfig.config.maxDepth == 5) + #expect(defaultConfig.config.maxGraphMemoryBytes == 128 * 1024 * 1024) + #expect(defaultConfig.config.maxTypeFields == 512) + #expect(defaultConfig.config.maxTypeMetaBytes == 4096) + #expect(defaultConfig.config.maxSchemaVersionsPerType == 10) + #expect(defaultConfig.config.maxAverageSchemaVersionsPerType == 3) + + let explicitConfig = Fory( + ref: true, + compatible: true, + maxDepth: 7, + maxGraphMemoryBytes: 65_536, + maxTypeFields: 31, + maxTypeMetaBytes: 1234, + maxSchemaVersionsPerType: 12, + maxAverageSchemaVersionsPerType: 4 + ) + #expect(explicitConfig.config.trackRef == true) + #expect(explicitConfig.config.compatible == true) + #expect(explicitConfig.config.checkClassVersion == false) + #expect(explicitConfig.config.maxDepth == 7) + #expect(explicitConfig.config.maxGraphMemoryBytes == 65_536) + #expect(explicitConfig.config.maxTypeFields == 31) + #expect(explicitConfig.config.maxTypeMetaBytes == 1234) + #expect(explicitConfig.config.maxSchemaVersionsPerType == 12) + #expect(explicitConfig.config.maxAverageSchemaVersionsPerType == 4) + + let configInit = Fory( + config: .init( + trackRef: false, + compatible: true, + maxDepth: 9, + maxGraphMemoryBytes: 131_072, + maxTypeFields: 41, + maxTypeMetaBytes: 2048, + maxSchemaVersionsPerType: 14, + maxAverageSchemaVersionsPerType: 5 + )) + #expect(configInit.config.trackRef == false) + #expect(configInit.config.compatible == true) + #expect(configInit.config.checkClassVersion == false) + #expect(configInit.config.maxDepth == 9) + #expect(configInit.config.maxGraphMemoryBytes == 131_072) + #expect(configInit.config.maxTypeFields == 41) + #expect(configInit.config.maxTypeMetaBytes == 2048) + #expect(configInit.config.maxSchemaVersionsPerType == 14) + #expect(configInit.config.maxAverageSchemaVersionsPerType == 5) + + let schemaConsistentDirect = Fory(ref: true, compatible: false) + let schemaConsistentViaConfig = Fory(config: Config(trackRef: true, compatible: false)) + #expect(schemaConsistentDirect.config.checkClassVersion == true) + #expect(schemaConsistentViaConfig.config.checkClassVersion == true) } @Test func structEvolvingOverrideUsesSmallerCompatiblePayload() throws { - let fory = Fory(compatible: true) - fory.register(EvolvingOverrideValue.self, id: 1001) - fory.register(FixedOverrideValue.self, id: 1002) + let fory = Fory(compatible: true) + fory.register(EvolvingOverrideValue.self, id: 1001) + fory.register(FixedOverrideValue.self, id: 1002) - let evolving = EvolvingOverrideValue(f1: "payload") - let fixed = FixedOverrideValue(f1: "payload") + let evolving = EvolvingOverrideValue(f1: "payload") + let fixed = FixedOverrideValue(f1: "payload") - let evolvingData = try fory.serialize(evolving) - let fixedData = try fory.serialize(fixed) + let evolvingData = try fory.serialize(evolving) + let fixedData = try fory.serialize(fixed) - #expect(fixedData.count < evolvingData.count) - let decodedEvolving: EvolvingOverrideValue = try fory.deserialize(evolvingData) - let decodedFixed: FixedOverrideValue = try fory.deserialize(fixedData) - #expect(decodedEvolving == evolving) - #expect(decodedFixed == fixed) + #expect(fixedData.count < evolvingData.count) + let decodedEvolving: EvolvingOverrideValue = try fory.deserialize(evolvingData) + let decodedFixed: FixedOverrideValue = try fory.deserialize(fixedData) + #expect(decodedEvolving == evolving) + #expect(decodedFixed == fixed) } @Test func deserializeRejectsTrailingBytes() throws { - let fory = Fory() - let payload = try fory.serialize(Int32(7)) - var bytes = [UInt8](payload) - bytes.append(0xFF) - let withTrailing = Data(bytes) + let fory = Fory() + let payload = try fory.serialize(Int32(7)) + var bytes = [UInt8](payload) + bytes.append(0xFF) + let withTrailing = Data(bytes) - do { - let _: Int32 = try fory.deserialize(withTrailing) - #expect(Bool(false)) - } catch {} + do { + let _: Int32 = try fory.deserialize(withTrailing) + #expect(Bool(false)) + } catch {} } @Test func optionalRoundTrip() throws { - let fory = Fory() + let fory = Fory() - let some: String? = "present" - let someData = try fory.serialize(some) - let someValue: String? = try fory.deserialize(someData) - #expect(someValue == "present") + let some: String? = "present" + let someData = try fory.serialize(some) + let someValue: String? = try fory.deserialize(someData) + #expect(someValue == "present") - let none: String? = nil - let noneData = try fory.serialize(none) - let noneValue: String? = try fory.deserialize(noneData) - #expect(noneValue == nil) + let none: String? = nil + let noneData = try fory.serialize(none) + let noneValue: String? = try fory.deserialize(noneData) + #expect(noneValue == nil) } @Test func collectionsRoundTrip() throws { - let fory = Fory() + let fory = Fory() - let list: [String?] = ["a", nil, "b"] - let listData = try fory.serialize(list) - let listValue: [String?] = try fory.deserialize(listData) - #expect(listValue == list) + let list: [String?] = ["a", nil, "b"] + let listData = try fory.serialize(list) + let listValue: [String?] = try fory.deserialize(listData) + #expect(listValue == list) - let intArray: [Int32] = [1, 2, 3, 4] - let intArrayData = try fory.serialize(intArray) - let intArrayValue: [Int32] = try fory.deserialize(intArrayData) - #expect(intArrayValue == intArray) + let intArray: [Int32] = [1, 2, 3, 4] + let intArrayData = try fory.serialize(intArray) + let intArrayValue: [Int32] = try fory.deserialize(intArrayData) + #expect(intArrayValue == intArray) - let uint8Array: [UInt8] = [1, 2, 3, 250] - let uint8ArrayData = try fory.serialize(uint8Array) - let uint8ArrayValue: [UInt8] = try fory.deserialize(uint8ArrayData) - #expect(uint8ArrayValue == uint8Array) + let uint8Array: [UInt8] = [1, 2, 3, 250] + let uint8ArrayData = try fory.serialize(uint8Array) + let uint8ArrayValue: [UInt8] = try fory.deserialize(uint8ArrayData) + #expect(uint8ArrayValue == uint8Array) - let set: Set = [1, 5, 8] - let setData = try fory.serialize(set) - let setValue: Set = try fory.deserialize(setData) - #expect(setValue == set) + let set: Set = [1, 5, 8] + let setData = try fory.serialize(set) + let setValue: Set = try fory.deserialize(setData) + #expect(setValue == set) - let map: [Int8: Int32?] = [1: 100, 2: nil, 3: -7] - let mapData = try fory.serialize(map) - let mapValue: [Int8: Int32?] = try fory.deserialize(mapData) - #expect(mapValue == map) + let map: [Int8: Int32?] = [1: 100, 2: nil, 3: -7] + let mapData = try fory.serialize(map) + let mapValue: [Int8: Int32?] = try fory.deserialize(mapData) + #expect(mapValue == map) - let nullableKeyMap: [Int8?: Int32?] = [1: 10, nil: nil] - let nullableMapData = try fory.serialize(nullableKeyMap) - let nullableMapValue: [Int8?: Int32?] = try fory.deserialize(nullableMapData) - #expect(nullableMapValue == nullableKeyMap) + let nullableKeyMap: [Int8?: Int32?] = [1: 10, nil: nil] + let nullableMapData = try fory.serialize(nullableKeyMap) + let nullableMapValue: [Int8?: Int32?] = try fory.deserialize(nullableMapData) + #expect(nullableMapValue == nullableKeyMap) } @Test func primitiveArrayTypeIDs() throws { - let fory = Fory() + let fory = Fory() - let int32Data = try fory.serialize([Int32(7), 9]) - let int32Bytes = [UInt8](int32Data) - #expect(int32Bytes[0] == ForyHeaderFlag.isXlang) - #expect(Int8(bitPattern: int32Bytes[1]) == RefFlag.notNullValue.rawValue) - #expect(UInt32(int32Bytes[2]) == TypeId.list.rawValue) + let int32Data = try fory.serialize([Int32(7), 9]) + let int32Bytes = [UInt8](int32Data) + #expect(int32Bytes[0] == ForyHeaderFlag.isXlang) + #expect(Int8(bitPattern: int32Bytes[1]) == RefFlag.notNullValue.rawValue) + #expect(UInt32(int32Bytes[2]) == TypeId.list.rawValue) - let uint8Data = try fory.serialize([UInt8(1), 2, 3]) - let uint8Bytes = [UInt8](uint8Data) - #expect(UInt32(uint8Bytes[2]) == TypeId.list.rawValue) + let uint8Data = try fory.serialize([UInt8(1), 2, 3]) + let uint8Bytes = [UInt8](uint8Data) + #expect(UInt32(uint8Bytes[2]) == TypeId.list.rawValue) } @Test func typeMetaFieldLimitRejectsLargeStruct() throws { - let fieldType = TypeMeta.FieldType(typeID: TypeId.int32.rawValue, nullable: false) - let meta = try TypeMeta( - typeID: TypeId.structType.rawValue, - userTypeID: 901, - namespace: .empty(specialChar1: ".", specialChar2: "_"), - typeName: .empty(specialChar1: "$", specialChar2: "_"), - registerByName: false, - fields: [ - TypeMeta.FieldInfo(fieldID: nil, fieldName: "first", fieldType: fieldType), - TypeMeta.FieldInfo(fieldID: nil, fieldName: "second", fieldType: fieldType), - ] - ) - let encoded = try meta.encode() + let fieldType = TypeMeta.FieldType(typeID: TypeId.int32.rawValue, nullable: false) + let meta = try TypeMeta( + typeID: TypeId.structType.rawValue, + userTypeID: 901, + namespace: .empty(specialChar1: ".", specialChar2: "_"), + typeName: .empty(specialChar1: "$", specialChar2: "_"), + registerByName: false, + fields: [ + TypeMeta.FieldInfo(fieldID: nil, fieldName: "first", fieldType: fieldType), + TypeMeta.FieldInfo(fieldID: nil, fieldName: "second", fieldType: fieldType) + ] + ) + let encoded = try meta.encode() - #expect(throws: (any Error).self) { - _ = try TypeMeta.decode(encoded, maxTypeFields: 1) - } + #expect(throws: (any Error).self) { + _ = try TypeMeta.decode(encoded, maxTypeFields: 1) + } } @Test func typeMetaBodyLimitRejectsLargeMetadata() throws { - let meta = try TypeMeta( - typeID: TypeId.structType.rawValue, - userTypeID: 901, - namespace: .empty(specialChar1: ".", specialChar2: "_"), - typeName: .empty(specialChar1: "$", specialChar2: "_"), - registerByName: false, - fields: [ - TypeMeta.FieldInfo( - fieldID: nil, - fieldName: "value", - fieldType: TypeMeta.FieldType(typeID: TypeId.int32.rawValue, nullable: false)) - ] - ) - let encoded = try meta.encode() + let meta = try TypeMeta( + typeID: TypeId.structType.rawValue, + userTypeID: 901, + namespace: .empty(specialChar1: ".", specialChar2: "_"), + typeName: .empty(specialChar1: "$", specialChar2: "_"), + registerByName: false, + fields: [ + TypeMeta.FieldInfo( + fieldID: nil, + fieldName: "value", + fieldType: TypeMeta.FieldType(typeID: TypeId.int32.rawValue, nullable: false)) + ] + ) + let encoded = try meta.encode() - #expect(throws: (any Error).self) { - _ = try TypeMeta.decode(encoded, maxTypeMetaBytes: 1) - } + #expect(throws: (any Error).self) { + _ = try TypeMeta.decode(encoded, maxTypeMetaBytes: 1) + } } @Test func schemaLimitTracksStructTypesSeparately() throws { - let config = Config(maxSchemaVersionsPerType: 1) - let resolver = TypeResolver(config: config) - resolver.register(Person.self, id: 901) - resolver.register(Address.self, id: 902) - try resolver.finishRegistration() - - func remoteTypeMeta(userTypeID: UInt32, fieldName: String) throws -> TypeMeta { - try TypeMeta( - typeID: TypeId.structType.rawValue, - userTypeID: userTypeID, - namespace: .empty(specialChar1: ".", specialChar2: "_"), - typeName: .empty(specialChar1: "$", specialChar2: "_"), - registerByName: false, - fields: [ - TypeMeta.FieldInfo( - fieldID: nil, - fieldName: fieldName, - fieldType: TypeMeta.FieldType(typeID: TypeId.int32.rawValue, nullable: false) + let config = Config(maxSchemaVersionsPerType: 1) + let resolver = TypeResolver(config: config) + resolver.register(Person.self, id: 901) + resolver.register(Address.self, id: 902) + try resolver.finishRegistration() + + func remoteTypeMeta(userTypeID: UInt32, fieldName: String) throws -> TypeMeta { + try TypeMeta( + typeID: TypeId.structType.rawValue, + userTypeID: userTypeID, + namespace: .empty(specialChar1: ".", specialChar2: "_"), + typeName: .empty(specialChar1: "$", specialChar2: "_"), + registerByName: false, + fields: [ + TypeMeta.FieldInfo( + fieldID: nil, + fieldName: fieldName, + fieldType: TypeMeta.FieldType(typeID: TypeId.int32.rawValue, nullable: false) + ) + ] ) - ] - ) - } + } - func cache(_ typeMeta: TypeMeta) throws { - let encoded = try typeMeta.encode() - let headerReader = ByteBuffer(bytes: encoded) - let header = try headerReader.readUInt64() - let buffer = ByteBuffer(bytes: encoded) - let decoded = try TypeMeta.decode(buffer) - let localTypeInfo = try resolver.requireTypeInfo(for: decoded) - _ = try resolver.cacheTypeInfo( - decoded, - forHeader: header, - localTypeInfo: localTypeInfo, - exactLocal: false, - config: config - ) - } + func cache(_ typeMeta: TypeMeta) throws { + let encoded = try typeMeta.encode() + let headerReader = ByteBuffer(bytes: encoded) + let header = try headerReader.readUInt64() + let buffer = ByteBuffer(bytes: encoded) + let decoded = try TypeMeta.decode(buffer) + let localTypeInfo = try resolver.requireTypeInfo(for: decoded) + _ = try resolver.cacheTypeInfo( + decoded, + forHeader: header, + localTypeInfo: localTypeInfo, + exactLocal: false, + config: config + ) + } - try cache(remoteTypeMeta(userTypeID: 901, fieldName: "remoteA")) - try cache(remoteTypeMeta(userTypeID: 902, fieldName: "remoteA")) - #expect(throws: (any Error).self) { - try cache(remoteTypeMeta(userTypeID: 901, fieldName: "remoteB")) - } + try cache(remoteTypeMeta(userTypeID: 901, fieldName: "remoteA")) + try cache(remoteTypeMeta(userTypeID: 902, fieldName: "remoteA")) + #expect(throws: (any Error).self) { + try cache(remoteTypeMeta(userTypeID: 901, fieldName: "remoteB")) + } } @Test func nonStructTypeMetaUsesSchemaLimit() throws { - let config = Config(maxSchemaVersionsPerType: 1) - let resolver = TypeResolver(config: config) - try resolver.register(SparseStatus.self, name: "example.SharedEnum") - try resolver.finishRegistration() - let namespace = try MetaStringEncoder.namespace.encode("example") - let typeName = try MetaStringEncoder.typeName.encode("SharedEnum") - - func remoteTypeMeta(_ typeID: TypeId) throws -> TypeMeta { - try TypeMeta( - typeID: typeID.rawValue, - userTypeID: nil, - namespace: namespace, - typeName: typeName, - registerByName: true, - fields: [] - ) - } + let config = Config(maxSchemaVersionsPerType: 1) + let resolver = TypeResolver(config: config) + try resolver.register(SparseStatus.self, name: "example.SharedEnum") + try resolver.finishRegistration() + let namespace = try MetaStringEncoder.namespace.encode("example") + let typeName = try MetaStringEncoder.typeName.encode("SharedEnum") + + func remoteTypeMeta(_ typeID: TypeId) throws -> TypeMeta { + try TypeMeta( + typeID: typeID.rawValue, + userTypeID: nil, + namespace: namespace, + typeName: typeName, + registerByName: true, + fields: [] + ) + } - func cache(_ typeMeta: TypeMeta) throws { - let encoded = try typeMeta.encode() + func cache(_ typeMeta: TypeMeta) throws { + let encoded = try typeMeta.encode() + let headerReader = ByteBuffer(bytes: encoded) + let header = try headerReader.readUInt64() + let buffer = ByteBuffer(bytes: encoded) + let decoded = try TypeMeta.decode(buffer) + let localTypeInfo = try resolver.requireTypeInfo(for: decoded) + _ = try resolver.cacheTypeInfo( + decoded, + forHeader: header, + localTypeInfo: localTypeInfo, + exactLocal: false, + config: config + ) + } + + try cache(remoteTypeMeta(.namedExt)) + #expect(throws: (any Error).self) { + try cache(remoteTypeMeta(.namedUnion)) + } +} + +@Test +func exactLocalNonStructTypeMetaBypassesSchemaLimit() throws { + let config = Config(compatible: true, maxSchemaVersionsPerType: 1) + let resolver = TypeResolver(config: config) + try resolver.register(SparseStatus.self, name: "example.SharedEnum") + try resolver.finishRegistration() + let localTypeInfo = try resolver.requireTypeInfo(for: SparseStatus.self) + let namespace = try MetaStringEncoder.namespace.encode("example") + let typeName = try MetaStringEncoder.typeName.encode("SharedEnum") + + let exactBuffer = ByteBuffer() + exactBuffer.writeUInt8(UInt8(truncatingIfNeeded: TypeId.namedEnum.rawValue)) + exactBuffer.writeUInt8(0) + exactBuffer.writeBytes(localTypeInfo.typeDefBytes!) + let exactContext = ReadContext(buffer: exactBuffer, typeResolver: resolver, config: config) + _ = try exactContext.readTypeInfo(for: SparseStatus.self) + + let remote = try TypeMeta( + typeID: TypeId.namedExt.rawValue, + userTypeID: nil, + namespace: namespace, + typeName: typeName, + registerByName: true, + fields: [] + ) + let encoded = try remote.encode() let headerReader = ByteBuffer(bytes: encoded) let header = try headerReader.readUInt64() let buffer = ByteBuffer(bytes: encoded) let decoded = try TypeMeta.decode(buffer) - let localTypeInfo = try resolver.requireTypeInfo(for: decoded) + let resolved = try resolver.requireTypeInfo(for: decoded) _ = try resolver.cacheTypeInfo( - decoded, - forHeader: header, - localTypeInfo: localTypeInfo, - exactLocal: false, - config: config + decoded, + forHeader: header, + localTypeInfo: resolved, + exactLocal: false, + config: config ) - } - - try cache(remoteTypeMeta(.namedExt)) - #expect(throws: (any Error).self) { - try cache(remoteTypeMeta(.namedUnion)) - } -} - -@Test -func exactLocalNonStructTypeMetaBypassesSchemaLimit() throws { - let config = Config(compatible: true, maxSchemaVersionsPerType: 1) - let resolver = TypeResolver(config: config) - try resolver.register(SparseStatus.self, name: "example.SharedEnum") - try resolver.finishRegistration() - let localTypeInfo = try resolver.requireTypeInfo(for: SparseStatus.self) - let namespace = try MetaStringEncoder.namespace.encode("example") - let typeName = try MetaStringEncoder.typeName.encode("SharedEnum") - - let exactBuffer = ByteBuffer() - exactBuffer.writeUInt8(UInt8(truncatingIfNeeded: TypeId.namedEnum.rawValue)) - exactBuffer.writeUInt8(0) - exactBuffer.writeBytes(localTypeInfo.typeDefBytes!) - let exactContext = ReadContext(buffer: exactBuffer, typeResolver: resolver, config: config) - _ = try exactContext.readTypeInfo(for: SparseStatus.self) - - let remote = try TypeMeta( - typeID: TypeId.namedExt.rawValue, - userTypeID: nil, - namespace: namespace, - typeName: typeName, - registerByName: true, - fields: [] - ) - let encoded = try remote.encode() - let headerReader = ByteBuffer(bytes: encoded) - let header = try headerReader.readUInt64() - let buffer = ByteBuffer(bytes: encoded) - let decoded = try TypeMeta.decode(buffer) - let resolved = try resolver.requireTypeInfo(for: decoded) - _ = try resolver.cacheTypeInfo( - decoded, - forHeader: header, - localTypeInfo: resolved, - exactLocal: false, - config: config - ) } @Test func typeMetaUsesFinalRegistration() throws { - func holderTypeDefBytes(registerFieldTypeFirst: Bool) throws -> [UInt8] { - let resolver = TypeResolver(config: Config(compatible: true)) - if registerFieldTypeFirst { - try resolver.register(LateMetaExt.self, name: "example.LateMetaExt") - try resolver.register(LateMetaHolder.self, name: "example.LateMetaHolder") - } else { - try resolver.register(LateMetaHolder.self, name: "example.LateMetaHolder") - try resolver.register(LateMetaExt.self, name: "example.LateMetaExt") + func holderTypeDefBytes(registerFieldTypeFirst: Bool) throws -> [UInt8] { + let resolver = TypeResolver(config: Config(compatible: true)) + if registerFieldTypeFirst { + try resolver.register(LateMetaExt.self, name: "example.LateMetaExt") + try resolver.register(LateMetaHolder.self, name: "example.LateMetaHolder") + } else { + try resolver.register(LateMetaHolder.self, name: "example.LateMetaHolder") + try resolver.register(LateMetaExt.self, name: "example.LateMetaExt") + } + try resolver.finishRegistration() + return try resolver.requireTypeInfo(for: LateMetaHolder.self).typeDefBytes! } - try resolver.finishRegistration() - return try resolver.requireTypeInfo(for: LateMetaHolder.self).typeDefBytes! - } - let fieldFirst = try holderTypeDefBytes(registerFieldTypeFirst: true) - let holderFirst = try holderTypeDefBytes(registerFieldTypeFirst: false) - #expect(fieldFirst == holderFirst) + let fieldFirst = try holderTypeDefBytes(registerFieldTypeFirst: true) + let holderFirst = try holderTypeDefBytes(registerFieldTypeFirst: false) + #expect(fieldFirst == holderFirst) - let typeMeta = try TypeMeta.decode(ByteBuffer(bytes: holderFirst)) - #expect(typeMeta.fields.count == 1) - #expect(typeMeta.fields[0].fieldType.typeID == TypeId.namedExt.rawValue) + let typeMeta = try TypeMeta.decode(ByteBuffer(bytes: holderFirst)) + #expect(typeMeta.fields.count == 1) + #expect(typeMeta.fields[0].fieldType.typeID == TypeId.namedExt.rawValue) } @Test func failedSchemaDoesNotConsumeLimit() throws { - let config = Config(maxSchemaVersionsPerType: 1) - let resolver = TypeResolver(config: config) - resolver.register(Person.self, id: 901) - resolver.register(Address.self, id: 902) - try resolver.finishRegistration() - - func remoteTypeMeta(fieldName: String, fieldType: TypeMeta.FieldType) throws -> TypeMeta { - try TypeMeta( - typeID: TypeId.structType.rawValue, - userTypeID: 901, - namespace: .empty(specialChar1: ".", specialChar2: "_"), - typeName: .empty(specialChar1: "$", specialChar2: "_"), - registerByName: false, - fields: [ - TypeMeta.FieldInfo( - fieldID: nil, - fieldName: fieldName, - fieldType: fieldType + let config = Config(maxSchemaVersionsPerType: 1) + let resolver = TypeResolver(config: config) + resolver.register(Person.self, id: 901) + resolver.register(Address.self, id: 902) + try resolver.finishRegistration() + + func remoteTypeMeta(fieldName: String, fieldType: TypeMeta.FieldType) throws -> TypeMeta { + try TypeMeta( + typeID: TypeId.structType.rawValue, + userTypeID: 901, + namespace: .empty(specialChar1: ".", specialChar2: "_"), + typeName: .empty(specialChar1: "$", specialChar2: "_"), + registerByName: false, + fields: [ + TypeMeta.FieldInfo( + fieldID: nil, + fieldName: fieldName, + fieldType: fieldType + ) + ] ) - ] - ) - } + } - func cache(_ typeMeta: TypeMeta) throws { - let encoded = try typeMeta.encode() - let headerReader = ByteBuffer(bytes: encoded) - let header = try headerReader.readUInt64() - let buffer = ByteBuffer(bytes: encoded) - let decoded = try TypeMeta.decode(buffer) - let localTypeInfo = try resolver.requireTypeInfo(for: decoded) - _ = try resolver.cacheTypeInfo( - decoded, - forHeader: header, - localTypeInfo: localTypeInfo, - exactLocal: false, - config: config - ) - } + func cache(_ typeMeta: TypeMeta) throws { + let encoded = try typeMeta.encode() + let headerReader = ByteBuffer(bytes: encoded) + let header = try headerReader.readUInt64() + let buffer = ByteBuffer(bytes: encoded) + let decoded = try TypeMeta.decode(buffer) + let localTypeInfo = try resolver.requireTypeInfo(for: decoded) + _ = try resolver.cacheTypeInfo( + decoded, + forHeader: header, + localTypeInfo: localTypeInfo, + exactLocal: false, + config: config + ) + } - #expect(throws: (any Error).self) { + #expect(throws: (any Error).self) { + try cache( + remoteTypeMeta( + fieldName: "id", + fieldType: TypeMeta.FieldType( + typeID: TypeId.map.rawValue, + nullable: false, + generics: [ + TypeMeta.FieldType(typeID: TypeId.string.rawValue, nullable: false), + TypeMeta.FieldType(typeID: TypeId.int32.rawValue, nullable: false) + ] + ) + )) + } try cache( - remoteTypeMeta( - fieldName: "id", - fieldType: TypeMeta.FieldType( - typeID: TypeId.map.rawValue, - nullable: false, - generics: [ - TypeMeta.FieldType(typeID: TypeId.string.rawValue, nullable: false), - TypeMeta.FieldType(typeID: TypeId.int32.rawValue, nullable: false), - ] - ) - )) - } - try cache( - remoteTypeMeta( - fieldName: "remoteA", - fieldType: TypeMeta.FieldType(typeID: TypeId.int32.rawValue, nullable: false) - )) + remoteTypeMeta( + fieldName: "remoteA", + fieldType: TypeMeta.FieldType(typeID: TypeId.int32.rawValue, nullable: false) + )) } @Test func staticTypeRejectsWrongMetaOwner() throws { - let config = Config(compatible: true) - let resolver = TypeResolver(config: config) - resolver.register(Person.self, id: 901) - resolver.register(Address.self, id: 902) - try resolver.finishRegistration() - let wrongTypeMeta = try TypeMeta( - typeID: TypeId.compatibleStruct.rawValue, - userTypeID: 901, - namespace: .empty(specialChar1: ".", specialChar2: "_"), - typeName: .empty(specialChar1: "$", specialChar2: "_"), - registerByName: false, - fields: [] - ) - let wrongBytes = try wrongTypeMeta.encode() - let wrongHeader = try ByteBuffer(bytes: wrongBytes).readUInt64() - let buffer = ByteBuffer() - buffer.writeUInt8(UInt8(truncatingIfNeeded: TypeId.compatibleStruct.rawValue)) - buffer.writeUInt8(0) - buffer.writeBytes(wrongBytes) - let context = ReadContext(buffer: buffer, typeResolver: resolver, config: config) - - #expect(throws: (any Error).self) { - _ = try context.readTypeInfo(for: Address.self) - } - #expect(resolver.getTypeInfo(forHeader: wrongHeader) == nil) - - let addressInfo = try resolver.requireTypeInfo(for: Address.self) - let addressBytes = try #require(addressInfo.typeDefBytes) - let addressHeader = try ByteBuffer(bytes: addressBytes).readUInt64() - let exactBuffer = ByteBuffer() - exactBuffer.writeUInt8(UInt8(truncatingIfNeeded: TypeId.compatibleStruct.rawValue)) - exactBuffer.writeUInt8(0) - exactBuffer.writeBytes(addressBytes) - let exactContext = ReadContext(buffer: exactBuffer, typeResolver: resolver, config: config) - _ = try exactContext.readTypeInfo(for: Address.self) - #expect(resolver.getTypeInfo(forHeader: addressHeader) == nil) + let config = Config(compatible: true) + let resolver = TypeResolver(config: config) + resolver.register(Person.self, id: 901) + resolver.register(Address.self, id: 902) + try resolver.finishRegistration() + let wrongTypeMeta = try TypeMeta( + typeID: TypeId.compatibleStruct.rawValue, + userTypeID: 901, + namespace: .empty(specialChar1: ".", specialChar2: "_"), + typeName: .empty(specialChar1: "$", specialChar2: "_"), + registerByName: false, + fields: [] + ) + let wrongBytes = try wrongTypeMeta.encode() + let wrongHeader = try ByteBuffer(bytes: wrongBytes).readUInt64() + let buffer = ByteBuffer() + buffer.writeUInt8(UInt8(truncatingIfNeeded: TypeId.compatibleStruct.rawValue)) + buffer.writeUInt8(0) + buffer.writeBytes(wrongBytes) + let context = ReadContext(buffer: buffer, typeResolver: resolver, config: config) + + #expect(throws: (any Error).self) { + _ = try context.readTypeInfo(for: Address.self) + } + #expect(resolver.getTypeInfo(forHeader: wrongHeader) == nil) + + let addressInfo = try resolver.requireTypeInfo(for: Address.self) + let addressBytes = try #require(addressInfo.typeDefBytes) + let addressHeader = try ByteBuffer(bytes: addressBytes).readUInt64() + let exactBuffer = ByteBuffer() + exactBuffer.writeUInt8(UInt8(truncatingIfNeeded: TypeId.compatibleStruct.rawValue)) + exactBuffer.writeUInt8(0) + exactBuffer.writeBytes(addressBytes) + let exactContext = ReadContext(buffer: exactBuffer, typeResolver: resolver, config: config) + _ = try exactContext.readTypeInfo(for: Address.self) + #expect(resolver.getTypeInfo(forHeader: addressHeader) == nil) } @Test func failedStaticMetaDoesNotCount() throws { - let config = Config(compatible: true, maxSchemaVersionsPerType: 1) - let resolver = TypeResolver(config: config) - resolver.register(Person.self, id: 901) - resolver.register(Address.self, id: 902) - try resolver.finishRegistration() - - func typeMeta(userTypeID: UInt32, fieldName: String) throws -> TypeMeta { - try TypeMeta( - typeID: TypeId.compatibleStruct.rawValue, - userTypeID: userTypeID, - namespace: .empty(specialChar1: ".", specialChar2: "_"), - typeName: .empty(specialChar1: "$", specialChar2: "_"), - registerByName: false, - fields: [ - TypeMeta.FieldInfo( - fieldID: nil, - fieldName: fieldName, - fieldType: TypeMeta.FieldType(typeID: TypeId.int32.rawValue, nullable: false) + let config = Config(compatible: true, maxSchemaVersionsPerType: 1) + let resolver = TypeResolver(config: config) + resolver.register(Person.self, id: 901) + resolver.register(Address.self, id: 902) + try resolver.finishRegistration() + + func typeMeta(userTypeID: UInt32, fieldName: String) throws -> TypeMeta { + try TypeMeta( + typeID: TypeId.compatibleStruct.rawValue, + userTypeID: userTypeID, + namespace: .empty(specialChar1: ".", specialChar2: "_"), + typeName: .empty(specialChar1: "$", specialChar2: "_"), + registerByName: false, + fields: [ + TypeMeta.FieldInfo( + fieldID: nil, + fieldName: fieldName, + fieldType: TypeMeta.FieldType(typeID: TypeId.int32.rawValue, nullable: false) + ) + ] ) - ] - ) - } + } - func writeTypeInfo(_ buffer: ByteBuffer, marker: UInt8, typeMeta: TypeMeta) throws { - buffer.writeUInt8(UInt8(truncatingIfNeeded: TypeId.compatibleStruct.rawValue)) - buffer.writeUInt8(marker) - buffer.writeBytes(try typeMeta.encode()) - } - - let failedBuffer = ByteBuffer() - try writeTypeInfo(failedBuffer, marker: 0, typeMeta: typeMeta(userTypeID: 902, fieldName: "zip2")) - try writeTypeInfo(failedBuffer, marker: 2, typeMeta: typeMeta(userTypeID: 901, fieldName: "id2")) - let failedContext = ReadContext(buffer: failedBuffer, typeResolver: resolver, config: config) - _ = try failedContext.readTypeInfo(for: Address.self) - #expect(throws: (any Error).self) { + func writeTypeInfo(_ buffer: ByteBuffer, marker: UInt8, typeMeta: TypeMeta) throws { + buffer.writeUInt8(UInt8(truncatingIfNeeded: TypeId.compatibleStruct.rawValue)) + buffer.writeUInt8(marker) + buffer.writeBytes(try typeMeta.encode()) + } + + let failedBuffer = ByteBuffer() + try writeTypeInfo(failedBuffer, marker: 0, typeMeta: typeMeta(userTypeID: 902, fieldName: "zip2")) + try writeTypeInfo(failedBuffer, marker: 2, typeMeta: typeMeta(userTypeID: 901, fieldName: "id2")) + let failedContext = ReadContext(buffer: failedBuffer, typeResolver: resolver, config: config) _ = try failedContext.readTypeInfo(for: Address.self) - } + #expect(throws: (any Error).self) { + _ = try failedContext.readTypeInfo(for: Address.self) + } - let validBuffer = ByteBuffer() - try writeTypeInfo(validBuffer, marker: 0, typeMeta: typeMeta(userTypeID: 901, fieldName: "id3")) - let validContext = ReadContext(buffer: validBuffer, typeResolver: resolver, config: config) - _ = try validContext.readTypeInfo(for: Person.self) + let validBuffer = ByteBuffer() + try writeTypeInfo(validBuffer, marker: 0, typeMeta: typeMeta(userTypeID: 901, fieldName: "id3")) + let validContext = ReadContext(buffer: validBuffer, typeResolver: resolver, config: config) + _ = try validContext.readTypeInfo(for: Person.self) } @Test func macroStructRoundTrip() throws { - let fory = Fory() - fory.register(Address.self, id: 100) - fory.register(Person.self, id: 101) - - let person = Person( - id: 42, - name: "Alice", - nickname: nil, - scores: [10, 20, 30], - tags: ["swift", "xlang"], - addresses: [Address(street: "Main", zip: 94107)], - metadata: [1: 100, 2: nil] - ) + let fory = Fory() + fory.register(Address.self, id: 100) + fory.register(Person.self, id: 101) + + let person = Person( + id: 42, + name: "Alice", + nickname: nil, + scores: [10, 20, 30], + tags: ["swift", "xlang"], + addresses: [Address(street: "Main", zip: 94107)], + metadata: [1: 100, 2: nil] + ) - let data = try fory.serialize(person) - let decoded: Person = try fory.deserialize(data) - #expect(decoded == person) + let data = try fory.serialize(person) + let decoded: Person = try fory.deserialize(data) + #expect(decoded == person) } @Test func macroClassRefTracking() throws { - let fory = Fory(config: .init(trackRef: true, compatible: false)) - fory.register(Node.self, id: 200) + let fory = Fory(config: .init(trackRef: true, compatible: false)) + fory.register(Node.self, id: 200) - let node = Node(value: 7) - node.next = node + let node = Node(value: 7) + node.next = node - let data = try fory.serialize(node) - let decoded: Node = try fory.deserialize(data) + let data = try fory.serialize(node) + let decoded: Node = try fory.deserialize(data) - #expect(decoded.value == 7) - #expect(decoded.next === decoded) + #expect(decoded.value == 7) + #expect(decoded.next === decoded) } @Test func macroClassWeakRefTracking() throws { - let fory = Fory(config: .init(trackRef: true, compatible: false)) - fory.register(WeakNode.self, id: 201) + let fory = Fory(config: .init(trackRef: true, compatible: false)) + fory.register(WeakNode.self, id: 201) - let node = WeakNode(value: 13) - node.next = node + let node = WeakNode(value: 13) + node.next = node - let data = try fory.serialize(node) - let decoded: WeakNode = try fory.deserialize(data) + let data = try fory.serialize(node) + let decoded: WeakNode = try fory.deserialize(data) - #expect(decoded.value == 13) - #expect(decoded.next === decoded) + #expect(decoded.value == 13) + #expect(decoded.next === decoded) } @Test func topLevelAnyRoundTrip() throws { - let fory = Fory() - fory.register(Address.self, id: 209) + let fory = Fory() + fory.register(Address.self, id: 209) - let value: Any = Address(street: "AnyTop", zip: 8080) - let data = try fory.serialize(value) - let decoded: Any = try fory.deserialize(data) - #expect(decoded as? Address == Address(street: "AnyTop", zip: 8080)) + let value: Any = Address(street: "AnyTop", zip: 8080) + let data = try fory.serialize(value) + let decoded: Any = try fory.deserialize(data) + #expect(decoded as? Address == Address(street: "AnyTop", zip: 8080)) - var buffer = Data() - try fory.serialize(value, to: &buffer) - let decodedFrom: Any = try fory.deserialize(from: ByteBuffer(data: buffer)) - #expect(decodedFrom as? Address == Address(street: "AnyTop", zip: 8080)) + var buffer = Data() + try fory.serialize(value, to: &buffer) + let decodedFrom: Any = try fory.deserialize(from: ByteBuffer(data: buffer)) + #expect(decodedFrom as? Address == Address(street: "AnyTop", zip: 8080)) - let nullAny: Any = Optional.none as Any - let nullData = try fory.serialize(nullAny) - let nullDecoded: Any = try fory.deserialize(nullData) - #expect(nullDecoded is ForyAnyNullValue) + let nullAny: Any = Optional.none as Any + let nullData = try fory.serialize(nullAny) + let nullDecoded: Any = try fory.deserialize(nullData) + #expect(nullDecoded is ForyAnyNullValue) } @Test func dynamicUserTypesDecodeByID() throws { - let fory = Fory() - fory.register(Address.self, id: 600) - try fory.register(Person.self, name: "demo.person") + let fory = Fory() + fory.register(Address.self, id: 600) + try fory.register(Person.self, name: "demo.person") - let value: Any = Address(street: "mixed", zip: 7788) - let data = try fory.serialize(value) - let decoded: Any = try fory.deserialize(data) - #expect(decoded as? Address == Address(street: "mixed", zip: 7788)) + let value: Any = Address(street: "mixed", zip: 7788) + let data = try fory.serialize(value) + let decoded: Any = try fory.deserialize(data) + #expect(decoded as? Address == Address(street: "mixed", zip: 7788)) } @Test func duplicateNameRegistrationIsRejected() throws { - let resolver = TypeResolver(config: Config(trackRef: false)) - try resolver.register(Address.self, namespace: "demo", typeName: "entity") + let resolver = TypeResolver(config: Config(trackRef: false)) + try resolver.register(Address.self, namespace: "demo", typeName: "entity") - do { - try resolver.register(Person.self, namespace: "demo", typeName: "entity") - #expect(Bool(false)) - } catch {} + do { + try resolver.register(Person.self, namespace: "demo", typeName: "entity") + #expect(Bool(false)) + } catch {} } @Test func nameRegistrationSplitsLastDot() throws { - let resolver = TypeResolver(config: Config(trackRef: false)) - try resolver.register(Address.self, name: "com.example.Address") + let resolver = TypeResolver(config: Config(trackRef: false)) + try resolver.register(Address.self, name: "com.example.Address") - let info = try resolver.requireTypeInfo(namespace: "com.example", typeName: "Address") - #expect(info.namespace.value == "com.example") - #expect(info.typeName.value == "Address") + let info = try resolver.requireTypeInfo(namespace: "com.example", typeName: "Address") + #expect(info.namespace.value == "com.example") + #expect(info.typeName.value == "Address") } @Test func nameRegistrationAllowsSimpleName() throws { - let resolver = TypeResolver(config: Config(trackRef: false)) - try resolver.register(Address.self, name: "Address") + let resolver = TypeResolver(config: Config(trackRef: false)) + try resolver.register(Address.self, name: "Address") - let info = try resolver.requireTypeInfo(namespace: "", typeName: "Address") - #expect(info.namespace.value == "") - #expect(info.typeName.value == "Address") + let info = try resolver.requireTypeInfo(namespace: "", typeName: "Address") + #expect(info.namespace.value == "") + #expect(info.typeName.value == "Address") } @Test func nameRegistrationRejectsEmptyName() throws { - let fory = Fory() + let fory = Fory() - #expect(throws: ForyError.self) { - try fory.register(Address.self, name: "") - } + #expect(throws: ForyError.self) { + try fory.register(Address.self, name: "") + } } @Test func nameRegistrationRejectsTrailingDot() throws { - let fory = Fory() + let fory = Fory() - #expect(throws: ForyError.self) { - try fory.register(Address.self, name: "com.example.") - } + #expect(throws: ForyError.self) { + try fory.register(Address.self, name: "com.example.") + } } @Test func splitNameRegistrationRejectsDottedTypeName() throws { - let resolver = TypeResolver(config: Config(trackRef: false)) + let resolver = TypeResolver(config: Config(trackRef: false)) - #expect(throws: ForyError.self) { - try resolver.register(Address.self, namespace: "com", typeName: "example.Address") - } + #expect(throws: ForyError.self) { + try resolver.register(Address.self, namespace: "com", typeName: "example.Address") + } } @Test func registrationIsRejectedAfterFirstTopLevelUse() throws { - let fory = Fory() - _ = try fory.serialize(Int32(7)) - - do { - try fory.register(Address.self, name: "demo.address") - #expect(Bool(false)) - } catch { - #expect("\(error)".contains("cannot register more types")) - } + let fory = Fory() + _ = try fory.serialize(Int32(7)) + + do { + try fory.register(Address.self, name: "demo.address") + #expect(Bool(false)) + } catch { + #expect("\(error)".contains("cannot register more types")) + } } @Test func serializeToAppendsRoots() throws { - let fory = Fory() - let first = Int32(7) - let second = "swift-buffer" - let third: String? = nil + let fory = Fory() + let first = Int32(7) + let second = "swift-buffer" + let third: String? = nil - let firstData = try fory.serialize(first) - let secondData = try fory.serialize(second) - let thirdData = try fory.serialize(third) + let firstData = try fory.serialize(first) + let secondData = try fory.serialize(second) + let thirdData = try fory.serialize(third) - var stream = Data() - try fory.serialize(first, to: &stream) - try fory.serialize(second, to: &stream) - try fory.serialize(third, to: &stream) + var stream = Data() + try fory.serialize(first, to: &stream) + try fory.serialize(second, to: &stream) + try fory.serialize(third, to: &stream) - var expected = Data() - expected.append(firstData) - expected.append(secondData) - expected.append(thirdData) - #expect(stream == expected) + var expected = Data() + expected.append(firstData) + expected.append(secondData) + expected.append(thirdData) + #expect(stream == expected) - let buffer = ByteBuffer(data: stream) - let decodedFirst: Int32 = try fory.deserialize(from: buffer) - #expect(decodedFirst == first) - #expect(buffer.getCursor() == firstData.count) + let buffer = ByteBuffer(data: stream) + let decodedFirst: Int32 = try fory.deserialize(from: buffer) + #expect(decodedFirst == first) + #expect(buffer.getCursor() == firstData.count) - let decodedSecond: String = try fory.deserialize(from: buffer) - #expect(decodedSecond == second) - #expect(buffer.getCursor() == firstData.count + secondData.count) + let decodedSecond: String = try fory.deserialize(from: buffer) + #expect(decodedSecond == second) + #expect(buffer.getCursor() == firstData.count + secondData.count) - let decodedThird: String? = try fory.deserialize(from: buffer) - #expect(decodedThird == nil) - #expect(buffer.remaining == 0) + let decodedThird: String? = try fory.deserialize(from: buffer) + #expect(decodedThird == nil) + #expect(buffer.remaining == 0) } @Test func rootBufferHonorsCursor() throws { - let fory = Fory() - let prefix: [UInt8] = [0xAA, 0xBB, 0xCC] - let payload = try fory.serialize("offset") + let fory = Fory() + let prefix: [UInt8] = [0xAA, 0xBB, 0xCC] + let payload = try fory.serialize("offset") - let buffer = ByteBuffer() - buffer.writeBytes(prefix) - buffer.writeBytes(Array(payload)) - buffer.setCursor(prefix.count) + let buffer = ByteBuffer() + buffer.writeBytes(prefix) + buffer.writeBytes(Array(payload)) + buffer.setCursor(prefix.count) - let decoded: String = try fory.deserialize(from: buffer) - #expect(decoded == "offset") - #expect(buffer.getCursor() == buffer.count) - #expect(Array(buffer.storage.prefix(prefix.count)) == prefix) + let decoded: String = try fory.deserialize(from: buffer) + #expect(decoded == "offset") + #expect(buffer.getCursor() == buffer.count) + #expect(Array(buffer.storage.prefix(prefix.count)) == prefix) } @Test func topLevelAnyObjectRoundTrip() throws { - let fory = Fory(config: .init(trackRef: true, compatible: false)) - fory.register(Node.self, id: 210) + let fory = Fory(config: .init(trackRef: true, compatible: false)) + fory.register(Node.self, id: 210) - let value: AnyObject = Node(value: 123) - let data = try fory.serialize(value) - let decoded: AnyObject = try fory.deserialize(data) + let value: AnyObject = Node(value: 123) + let data = try fory.serialize(value) + let decoded: AnyObject = try fory.deserialize(data) - let node = decoded as? Node - #expect(node != nil) - #expect(node?.value == 123) + let node = decoded as? Node + #expect(node != nil) + #expect(node?.value == 123) - var buffer = Data() - try fory.serialize(value, to: &buffer) - let decodedFrom: AnyObject = try fory.deserialize(from: ByteBuffer(data: buffer)) - #expect((decodedFrom as? Node)?.value == 123) + var buffer = Data() + try fory.serialize(value, to: &buffer) + let decodedFrom: AnyObject = try fory.deserialize(from: ByteBuffer(data: buffer)) + #expect((decodedFrom as? Node)?.value == 123) } @Test func topLevelAnySerializerRoundTrip() throws { - let fory = Fory() - fory.register(Address.self, id: 211) + let fory = Fory() + fory.register(Address.self, id: 211) - let value: any Serializer = Address(street: "AnyStreet", zip: 9090) - let data = try fory.serialize(value) - let decoded: any Serializer = try fory.deserialize(data) + let value: any Serializer = Address(street: "AnyStreet", zip: 9090) + let data = try fory.serialize(value) + let decoded: any Serializer = try fory.deserialize(data) - let address = decoded as? Address - #expect(address == Address(street: "AnyStreet", zip: 9090)) + let address = decoded as? Address + #expect(address == Address(street: "AnyStreet", zip: 9090)) - var buffer = Data() - try fory.serialize(value, to: &buffer) - let decodedFrom: any Serializer = try fory.deserialize(from: ByteBuffer(data: buffer)) - #expect(decodedFrom as? Address == Address(street: "AnyStreet", zip: 9090)) + var buffer = Data() + try fory.serialize(value, to: &buffer) + let decodedFrom: any Serializer = try fory.deserialize(from: ByteBuffer(data: buffer)) + #expect(decodedFrom as? Address == Address(street: "AnyStreet", zip: 9090)) } @Test func macroDynamicAnyObjectAndAnySerializerFieldsRoundTrip() throws { - let fory = Fory(config: .init(trackRef: true, compatible: false)) - fory.register(Node.self, id: 220) - fory.register(Address.self, id: 221) - fory.register(AnyObjectHolder.self, id: 222) - fory.register(AnySerializerHolder.self, id: 223) - - let sharedNode = Node(value: 77) - let objectHolder = AnyObjectHolder( - value: sharedNode, - optionalValue: nil, - items: [sharedNode, NSNull()] - ) - let objectData = try fory.serialize(objectHolder) - let objectDecoded: AnyObjectHolder = try fory.deserialize(objectData) - #expect((objectDecoded.value as? Node)?.value == 77) - #expect(objectDecoded.optionalValue == nil) - #expect(objectDecoded.items.count == 2) - #expect((objectDecoded.items[0] as? Node)?.value == 77) - #expect(objectDecoded.items[1] is NSNull) - - let serializerHolder = AnySerializerHolder( - value: Address(street: "Root", zip: 10001), - items: [Int32(11), Address(street: "Nested", zip: 10002)], - map: [ - "age": Int64(19), - "address": Address(street: "Mapped", zip: 10003), - ] - ) - let serializerData = try fory.serialize(serializerHolder) - let serializerDecoded: AnySerializerHolder = try fory.deserialize(serializerData) + let fory = Fory(config: .init(trackRef: true, compatible: false)) + fory.register(Node.self, id: 220) + fory.register(Address.self, id: 221) + fory.register(AnyObjectHolder.self, id: 222) + fory.register(AnySerializerHolder.self, id: 223) + + let sharedNode = Node(value: 77) + let objectHolder = AnyObjectHolder( + value: sharedNode, + optionalValue: nil, + items: [sharedNode, NSNull()] + ) + let objectData = try fory.serialize(objectHolder) + let objectDecoded: AnyObjectHolder = try fory.deserialize(objectData) + #expect((objectDecoded.value as? Node)?.value == 77) + #expect(objectDecoded.optionalValue == nil) + #expect(objectDecoded.items.count == 2) + #expect((objectDecoded.items[0] as? Node)?.value == 77) + #expect(objectDecoded.items[1] is NSNull) + + let serializerHolder = AnySerializerHolder( + value: Address(street: "Root", zip: 10001), + items: [Int32(11), Address(street: "Nested", zip: 10002)], + map: [ + "age": Int64(19), + "address": Address(street: "Mapped", zip: 10003) + ] + ) + let serializerData = try fory.serialize(serializerHolder) + let serializerDecoded: AnySerializerHolder = try fory.deserialize(serializerData) - #expect(serializerDecoded.value as? Address == Address(street: "Root", zip: 10001)) - #expect(serializerDecoded.items.count == 2) - #expect(serializerDecoded.items[0] as? Int32 == 11) - #expect(serializerDecoded.items[1] as? Address == Address(street: "Nested", zip: 10002)) - #expect(serializerDecoded.map["age"] as? Int64 == 19) - #expect(serializerDecoded.map["address"] as? Address == Address(street: "Mapped", zip: 10003)) + #expect(serializerDecoded.value as? Address == Address(street: "Root", zip: 10001)) + #expect(serializerDecoded.items.count == 2) + #expect(serializerDecoded.items[0] as? Int32 == 11) + #expect(serializerDecoded.items[1] as? Address == Address(street: "Nested", zip: 10002)) + #expect(serializerDecoded.map["age"] as? Int64 == 19) + #expect(serializerDecoded.map["address"] as? Address == Address(street: "Mapped", zip: 10003)) } @Test func dynamicAnySerializerTracksRefs() throws { - let fory = Fory(config: .init(trackRef: true, compatible: false)) - fory.register(Node.self, id: 226) - fory.register(AnySerializerHolder.self, id: 227) - - let shared = Node(value: 88) - shared.next = shared - let value = AnySerializerHolder( - value: shared, - items: [shared], - map: ["shared": shared] - ) + let fory = Fory(config: .init(trackRef: true, compatible: false)) + fory.register(Node.self, id: 226) + fory.register(AnySerializerHolder.self, id: 227) + + let shared = Node(value: 88) + shared.next = shared + let value = AnySerializerHolder( + value: shared, + items: [shared], + map: ["shared": shared] + ) - let decoded: AnySerializerHolder = try fory.deserialize(try fory.serialize(value)) - let root = decoded.value as? Node - let item = decoded.items.first as? Node - let mapped = decoded.map["shared"] as? Node + let decoded: AnySerializerHolder = try fory.deserialize(try fory.serialize(value)) + let root = decoded.value as? Node + let item = decoded.items.first as? Node + let mapped = decoded.map["shared"] as? Node - #expect(root != nil) - #expect(root === item) - #expect(item === mapped) - #expect(root?.next === root) + #expect(root != nil) + #expect(root === item) + #expect(item === mapped) + #expect(root?.next === root) } @Test func macroAnyFieldsRoundTrip() throws { - let fory = Fory() - fory.register(Address.self, id: 224) - fory.register(AnyFieldHolder.self, id: 225) - - let value = AnyFieldHolder( - value: Address(street: "AnyRoot", zip: 11001), - optionalValue: nil, - list: [Int32(7), "hello", Address(street: "AnyList", zip: 11002), NSNull()], - stringMap: [ - "count": Int64(3), - "name": "map", - "address": Address(street: "AnyMap", zip: 11003), - "empty": NSNull(), - ], - int32Map: [ - 1: Int32(-9), - 2: "v2", - 3: Address(street: "AnyIntMap", zip: 11004), - 4: NSNull(), - ] - ) - let data = try fory.serialize(value) - let decoded: AnyFieldHolder = try fory.deserialize(data) - - #expect(decoded.value as? Address == Address(street: "AnyRoot", zip: 11001)) - #expect(decoded.optionalValue == nil) - #expect(decoded.list.count == 4) - #expect(decoded.list[0] as? Int32 == 7) - #expect(decoded.list[1] as? String == "hello") - #expect(decoded.list[2] as? Address == Address(street: "AnyList", zip: 11002)) - #expect(decoded.list[3] is NSNull) - #expect(decoded.stringMap["count"] as? Int64 == 3) - #expect(decoded.stringMap["name"] as? String == "map") - #expect(decoded.stringMap["address"] as? Address == Address(street: "AnyMap", zip: 11003)) - #expect(decoded.stringMap["empty"] is NSNull) - #expect(decoded.int32Map[1] as? Int32 == -9) - #expect(decoded.int32Map[2] as? String == "v2") - #expect(decoded.int32Map[3] as? Address == Address(street: "AnyIntMap", zip: 11004)) - #expect(decoded.int32Map[4] is NSNull) + let fory = Fory() + fory.register(Address.self, id: 224) + fory.register(AnyFieldHolder.self, id: 225) + + let value = AnyFieldHolder( + value: Address(street: "AnyRoot", zip: 11001), + optionalValue: nil, + list: [Int32(7), "hello", Address(street: "AnyList", zip: 11002), NSNull()], + stringMap: [ + "count": Int64(3), + "name": "map", + "address": Address(street: "AnyMap", zip: 11003), + "empty": NSNull() + ], + int32Map: [ + 1: Int32(-9), + 2: "v2", + 3: Address(street: "AnyIntMap", zip: 11004), + 4: NSNull() + ] + ) + let data = try fory.serialize(value) + let decoded: AnyFieldHolder = try fory.deserialize(data) + + #expect(decoded.value as? Address == Address(street: "AnyRoot", zip: 11001)) + #expect(decoded.optionalValue == nil) + #expect(decoded.list.count == 4) + #expect(decoded.list[0] as? Int32 == 7) + #expect(decoded.list[1] as? String == "hello") + #expect(decoded.list[2] as? Address == Address(street: "AnyList", zip: 11002)) + #expect(decoded.list[3] is NSNull) + #expect(decoded.stringMap["count"] as? Int64 == 3) + #expect(decoded.stringMap["name"] as? String == "map") + #expect(decoded.stringMap["address"] as? Address == Address(street: "AnyMap", zip: 11003)) + #expect(decoded.stringMap["empty"] is NSNull) + #expect(decoded.int32Map[1] as? Int32 == -9) + #expect(decoded.int32Map[2] as? String == "v2") + #expect(decoded.int32Map[3] as? Address == Address(street: "AnyIntMap", zip: 11004)) + #expect(decoded.int32Map[4] is NSNull) } @Test func collectionAndMapRefTracking() throws { - let fory = Fory(config: .init(trackRef: true, compatible: false)) - fory.register(Node.self, id: 200) - - let shared = Node(value: 11) - let list: [Node?] = [shared, shared, nil] - let listData = try fory.serialize(list) - let listReader = ByteBuffer(data: listData) - _ = try fory.readHead(buffer: listReader) - _ = try listReader.readInt8() - _ = try listReader.readVarUInt32() - _ = try listReader.readVarUInt32() - let listHeader = try listReader.readUInt8() - #expect((listHeader & 0b0000_0001) != 0) - - let decodedList: [Node?] = try fory.deserialize(listData) - #expect(decodedList.count == 3) - #expect(decodedList[0] === decodedList[1]) - #expect(decodedList[2] == nil) - - let sharedValue = Node(value: 21) - let map: [Int8: Node?] = [1: sharedValue, 2: sharedValue] - let mapData = try fory.serialize(map) - let mapReader = ByteBuffer(data: mapData) - _ = try fory.readHead(buffer: mapReader) - _ = try mapReader.readInt8() - _ = try mapReader.readVarUInt32() - _ = try mapReader.readVarUInt32() - let mapChunkHeader = try mapReader.readUInt8() - #expect((mapChunkHeader & 0b0000_1000) != 0) - - let decodedMap: [Int8: Node?] = try fory.deserialize(mapData) - let v1 = decodedMap[1] ?? nil - let v2 = decodedMap[2] ?? nil - #expect(v1 != nil) - #expect(v1 === v2) + let fory = Fory(config: .init(trackRef: true, compatible: false)) + fory.register(Node.self, id: 200) + + let shared = Node(value: 11) + let list: [Node?] = [shared, shared, nil] + let listData = try fory.serialize(list) + let listReader = ByteBuffer(data: listData) + _ = try fory.readHead(buffer: listReader) + _ = try listReader.readInt8() + _ = try listReader.readVarUInt32() + _ = try listReader.readVarUInt32() + let listHeader = try listReader.readUInt8() + #expect((listHeader & 0b0000_0001) != 0) + + let decodedList: [Node?] = try fory.deserialize(listData) + #expect(decodedList.count == 3) + #expect(decodedList[0] === decodedList[1]) + #expect(decodedList[2] == nil) + + let sharedValue = Node(value: 21) + let map: [Int8: Node?] = [1: sharedValue, 2: sharedValue] + let mapData = try fory.serialize(map) + let mapReader = ByteBuffer(data: mapData) + _ = try fory.readHead(buffer: mapReader) + _ = try mapReader.readInt8() + _ = try mapReader.readVarUInt32() + _ = try mapReader.readVarUInt32() + let mapChunkHeader = try mapReader.readUInt8() + #expect((mapChunkHeader & 0b0000_1000) != 0) + + let decodedMap: [Int8: Node?] = try fory.deserialize(mapData) + let v1 = decodedMap[1] ?? nil + let v2 = decodedMap[2] ?? nil + #expect(v1 != nil) + #expect(v1 === v2) } @Test func macroFieldOrderFollowsForyRules() throws { - let fory = Fory(compatible: false) - fory.register(FieldOrder.self, id: 300) + let fory = Fory(compatible: false) + fory.register(FieldOrder.self, id: 300) - let value = FieldOrder(textTail: "tail", longValue: 123_456_789, shortValue: 17, intValue: 99) - let data = try fory.serialize(value) + let value = FieldOrder(textTail: "tail", longValue: 123_456_789, shortValue: 17, intValue: 99) + let data = try fory.serialize(value) - let buffer = ByteBuffer(data: data) - _ = try fory.readHead(buffer: buffer) - _ = try buffer.readInt8() // root ref flag - _ = try buffer.readVarUInt32() // type id - _ = try buffer.readVarUInt32() // user type id - _ = try buffer.readInt32() // schema hash + let buffer = ByteBuffer(data: data) + _ = try fory.readHead(buffer: buffer) + _ = try buffer.readInt8() // root ref flag + _ = try buffer.readVarUInt32() // type id + _ = try buffer.readVarUInt32() // user type id + _ = try buffer.readInt32() // schema hash - let first = try buffer.readInt16() - let second = try buffer.readVarInt64() - let third = try buffer.readVarInt32() + let first = try buffer.readInt16() + let second = try buffer.readVarInt64() + let third = try buffer.readVarInt32() - let tailContext = ReadContext( - buffer: buffer, typeResolver: fory.typeResolver, config: fory.config) - let fourth = try String.foryReadData(tailContext) + let tailContext = ReadContext( + buffer: buffer, typeResolver: fory.typeResolver, config: fory.config) + let fourth = try String.foryReadData(tailContext) - #expect(first == value.shortValue) - #expect(second == value.longValue) - #expect(third == value.intValue) - #expect(fourth == value.textTail) + #expect(first == value.shortValue) + #expect(second == value.longValue) + #expect(third == value.intValue) + #expect(fourth == value.textTail) } @Test func macroTaggedFieldsKeepGroupedPayloadOrder() throws { - let fory = Fory(compatible: false) - fory.register(TaggedFieldOrder.self, id: 303) + let fory = Fory(compatible: false) + fory.register(TaggedFieldOrder.self, id: 303) - let fields = TaggedFieldOrder.foryFieldsInfo(trackRef: false) - #expect(fields.map(\.fieldName) == ["intValue", "textTail"]) - #expect(fields.map(\.fieldID) == [10, 1]) + let fields = TaggedFieldOrder.foryFieldsInfo(trackRef: false) + #expect(fields.map(\.fieldName) == ["intValue", "textTail"]) + #expect(fields.map(\.fieldID) == [10, 1]) - let value = TaggedFieldOrder(textTail: "tail", intValue: 99) - let data = try fory.serialize(value) - let buffer = ByteBuffer(data: data) - _ = try fory.readHead(buffer: buffer) - _ = try buffer.readInt8() - _ = try buffer.readVarUInt32() - _ = try buffer.readVarUInt32() - _ = try buffer.readInt32() + let value = TaggedFieldOrder(textTail: "tail", intValue: 99) + let data = try fory.serialize(value) + let buffer = ByteBuffer(data: data) + _ = try fory.readHead(buffer: buffer) + _ = try buffer.readInt8() + _ = try buffer.readVarUInt32() + _ = try buffer.readVarUInt32() + _ = try buffer.readInt32() - #expect(try buffer.readVarInt32() == value.intValue) - let tailContext = ReadContext( - buffer: buffer, typeResolver: fory.typeResolver, config: fory.config) - #expect(try String.foryReadData(tailContext) == value.textTail) + #expect(try buffer.readVarInt32() == value.intValue) + let tailContext = ReadContext( + buffer: buffer, typeResolver: fory.typeResolver, config: fory.config) + #expect(try String.foryReadData(tailContext) == value.textTail) } @Test func macroNonPrimitiveFieldsSortByFieldIdentifier() throws { - let fields = NonPrimitiveFieldOrder.foryFieldsInfo(trackRef: false) + let fields = NonPrimitiveFieldOrder.foryFieldsInfo(trackRef: false) - #expect( - fields.map(\.fieldName) == [ - "intValue", "mapValue", "stringValue", "addressValue", "binaryValue", - ]) - #expect(fields.map(\.fieldID) == [nil, 10, 20, nil, nil]) + #expect( + fields.map(\.fieldName) == [ + "intValue", "mapValue", "stringValue", "addressValue", "binaryValue" + ]) + #expect(fields.map(\.fieldID) == [nil, 10, 20, nil, nil]) } @Test func macroFieldEncodingOverridesForUnsignedTypes() throws { - let fory = Fory(compatible: false) - fory.register(EncodedNumberFields.self, id: 301) + let fory = Fory(compatible: false) + fory.register(EncodedNumberFields.self, id: 301) - let value = EncodedNumberFields( - u32Fixed: 0x1122_3344, - u64Tagged: UInt64(Int32.max) + 99 - ) - let data = try fory.serialize(value) - let decoded: EncodedNumberFields = try fory.deserialize(data) - #expect(decoded == value) + let value = EncodedNumberFields( + u32Fixed: 0x1122_3344, + u64Tagged: UInt64(Int32.max) + 99 + ) + let data = try fory.serialize(value) + let decoded: EncodedNumberFields = try fory.deserialize(data) + #expect(decoded == value) - let buffer = ByteBuffer(data: data) - _ = try fory.readHead(buffer: buffer) - _ = try buffer.readInt8() - _ = try buffer.readVarUInt32() - _ = try buffer.readVarUInt32() - _ = try buffer.readInt32() + let buffer = ByteBuffer(data: data) + _ = try fory.readHead(buffer: buffer) + _ = try buffer.readInt8() + _ = try buffer.readVarUInt32() + _ = try buffer.readVarUInt32() + _ = try buffer.readInt32() - #expect(try buffer.readUInt32() == value.u32Fixed) - #expect(try buffer.readTaggedUInt64() == value.u64Tagged) + #expect(try buffer.readUInt32() == value.u32Fixed) + #expect(try buffer.readTaggedUInt64() == value.u64Tagged) } @Test func macroEnumUsesExplicitIntegerRawValue() throws { - let fory = Fory(config: .init(trackRef: false, compatible: false)) - fory.register(SparseStatus.self, id: 302) + let fory = Fory(config: .init(trackRef: false, compatible: false)) + fory.register(SparseStatus.self, id: 302) - let data = try fory.serialize(SparseStatus.ok) - let buffer = ByteBuffer(data: data) - _ = try fory.readHead(buffer: buffer) - _ = try buffer.readInt8() - _ = try buffer.readVarUInt32() - _ = try buffer.readVarUInt32() - #expect(try buffer.readVarUInt32() == 8192) + let data = try fory.serialize(SparseStatus.ok) + let buffer = ByteBuffer(data: data) + _ = try fory.readHead(buffer: buffer) + _ = try buffer.readInt8() + _ = try buffer.readVarUInt32() + _ = try buffer.readVarUInt32() + #expect(try buffer.readVarUInt32() == 8192) - let decoded: SparseStatus = try fory.deserialize(data) - #expect(decoded == .ok) + let decoded: SparseStatus = try fory.deserialize(data) + #expect(decoded == .ok) } @Test func macroFieldEncodingOverridesCompatibleTypeMeta() throws { - let fields = EncodedNumberFields.foryFieldsInfo(trackRef: false) - #expect(fields.count == 2) - #expect(fields[0].fieldName == "u32Fixed") - #expect(fields[0].fieldType.typeID == TypeId.uint32.rawValue) - #expect(fields[1].fieldName == "u64Tagged") - #expect(fields[1].fieldType.typeID == TypeId.taggedUInt64.rawValue) + let fields = EncodedNumberFields.foryFieldsInfo(trackRef: false) + #expect(fields.count == 2) + #expect(fields[0].fieldName == "u32Fixed") + #expect(fields[0].fieldType.typeID == TypeId.uint32.rawValue) + #expect(fields[1].fieldName == "u64Tagged") + #expect(fields[1].fieldType.typeID == TypeId.taggedUInt64.rawValue) } @Test func macroReducedPrecisionFieldsUseXlangTypeIDs() { - let fields = ReducedPrecisionMacroFields.foryFieldsInfo(trackRef: false) - #expect(fields.count == 4) - #expect( - fields.map(\.fieldName) == ["float16Value", "bfloat16Value", "bfloat16Array", "float16Array"]) - #expect( - fields.map(\.fieldType.typeID) == [ - TypeId.float16.rawValue, - TypeId.bfloat16.rawValue, - TypeId.bfloat16Array.rawValue, - TypeId.float16Array.rawValue, - ]) + let fields = ReducedPrecisionMacroFields.foryFieldsInfo(trackRef: false) + #expect(fields.count == 4) + #expect( + fields.map(\.fieldName) == ["float16Value", "bfloat16Value", "bfloat16Array", "float16Array"]) + #expect( + fields.map(\.fieldType.typeID) == [ + TypeId.float16.rawValue, + TypeId.bfloat16.rawValue, + TypeId.bfloat16Array.rawValue, + TypeId.float16Array.rawValue + ]) } @Test func macroFieldIDsPopulateCompatibleTypeMeta() { - let fields = FieldIdConfigured.foryFieldsInfo(trackRef: false) - #expect(fields.count == 2) - - var byID: [Int16: TypeMeta.FieldInfo] = [:] - for field in fields { - if let id = field.fieldID { - byID[id] = field + let fields = FieldIdConfigured.foryFieldsInfo(trackRef: false) + #expect(fields.count == 2) + + var byID: [Int16: TypeMeta.FieldInfo] = [:] + for field in fields { + if let id = field.fieldID { + byID[id] = field + } } - } - #expect(byID[2]?.fieldName == "stableID") - #expect(byID[2]?.fieldType.typeID == TypeId.varint32.rawValue) - #expect(byID[5]?.fieldName == "fixedValue") - #expect(byID[5]?.fieldType.typeID == TypeId.int32.rawValue) + #expect(byID[2]?.fieldName == "stableID") + #expect(byID[2]?.fieldType.typeID == TypeId.varint32.rawValue) + #expect(byID[5]?.fieldName == "fixedValue") + #expect(byID[5]?.fieldType.typeID == TypeId.int32.rawValue) } @Test func macroFieldIDsDriveCompatibleStructDecodeAcrossRenames() throws { - let writer = Fory(config: .init(trackRef: false, compatible: true)) - writer.register(FieldIdSource.self, id: 9101) + let writer = Fory(config: .init(trackRef: false, compatible: true)) + writer.register(FieldIdSource.self, id: 9101) - let reader = Fory(config: .init(trackRef: false, compatible: true)) - reader.register(FieldIdTarget.self, id: 9101) + let reader = Fory(config: .init(trackRef: false, compatible: true)) + reader.register(FieldIdTarget.self, id: 9101) - let source = FieldIdSource(value: 42, label: "alpha") - let bytes = try writer.serialize(source) - let decoded: FieldIdTarget = try reader.deserialize(bytes) + let source = FieldIdSource(value: 42, label: "alpha") + let bytes = try writer.serialize(source) + let decoded: FieldIdTarget = try reader.deserialize(bytes) - #expect(decoded.renamedValue == source.value) - #expect(decoded.renamedLabel == source.label) + #expect(decoded.renamedValue == source.value) + #expect(decoded.renamedLabel == source.label) - let roundTrip = try reader.serialize(decoded) - let back: FieldIdSource = try writer.deserialize(roundTrip) - #expect(back == source) + let roundTrip = try reader.serialize(decoded) + let back: FieldIdSource = try writer.deserialize(roundTrip) + #expect(back == source) } @Test func macroFieldIDsDriveTaggedUnionDecodeAcrossRenames() throws { - let writer = Fory(config: .init(trackRef: false, compatible: true)) - writer.register(FieldIdUnionSource.self, id: 9102) + let writer = Fory(config: .init(trackRef: false, compatible: true)) + writer.register(FieldIdUnionSource.self, id: 9102) - let reader = Fory(config: .init(trackRef: false, compatible: true)) - reader.register(FieldIdUnionTarget.self, id: 9102) + let reader = Fory(config: .init(trackRef: false, compatible: true)) + reader.register(FieldIdUnionTarget.self, id: 9102) - let source = FieldIdUnionSource.number(123) - let bytes = try writer.serialize(source) - let decoded: FieldIdUnionTarget = try reader.deserialize(bytes) + let source = FieldIdUnionSource.number(123) + let bytes = try writer.serialize(source) + let decoded: FieldIdUnionTarget = try reader.deserialize(bytes) - switch decoded { - case .renamedNumber(let value): - #expect(value == 123) - default: - #expect(Bool(false)) - } + switch decoded { + case .renamedNumber(let value): + #expect(value == 123) + default: + #expect(Bool(false)) + } } @Test func compatibleNestedStructArrayRoundTrip() throws { - let writer = Fory(config: .init(trackRef: false, compatible: true)) - writer.register(CompatibleNestedItem.self, id: 9103) - writer.register(CompatibleNestedArrayHolder.self, id: 9104) - - let reader = Fory(config: .init(trackRef: false, compatible: true)) - reader.register(CompatibleNestedItem.self, id: 9103) - reader.register(CompatibleNestedArrayHolder.self, id: 9104) - - let value = CompatibleNestedArrayHolder( - items: [ - CompatibleNestedItem(id: 1, name: "alpha"), - CompatibleNestedItem(id: 2, name: "beta"), - ] - ) - let bytes = try writer.serialize(value) - let decoded: CompatibleNestedArrayHolder = try reader.deserialize(bytes) - #expect(decoded == value) + let writer = Fory(config: .init(trackRef: false, compatible: true)) + writer.register(CompatibleNestedItem.self, id: 9103) + writer.register(CompatibleNestedArrayHolder.self, id: 9104) + + let reader = Fory(config: .init(trackRef: false, compatible: true)) + reader.register(CompatibleNestedItem.self, id: 9103) + reader.register(CompatibleNestedArrayHolder.self, id: 9104) + + let value = CompatibleNestedArrayHolder( + items: [ + CompatibleNestedItem(id: 1, name: "alpha"), + CompatibleNestedItem(id: 2, name: "beta") + ] + ) + let bytes = try writer.serialize(value) + let decoded: CompatibleNestedArrayHolder = try reader.deserialize(bytes) + #expect(decoded == value) } @Test func compatibleNestedStructOptionalArrayRoundTrip() throws { - let writer = Fory(config: .init(trackRef: false, compatible: true)) - writer.register(CompatibleNestedItem.self, id: 9103) - writer.register(CompatibleNestedOptionalArrayHolder.self, id: 9105) - - let reader = Fory(config: .init(trackRef: false, compatible: true)) - reader.register(CompatibleNestedItem.self, id: 9103) - reader.register(CompatibleNestedOptionalArrayHolder.self, id: 9105) - - let value = CompatibleNestedOptionalArrayHolder( - items: [ - CompatibleNestedItem(id: 1, name: "alpha"), - nil, - CompatibleNestedItem(id: 2, name: "beta"), - ] - ) - let bytes = try writer.serialize(value) - let decoded: CompatibleNestedOptionalArrayHolder = try reader.deserialize(bytes) - #expect(decoded == value) + let writer = Fory(config: .init(trackRef: false, compatible: true)) + writer.register(CompatibleNestedItem.self, id: 9103) + writer.register(CompatibleNestedOptionalArrayHolder.self, id: 9105) + + let reader = Fory(config: .init(trackRef: false, compatible: true)) + reader.register(CompatibleNestedItem.self, id: 9103) + reader.register(CompatibleNestedOptionalArrayHolder.self, id: 9105) + + let value = CompatibleNestedOptionalArrayHolder( + items: [ + CompatibleNestedItem(id: 1, name: "alpha"), + nil, + CompatibleNestedItem(id: 2, name: "beta") + ] + ) + let bytes = try writer.serialize(value) + let decoded: CompatibleNestedOptionalArrayHolder = try reader.deserialize(bytes) + #expect(decoded == value) } @Test func compatibleNestedStructMapRoundTrip() throws { - let writer = Fory(config: .init(trackRef: false, compatible: true)) - writer.register(CompatibleNestedItem.self, id: 9103) - writer.register(CompatibleNestedMapHolder.self, id: 9106) - - let reader = Fory(config: .init(trackRef: false, compatible: true)) - reader.register(CompatibleNestedItem.self, id: 9103) - reader.register(CompatibleNestedMapHolder.self, id: 9106) - - let value = CompatibleNestedMapHolder( - items: [ - 1: CompatibleNestedItem(id: 10, name: "first"), - 2: CompatibleNestedItem(id: 20, name: "second"), - ] - ) - let bytes = try writer.serialize(value) - let decoded: CompatibleNestedMapHolder = try reader.deserialize(bytes) - #expect(decoded == value) + let writer = Fory(config: .init(trackRef: false, compatible: true)) + writer.register(CompatibleNestedItem.self, id: 9103) + writer.register(CompatibleNestedMapHolder.self, id: 9106) + + let reader = Fory(config: .init(trackRef: false, compatible: true)) + reader.register(CompatibleNestedItem.self, id: 9103) + reader.register(CompatibleNestedMapHolder.self, id: 9106) + + let value = CompatibleNestedMapHolder( + items: [ + 1: CompatibleNestedItem(id: 10, name: "first"), + 2: CompatibleNestedItem(id: 20, name: "second") + ] + ) + let bytes = try writer.serialize(value) + let decoded: CompatibleNestedMapHolder = try reader.deserialize(bytes) + #expect(decoded == value) } @Test func pvlVarInt64AndVarUInt64Extremes() throws { - let uintValues: [UInt64] = [ - 0, - 1, - 127, - 128, - 16_383, - 16_384, - 2_097_151, - 2_097_152, - 268_435_455, - 268_435_456, - 34_359_738_367, - 34_359_738_368, - 4_398_046_511_103, - 4_398_046_511_104, - 562_949_953_421_311, - 562_949_953_421_312, - 72_057_594_037_927_935, - 72_057_594_037_927_936, - UInt64(Int64.max), - UInt64.max, - ] - let intValues: [Int64] = [ - Int64.min, - Int64.min + 1, - -1_000_000_000_000, - -1_000_000, - -1_000, - -128, - -1, - 0, - 1, - 127, - 1_000, - 1_000_000, - 1_000_000_000_000, - Int64.max - 1, - Int64.max, - ] - - let writeBuffer = ByteBuffer() - for value in uintValues { - writeBuffer.writeVarUInt64(value) - } - for value in intValues { - writeBuffer.writeVarInt64(value) - } - let minBuffer = ByteBuffer() - minBuffer.writeVarInt64(Int64.min) - #expect(minBuffer.count == 9) - #expect(minBuffer.storage.prefix(minBuffer.count).allSatisfy { $0 == 0xFF }) - - let encoded = Array(writeBuffer.storage.prefix(writeBuffer.count)) - - let readBuffer = ByteBuffer(bytes: encoded) - for value in uintValues { - #expect(try readBuffer.readVarUInt64() == value) - } - for value in intValues { - #expect(try readBuffer.readVarInt64() == value) - } - #expect(readBuffer.remaining == 0) + let uintValues: [UInt64] = [ + 0, + 1, + 127, + 128, + 16_383, + 16_384, + 2_097_151, + 2_097_152, + 268_435_455, + 268_435_456, + 34_359_738_367, + 34_359_738_368, + 4_398_046_511_103, + 4_398_046_511_104, + 562_949_953_421_311, + 562_949_953_421_312, + 72_057_594_037_927_935, + 72_057_594_037_927_936, + UInt64(Int64.max), + UInt64.max + ] + let intValues: [Int64] = [ + Int64.min, + Int64.min + 1, + -1_000_000_000_000, + -1_000_000, + -1_000, + -128, + -1, + 0, + 1, + 127, + 1_000, + 1_000_000, + 1_000_000_000_000, + Int64.max - 1, + Int64.max + ] + + let writeBuffer = ByteBuffer() + for value in uintValues { + writeBuffer.writeVarUInt64(value) + } + for value in intValues { + writeBuffer.writeVarInt64(value) + } + let minBuffer = ByteBuffer() + minBuffer.writeVarInt64(Int64.min) + #expect(minBuffer.count == 9) + #expect(minBuffer.storage.prefix(minBuffer.count).allSatisfy { $0 == 0xFF }) + + let encoded = Array(writeBuffer.storage.prefix(writeBuffer.count)) + + let readBuffer = ByteBuffer(bytes: encoded) + for value in uintValues { + #expect(try readBuffer.readVarUInt64() == value) + } + for value in intValues { + #expect(try readBuffer.readVarInt64() == value) + } + #expect(readBuffer.remaining == 0) } @Test func metaStringEncodingRoundTrip() throws { - let encoder = MetaStringEncoder.fieldName - let decoder = MetaStringDecoder.fieldName + let encoder = MetaStringEncoder.fieldName + let decoder = MetaStringDecoder.fieldName - let lower = try encoder.encode("alpha_beta", encoding: .lowerSpecial) - #expect(lower.encoding == .lowerSpecial) - #expect(try decoder.decode(bytes: lower.bytes, encoding: lower.encoding).value == "alpha_beta") + let lower = try encoder.encode("alpha_beta", encoding: .lowerSpecial) + #expect(lower.encoding == .lowerSpecial) + #expect(try decoder.decode(bytes: lower.bytes, encoding: lower.encoding).value == "alpha_beta") - let firstLower = try encoder.encode("User_name", encoding: .firstToLowerSpecial) - #expect(firstLower.encoding == .firstToLowerSpecial) - #expect( - try decoder.decode(bytes: firstLower.bytes, encoding: firstLower.encoding).value == "User_name") + let firstLower = try encoder.encode("User_name", encoding: .firstToLowerSpecial) + #expect(firstLower.encoding == .firstToLowerSpecial) + #expect( + try decoder.decode(bytes: firstLower.bytes, encoding: firstLower.encoding).value == "User_name") - let allLower = try encoder.encode("MyHTTPType", encoding: .allToLowerSpecial) - #expect(allLower.encoding == .allToLowerSpecial) - #expect( - try decoder.decode(bytes: allLower.bytes, encoding: allLower.encoding).value == "MyHTTPType") + let allLower = try encoder.encode("MyHTTPType", encoding: .allToLowerSpecial) + #expect(allLower.encoding == .allToLowerSpecial) + #expect( + try decoder.decode(bytes: allLower.bytes, encoding: allLower.encoding).value == "MyHTTPType") - let lowerUpperDigit = try encoder.encode("userId2", encoding: .lowerUpperDigitSpecial) - #expect(lowerUpperDigit.encoding == .lowerUpperDigitSpecial) - #expect( - try decoder.decode(bytes: lowerUpperDigit.bytes, encoding: lowerUpperDigit.encoding).value - == "userId2") + let lowerUpperDigit = try encoder.encode("userId2", encoding: .lowerUpperDigitSpecial) + #expect(lowerUpperDigit.encoding == .lowerUpperDigitSpecial) + #expect( + try decoder.decode(bytes: lowerUpperDigit.bytes, encoding: lowerUpperDigit.encoding).value + == "userId2") - let autoUtf8 = try encoder.encode("naïve_meta") - #expect(autoUtf8.encoding == .utf8) - #expect( - try decoder.decode(bytes: autoUtf8.bytes, encoding: autoUtf8.encoding).value == "naïve_meta") + let autoUtf8 = try encoder.encode("naïve_meta") + #expect(autoUtf8.encoding == .utf8) + #expect( + try decoder.decode(bytes: autoUtf8.bytes, encoding: autoUtf8.encoding).value == "naïve_meta") } @Test func typeMetaRoundTripByName() throws { - let namespace = try MetaStringEncoder.namespace.encode("com.example") - let typeName = try MetaStringEncoder.typeName.encode("UserProfile") - - let fields: [TypeMeta.FieldInfo] = [ - .init( - fieldID: nil, - fieldName: "createdAt", - fieldType: .init(typeID: TypeId.varint64.rawValue, nullable: false) - ), - .init( - fieldID: nil, - fieldName: "tags", - fieldType: .init( - typeID: TypeId.list.rawValue, - nullable: false, - generics: [.init(typeID: TypeId.string.rawValue, nullable: true)] - ) - ), - .init( - fieldID: nil, - fieldName: "attributes", - fieldType: .init( - typeID: TypeId.map.rawValue, - nullable: true, - generics: [ - .init(typeID: TypeId.string.rawValue, nullable: false), - .init(typeID: TypeId.varint32.rawValue, nullable: true), - ] - ) - ), - .init( - fieldID: 7, - fieldName: "ignored_for_tag_mode", - fieldType: .init(typeID: TypeId.varint32.rawValue, nullable: false) - ), - ] - - let meta = try TypeMeta( - typeID: TypeId.namedStruct.rawValue, - userTypeID: nil, - namespace: namespace, - typeName: typeName, - registerByName: true, - fields: fields - ) - - let encoded = try meta.encode() - let decoded = try TypeMeta.decode(encoded) - - #expect(decoded.registerByName == true) - #expect(decoded.namespace.value == "com.example") - #expect(decoded.typeName.value == "UserProfile") - #expect(decoded.typeID == TypeId.namedStruct.rawValue) - #expect(decoded.userTypeID == nil) - #expect(decoded.fields.count == 4) - #expect(decoded.fields[0].fieldName == "created_at") - #expect(decoded.fields[3].fieldID == 7) + let namespace = try MetaStringEncoder.namespace.encode("com.example") + let typeName = try MetaStringEncoder.typeName.encode("UserProfile") + + let fields: [TypeMeta.FieldInfo] = [ + .init( + fieldID: nil, + fieldName: "createdAt", + fieldType: .init(typeID: TypeId.varint64.rawValue, nullable: false) + ), + .init( + fieldID: nil, + fieldName: "tags", + fieldType: .init( + typeID: TypeId.list.rawValue, + nullable: false, + generics: [.init(typeID: TypeId.string.rawValue, nullable: true)] + ) + ), + .init( + fieldID: nil, + fieldName: "attributes", + fieldType: .init( + typeID: TypeId.map.rawValue, + nullable: true, + generics: [ + .init(typeID: TypeId.string.rawValue, nullable: false), + .init(typeID: TypeId.varint32.rawValue, nullable: true) + ] + ) + ), + .init( + fieldID: 7, + fieldName: "ignored_for_tag_mode", + fieldType: .init(typeID: TypeId.varint32.rawValue, nullable: false) + ) + ] + + let meta = try TypeMeta( + typeID: TypeId.namedStruct.rawValue, + userTypeID: nil, + namespace: namespace, + typeName: typeName, + registerByName: true, + fields: fields + ) + + let encoded = try meta.encode() + let decoded = try TypeMeta.decode(encoded) + + #expect(decoded.registerByName == true) + #expect(decoded.namespace.value == "com.example") + #expect(decoded.typeName.value == "UserProfile") + #expect(decoded.typeID == TypeId.namedStruct.rawValue) + #expect(decoded.userTypeID == nil) + #expect(decoded.fields.count == 4) + #expect(decoded.fields[0].fieldName == "created_at") + #expect(decoded.fields[3].fieldID == 7) } @Test func typeMetaRoundTripByID() throws { - let emptyNamespace = MetaString.empty(specialChar1: ".", specialChar2: "_") - let emptyTypeName = MetaString.empty(specialChar1: "$", specialChar2: "_") - - let meta = try TypeMeta( - typeID: TypeId.structType.rawValue, - userTypeID: 101, - namespace: emptyNamespace, - typeName: emptyTypeName, - registerByName: false, - fields: [] - ) + let emptyNamespace = MetaString.empty(specialChar1: ".", specialChar2: "_") + let emptyTypeName = MetaString.empty(specialChar1: "$", specialChar2: "_") + + let meta = try TypeMeta( + typeID: TypeId.structType.rawValue, + userTypeID: 101, + namespace: emptyNamespace, + typeName: emptyTypeName, + registerByName: false, + fields: [] + ) - let encoded = try meta.encode() - let decoded = try TypeMeta.decode(encoded) + let encoded = try meta.encode() + let decoded = try TypeMeta.decode(encoded) - #expect(decoded.registerByName == false) - #expect(decoded.typeID == TypeId.structType.rawValue) - #expect(decoded.userTypeID == 101) - #expect(decoded.fields.isEmpty) + #expect(decoded.registerByName == false) + #expect(decoded.typeID == TypeId.structType.rawValue) + #expect(decoded.userTypeID == 101) + #expect(decoded.fields.isEmpty) } @Test func typeMetaHeaderHashIncludesHeaderLowBits() throws { - let emptyNamespace = MetaString.empty(specialChar1: ".", specialChar2: "_") - let emptyTypeName = MetaString.empty(specialChar1: "$", specialChar2: "_") - - let meta = try TypeMeta( - typeID: TypeId.structType.rawValue, - userTypeID: 102, - namespace: emptyNamespace, - typeName: emptyTypeName, - registerByName: false, - fields: [] - ) - - var encoded = try meta.encode() - let header = try ByteBuffer(bytes: encoded).readUInt64() - let hashMask = UInt64.max << 12 - let bodyOnlyHash = bodyOnlyTypeMetaHeaderHash(Array(encoded.dropFirst(8))) - #expect((header & hashMask) != bodyOnlyHash) - let rewrittenHeader = bodyOnlyHash | (header & ~hashMask) - for index in 0..<8 { - encoded[index] = UInt8(truncatingIfNeeded: rewrittenHeader >> (index * 8)) - } - - #expect(throws: ForyError.self) { - _ = try TypeMeta.decode(encoded) - } + let emptyNamespace = MetaString.empty(specialChar1: ".", specialChar2: "_") + let emptyTypeName = MetaString.empty(specialChar1: "$", specialChar2: "_") + + let meta = try TypeMeta( + typeID: TypeId.structType.rawValue, + userTypeID: 102, + namespace: emptyNamespace, + typeName: emptyTypeName, + registerByName: false, + fields: [] + ) + + var encoded = try meta.encode() + let header = try ByteBuffer(bytes: encoded).readUInt64() + let hashMask = UInt64.max << 12 + let bodyOnlyHash = bodyOnlyTypeMetaHeaderHash(Array(encoded.dropFirst(8))) + #expect((header & hashMask) != bodyOnlyHash) + let rewrittenHeader = bodyOnlyHash | (header & ~hashMask) + for index in 0..<8 { + encoded[index] = UInt8(truncatingIfNeeded: rewrittenHeader >> (index * 8)) + } + + #expect(throws: ForyError.self) { + _ = try TypeMeta.decode(encoded) + } } private func bodyOnlyTypeMetaHeaderHash(_ body: [UInt8]) -> UInt64 { - let shifted = MurmurHash3.x64_128(body, seed: 47).0 << 12 - let signed = Int64(bitPattern: shifted) - let absSigned = signed == Int64.min ? signed : Swift.abs(signed) - return UInt64(bitPattern: absSigned) & (UInt64.max << 12) + let shifted = MurmurHash3.x64_128(body, seed: 47).0 << 12 + let signed = Int64(bitPattern: shifted) + let absSigned = signed == Int64.min ? signed : Swift.abs(signed) + return UInt64(bitPattern: absSigned) & (UInt64.max << 12) } diff --git a/swift/Tests/ForyTests/GraphMemoryBudgetTests.swift b/swift/Tests/ForyTests/GraphMemoryBudgetTests.swift index 1c3f88b8f6..ed5c885d83 100644 --- a/swift/Tests/ForyTests/GraphMemoryBudgetTests.swift +++ b/swift/Tests/ForyTests/GraphMemoryBudgetTests.swift @@ -22,363 +22,363 @@ import Testing @ForyStruct private final class BudgetNode { - var id: Int32 = 0 + var id: Int32 = 0 - required init() {} + required init() {} - init(id: Int32) { - self.id = id - } + init(id: Int32) { + self.id = id + } } @ForyStruct private struct BudgetSiblings { - var left: [BudgetNode] = [] - var right: [BudgetNode] = [] + var left: [BudgetNode] = [] + var right: [BudgetNode] = [] } @ForyStruct private struct BudgetDenseHolder: Equatable { - var text: String = "" - var data: Data = Data() - @ArrayField(element: .int32()) - var dense: [Int32] = [] + var text: String = "" + var data: Data = Data() + @ArrayField(element: .int32()) + var dense: [Int32] = [] } private let defaultGraphMemoryBytes: Int64 = 128 * 1024 * 1024 private func makeBudgetFory(maxGraphMemoryBytes: Int64 = defaultGraphMemoryBytes) -> Fory { - let fory = Fory( - config: .init( - trackRef: false, - compatible: false, - maxGraphMemoryBytes: maxGraphMemoryBytes - )) - fory.register(BudgetNode.self, id: 9801) - fory.register(BudgetSiblings.self, id: 9802) - fory.register(BudgetDenseHolder.self, id: 9803) - return fory + let fory = Fory( + config: .init( + trackRef: false, + compatible: false, + maxGraphMemoryBytes: maxGraphMemoryBytes + )) + fory.register(BudgetNode.self, id: 9801) + fory.register(BudgetSiblings.self, id: 9802) + fory.register(BudgetDenseHolder.self, id: 9803) + return fory } private let testReferenceBytes = 4 private let budgetNodeGraphBytes = 1 + 4 private func elementBytes(_ type: Element.Type) -> Int { - type.isRefType ? testReferenceBytes : max(1, MemoryLayout.stride) + type.isRefType ? testReferenceBytes : max(1, MemoryLayout.stride) } private func ownerBytes(_ type: T.Type) -> Int { - max(1, MemoryLayout.stride) + max(1, MemoryLayout.stride) } private func arrayBudget(_ type: Element.Type, count: Int) -> Int { - count * elementBytes(type) + count * elementBytes(type) } private func rootArrayBudget( - _ type: Element.Type, - count: Int, - elementOwnerBytes: Int = 0 + _ type: Element.Type, + count: Int, + elementOwnerBytes: Int = 0 ) -> Int { - ownerBytes([Element].self) + arrayBudget(type, count: count) + count * elementOwnerBytes + ownerBytes([Element].self) + arrayBudget(type, count: count) + count * elementOwnerBytes } private func mapBudget( - key: Key.Type, - value: Value.Type, - count: Int + key: Key.Type, + value: Value.Type, + count: Int ) -> Int { - count * (elementBytes(key) + elementBytes(value)) + count * (elementBytes(key) + elementBytes(value)) } private func rootMapBudget( - key: Key.Type, - value: Value.Type, - count: Int + key: Key.Type, + value: Value.Type, + count: Int ) -> Int { - ownerBytes(Dictionary.self) + mapBudget(key: key, value: value, count: count) + ownerBytes(Dictionary.self) + mapBudget(key: key, value: value, count: count) } private func expectInvalidData(_ body: () throws -> Void) { - do { - try body() - Issue.record("expected invalid data") - } catch ForyError.invalidData { - } catch { - Issue.record("expected invalid data, got \(error)") - } + do { + try body() + Issue.record("expected invalid data") + } catch ForyError.invalidData { + } catch { + Issue.record("expected invalid data, got \(error)") + } } @Test func fixedDefaultBudgetAndDisable() throws { - let config = Config(trackRef: false, compatible: false) - let context = ReadContext( - buffer: ByteBuffer(), - typeResolver: TypeResolver(config: config), - config: config - ) - - try context.initGraphMemoryBudget() - try context.reserveGraphMemory(Int(defaultGraphMemoryBytes)) - expectInvalidData { - try context.reserveGraphMemory(testReferenceBytes) - } - - let disabledConfig = Config(trackRef: false, compatible: false, maxGraphMemoryBytes: 0) - let disabled = ReadContext( - buffer: ByteBuffer(), - typeResolver: TypeResolver(config: disabledConfig), - config: disabledConfig - ) - try disabled.initGraphMemoryBudget() - try disabled.reserveGraphMemory(Int(defaultGraphMemoryBytes) + 1) + let config = Config(trackRef: false, compatible: false) + let context = ReadContext( + buffer: ByteBuffer(), + typeResolver: TypeResolver(config: config), + config: config + ) + + try context.initGraphMemoryBudget() + try context.reserveGraphMemory(Int(defaultGraphMemoryBytes)) + expectInvalidData { + try context.reserveGraphMemory(testReferenceBytes) + } + + let disabledConfig = Config(trackRef: false, compatible: false, maxGraphMemoryBytes: 0) + let disabled = ReadContext( + buffer: ByteBuffer(), + typeResolver: TypeResolver(config: disabledConfig), + config: disabledConfig + ) + try disabled.initGraphMemoryBudget() + try disabled.reserveGraphMemory(Int(defaultGraphMemoryBytes) + 1) } @Test func byteBufferRootUsesFixedDefaultBudget() throws { - let count = 6 - let value = Array(repeating: [String](), count: count) - let bytes = try makeBudgetFory().serialize(value) - let buffer = ByteBuffer(data: bytes) + let count = 6 + let value = Array(repeating: [String](), count: count) + let bytes = try makeBudgetFory().serialize(value) + let buffer = ByteBuffer(data: bytes) - let decoded: [[String]] = try makeBudgetFory().deserialize(from: buffer) - #expect(decoded.count == count) + let decoded: [[String]] = try makeBudgetFory().deserialize(from: buffer) + #expect(decoded.count == count) } @Test func explicitConfigOverridesDefault() throws { - let values = (0..<16).map { "value-\($0)" } - let bytes = try makeBudgetFory().serialize(values) - let required = rootArrayBudget(String.self, count: values.count) - - expectInvalidData { - let _: [String] = try makeBudgetFory(maxGraphMemoryBytes: Int64(required - 1)).deserialize( - bytes) - } - let decoded: [String] = try makeBudgetFory(maxGraphMemoryBytes: Int64(required)).deserialize( - bytes) - #expect(decoded == values) + let values = (0..<16).map { "value-\($0)" } + let bytes = try makeBudgetFory().serialize(values) + let required = rootArrayBudget(String.self, count: values.count) + + expectInvalidData { + let _: [String] = try makeBudgetFory(maxGraphMemoryBytes: Int64(required - 1)).deserialize( + bytes) + } + let decoded: [String] = try makeBudgetFory(maxGraphMemoryBytes: Int64(required)).deserialize( + bytes) + #expect(decoded == values) } @Test func siblingContainersShareOneBudget() throws { - let value = BudgetSiblings( - left: (0..<16).map { BudgetNode(id: Int32($0)) }, - right: (16..<32).map { BudgetNode(id: Int32($0)) } - ) - let bytes = try makeBudgetFory().serialize(value) - let oneList = arrayBudget(BudgetNode.self, count: 16) + 16 * budgetNodeGraphBytes - let required = ownerBytes(BudgetSiblings.self) + oneList * 2 - - expectInvalidData { - let _: BudgetSiblings = try makeBudgetFory(maxGraphMemoryBytes: Int64(required - 1)) - .deserialize(bytes) - } - let decoded: BudgetSiblings = try makeBudgetFory(maxGraphMemoryBytes: Int64(required)) - .deserialize(bytes) - #expect(decoded.left.count == 16) - #expect(decoded.right.count == 16) + let value = BudgetSiblings( + left: (0..<16).map { BudgetNode(id: Int32($0)) }, + right: (16..<32).map { BudgetNode(id: Int32($0)) } + ) + let bytes = try makeBudgetFory().serialize(value) + let oneList = arrayBudget(BudgetNode.self, count: 16) + 16 * budgetNodeGraphBytes + let required = ownerBytes(BudgetSiblings.self) + oneList * 2 + + expectInvalidData { + let _: BudgetSiblings = try makeBudgetFory(maxGraphMemoryBytes: Int64(required - 1)) + .deserialize(bytes) + } + let decoded: BudgetSiblings = try makeBudgetFory(maxGraphMemoryBytes: Int64(required)) + .deserialize(bytes) + #expect(decoded.left.count == 16) + #expect(decoded.right.count == 16) } @Test func mapBudgetIsCharged() throws { - let value: [String: Int32] = ["a": 1, "b": 2, "c": 3] - let bytes = try makeBudgetFory().serialize(value) - let required = rootMapBudget(key: String.self, value: Int32.self, count: value.count) - - expectInvalidData { - let _: [String: Int32] = try makeBudgetFory(maxGraphMemoryBytes: Int64(required - 1)) - .deserialize(bytes) - } - let decoded: [String: Int32] = try makeBudgetFory(maxGraphMemoryBytes: Int64(required)) - .deserialize(bytes) - #expect(decoded == value) + let value: [String: Int32] = ["a": 1, "b": 2, "c": 3] + let bytes = try makeBudgetFory().serialize(value) + let required = rootMapBudget(key: String.self, value: Int32.self, count: value.count) + + expectInvalidData { + let _: [String: Int32] = try makeBudgetFory(maxGraphMemoryBytes: Int64(required - 1)) + .deserialize(bytes) + } + let decoded: [String: Int32] = try makeBudgetFory(maxGraphMemoryBytes: Int64(required)) + .deserialize(bytes) + #expect(decoded == value) } @Test func referenceAndInlineValueArraysAreCharged() throws { - let nodes = (0..<4).map { BudgetNode(id: Int32($0)) } - let nodeBytes = try makeBudgetFory().serialize(nodes) - let nodeBudget = rootArrayBudget( - BudgetNode.self, - count: nodes.count, - elementOwnerBytes: budgetNodeGraphBytes - ) - expectInvalidData { - let _: [BudgetNode] = try makeBudgetFory(maxGraphMemoryBytes: Int64(nodeBudget - 1)) - .deserialize(nodeBytes) - } - let decodedNodes: [BudgetNode] = try makeBudgetFory(maxGraphMemoryBytes: Int64(nodeBudget)) - .deserialize(nodeBytes) - #expect(decodedNodes.count == nodes.count) - - let ints: [Int32] = [1, 2, 3, 4] - let intBytes = try makeBudgetFory().serialize(ints) - let intBudget = rootArrayBudget(Int32.self, count: ints.count) - expectInvalidData { - let _: [Int32] = try makeBudgetFory(maxGraphMemoryBytes: Int64(intBudget - 1)) - .deserialize(intBytes) - } - #expect(try makeBudgetFory(maxGraphMemoryBytes: Int64(intBudget)).deserialize(intBytes) == ints) + let nodes = (0..<4).map { BudgetNode(id: Int32($0)) } + let nodeBytes = try makeBudgetFory().serialize(nodes) + let nodeBudget = rootArrayBudget( + BudgetNode.self, + count: nodes.count, + elementOwnerBytes: budgetNodeGraphBytes + ) + expectInvalidData { + let _: [BudgetNode] = try makeBudgetFory(maxGraphMemoryBytes: Int64(nodeBudget - 1)) + .deserialize(nodeBytes) + } + let decodedNodes: [BudgetNode] = try makeBudgetFory(maxGraphMemoryBytes: Int64(nodeBudget)) + .deserialize(nodeBytes) + #expect(decodedNodes.count == nodes.count) + + let ints: [Int32] = [1, 2, 3, 4] + let intBytes = try makeBudgetFory().serialize(ints) + let intBudget = rootArrayBudget(Int32.self, count: ints.count) + expectInvalidData { + let _: [Int32] = try makeBudgetFory(maxGraphMemoryBytes: Int64(intBudget - 1)) + .deserialize(intBytes) + } + #expect(try makeBudgetFory(maxGraphMemoryBytes: Int64(intBudget)).deserialize(intBytes) == ints) } @Test func setConversionOwnerChargedOnce() throws { - let values: Set = [1, 2, 3] - let bytes = try makeBudgetFory().serialize(values) - let required = ownerBytes(Set.self) + arrayBudget(Int32.self, count: values.count) - - expectInvalidData { - let _: Set = try makeBudgetFory(maxGraphMemoryBytes: Int64(required - 1)) - .deserialize(bytes) - } - let decoded: Set = try makeBudgetFory(maxGraphMemoryBytes: Int64(required)) - .deserialize(bytes) - #expect(decoded == values) + let values: Set = [1, 2, 3] + let bytes = try makeBudgetFory().serialize(values) + let required = ownerBytes(Set.self) + arrayBudget(Int32.self, count: values.count) + + expectInvalidData { + let _: Set = try makeBudgetFory(maxGraphMemoryBytes: Int64(required - 1)) + .deserialize(bytes) + } + let decoded: Set = try makeBudgetFory(maxGraphMemoryBytes: Int64(required)) + .deserialize(bytes) + #expect(decoded == values) } @Test func stringBinaryAndPrimitiveDenseArrayOwnersAreSkipped() throws { - let value = BudgetDenseHolder( - text: "budget", - data: Data([1, 2, 3]), - dense: [1, 2, 3] - ) - let bytes = try makeBudgetFory().serialize(value) - let required = ownerBytes(BudgetDenseHolder.self) - - expectInvalidData { - let _: BudgetDenseHolder = try makeBudgetFory(maxGraphMemoryBytes: Int64(required - 1)) - .deserialize(bytes) - } - let decoded: BudgetDenseHolder = try makeBudgetFory(maxGraphMemoryBytes: Int64(required)) - .deserialize(bytes) - #expect(decoded == value) + let value = BudgetDenseHolder( + text: "budget", + data: Data([1, 2, 3]), + dense: [1, 2, 3] + ) + let bytes = try makeBudgetFory().serialize(value) + let required = ownerBytes(BudgetDenseHolder.self) + + expectInvalidData { + let _: BudgetDenseHolder = try makeBudgetFory(maxGraphMemoryBytes: Int64(required - 1)) + .deserialize(bytes) + } + let decoded: BudgetDenseHolder = try makeBudgetFory(maxGraphMemoryBytes: Int64(required)) + .deserialize(bytes) + #expect(decoded == value) } @Test func dynamicAnyEmptyMapOwnerSelf() throws { - let value = [:] as [AnyHashable: Any] - let bytes = try makeBudgetFory().serialize(value as Any) - let required = - ownerBytes(Dictionary.self) - + ownerBytes(Dictionary.self) - - expectInvalidData { - let _: Any = try makeBudgetFory(maxGraphMemoryBytes: Int64(required - 1)) - .deserialize(bytes) - } - let decoded: Any = try makeBudgetFory(maxGraphMemoryBytes: Int64(required)) - .deserialize(bytes) - #expect((decoded as? [String: Any])?.isEmpty == true) + let value = [:] as [AnyHashable: Any] + let bytes = try makeBudgetFory().serialize(value as Any) + let required = + ownerBytes(Dictionary.self) + + ownerBytes(Dictionary.self) + + expectInvalidData { + let _: Any = try makeBudgetFory(maxGraphMemoryBytes: Int64(required - 1)) + .deserialize(bytes) + } + let decoded: Any = try makeBudgetFory(maxGraphMemoryBytes: Int64(required)) + .deserialize(bytes) + #expect((decoded as? [String: Any])?.isEmpty == true) } @Test func publicAnyArrayBudget() throws { - let value: [Any] = [Int32(1), Int32(2), Int32(3)] - let bytes = try makeBudgetFory().serialize(value) - let wrappedBudget = arrayBudget(SerializableAny.self, count: value.count) - let finalBudget = ownerBytes([Any].self) + value.count * testReferenceBytes - - expectInvalidData { - let _: [Any] = try makeBudgetFory(maxGraphMemoryBytes: Int64(wrappedBudget)) - .deserialize(bytes, as: [Any].self) - } - let decoded = try makeBudgetFory(maxGraphMemoryBytes: Int64(wrappedBudget + finalBudget)) - .deserialize(bytes, as: [Any].self) - #expect(decoded.count == value.count) + let value: [Any] = [Int32(1), Int32(2), Int32(3)] + let bytes = try makeBudgetFory().serialize(value) + let wrappedBudget = arrayBudget(SerializableAny.self, count: value.count) + let finalBudget = ownerBytes([Any].self) + value.count * testReferenceBytes + + expectInvalidData { + let _: [Any] = try makeBudgetFory(maxGraphMemoryBytes: Int64(wrappedBudget)) + .deserialize(bytes, as: [Any].self) + } + let decoded = try makeBudgetFory(maxGraphMemoryBytes: Int64(wrappedBudget + finalBudget)) + .deserialize(bytes, as: [Any].self) + #expect(decoded.count == value.count) } @Test func publicAnyMapBudget() throws { - let stringMap: [String: Any] = ["a": Int32(1), "b": Int32(2), "c": Int32(3)] - let stringBytes = try makeBudgetFory().serialize(stringMap) - let stringWrapped = mapBudget( - key: String.self, - value: SerializableAny.self, - count: stringMap.count - ) - let stringFinal = - ownerBytes(Dictionary.self) + stringMap.count * 2 * testReferenceBytes - expectInvalidData { - let _: [String: Any] = try makeBudgetFory(maxGraphMemoryBytes: Int64(stringWrapped)) - .deserialize(stringBytes, as: [String: Any].self) - } - let decodedString = try makeBudgetFory(maxGraphMemoryBytes: Int64(stringWrapped + stringFinal)) - .deserialize(stringBytes, as: [String: Any].self) - #expect(decodedString.count == stringMap.count) - - let intMap: [Int32: Any] = [1: Int32(10), 2: Int32(20), 3: Int32(30)] - let intBytes = try makeBudgetFory().serialize(intMap) - let intWrapped = mapBudget( - key: Int32.self, - value: SerializableAny.self, - count: intMap.count - ) - let intFinal = ownerBytes(Dictionary.self) + intMap.count * 2 * testReferenceBytes - expectInvalidData { - let _: [Int32: Any] = try makeBudgetFory(maxGraphMemoryBytes: Int64(intWrapped)) - .deserialize(intBytes, as: [Int32: Any].self) - } - let decodedInt = try makeBudgetFory(maxGraphMemoryBytes: Int64(intWrapped + intFinal)) - .deserialize(intBytes, as: [Int32: Any].self) - #expect(decodedInt.count == intMap.count) - - let anyHashableMap: [AnyHashable: Any] = [ - AnyHashable("a"): Int32(1), - AnyHashable(Int32(2)): Int32(2), - AnyHashable(true): Int32(3), - ] - let anyHashableBytes = try makeBudgetFory().serialize(anyHashableMap) - let anyHashableWrapped = mapBudget( - key: AnyHashable.self, - value: SerializableAny.self, - count: anyHashableMap.count - ) - let anyHashableFinal = - ownerBytes(Dictionary.self) + anyHashableMap.count * 2 * testReferenceBytes - expectInvalidData { - let _: [AnyHashable: Any] = try makeBudgetFory( - maxGraphMemoryBytes: Int64(anyHashableWrapped) + let stringMap: [String: Any] = ["a": Int32(1), "b": Int32(2), "c": Int32(3)] + let stringBytes = try makeBudgetFory().serialize(stringMap) + let stringWrapped = mapBudget( + key: String.self, + value: SerializableAny.self, + count: stringMap.count + ) + let stringFinal = + ownerBytes(Dictionary.self) + stringMap.count * 2 * testReferenceBytes + expectInvalidData { + let _: [String: Any] = try makeBudgetFory(maxGraphMemoryBytes: Int64(stringWrapped)) + .deserialize(stringBytes, as: [String: Any].self) + } + let decodedString = try makeBudgetFory(maxGraphMemoryBytes: Int64(stringWrapped + stringFinal)) + .deserialize(stringBytes, as: [String: Any].self) + #expect(decodedString.count == stringMap.count) + + let intMap: [Int32: Any] = [1: Int32(10), 2: Int32(20), 3: Int32(30)] + let intBytes = try makeBudgetFory().serialize(intMap) + let intWrapped = mapBudget( + key: Int32.self, + value: SerializableAny.self, + count: intMap.count + ) + let intFinal = ownerBytes(Dictionary.self) + intMap.count * 2 * testReferenceBytes + expectInvalidData { + let _: [Int32: Any] = try makeBudgetFory(maxGraphMemoryBytes: Int64(intWrapped)) + .deserialize(intBytes, as: [Int32: Any].self) + } + let decodedInt = try makeBudgetFory(maxGraphMemoryBytes: Int64(intWrapped + intFinal)) + .deserialize(intBytes, as: [Int32: Any].self) + #expect(decodedInt.count == intMap.count) + + let anyHashableMap: [AnyHashable: Any] = [ + AnyHashable("a"): Int32(1), + AnyHashable(Int32(2)): Int32(2), + AnyHashable(true): Int32(3) + ] + let anyHashableBytes = try makeBudgetFory().serialize(anyHashableMap) + let anyHashableWrapped = mapBudget( + key: AnyHashable.self, + value: SerializableAny.self, + count: anyHashableMap.count + ) + let anyHashableFinal = + ownerBytes(Dictionary.self) + anyHashableMap.count * 2 * testReferenceBytes + expectInvalidData { + let _: [AnyHashable: Any] = try makeBudgetFory( + maxGraphMemoryBytes: Int64(anyHashableWrapped) + ).deserialize(anyHashableBytes, as: [AnyHashable: Any].self) + } + let decodedAnyHashable = try makeBudgetFory( + maxGraphMemoryBytes: Int64(anyHashableWrapped + anyHashableFinal) ).deserialize(anyHashableBytes, as: [AnyHashable: Any].self) - } - let decodedAnyHashable = try makeBudgetFory( - maxGraphMemoryBytes: Int64(anyHashableWrapped + anyHashableFinal) - ).deserialize(anyHashableBytes, as: [AnyHashable: Any].self) - #expect(decodedAnyHashable.count == anyHashableMap.count) + #expect(decodedAnyHashable.count == anyHashableMap.count) } @Test func dynamicAnyArrayBudget() throws { - let list: [Any] = [Int32(1), "two", Int32(3)] - let value: Any = list - let bytes = try makeBudgetFory().serialize(value) - let count = list.count - let wrappedBudget = arrayBudget(SerializableAny.self, count: count) - let finalBudget = ownerBytes([Any].self) + count * testReferenceBytes - - expectInvalidData { - let _: Any = try makeBudgetFory(maxGraphMemoryBytes: Int64(wrappedBudget)) - .deserialize(bytes, as: Any.self) - } - let decoded = try makeBudgetFory(maxGraphMemoryBytes: Int64(wrappedBudget + finalBudget)) - .deserialize(bytes, as: Any.self) - #expect((decoded as? [Any])?.count == count) + let list: [Any] = [Int32(1), "two", Int32(3)] + let value: Any = list + let bytes = try makeBudgetFory().serialize(value) + let count = list.count + let wrappedBudget = arrayBudget(SerializableAny.self, count: count) + let finalBudget = ownerBytes([Any].self) + count * testReferenceBytes + + expectInvalidData { + let _: Any = try makeBudgetFory(maxGraphMemoryBytes: Int64(wrappedBudget)) + .deserialize(bytes, as: Any.self) + } + let decoded = try makeBudgetFory(maxGraphMemoryBytes: Int64(wrappedBudget + finalBudget)) + .deserialize(bytes, as: Any.self) + #expect((decoded as? [Any])?.count == count) } @Test func byteAvailabilityCheckStillRejectsLargeLength() throws { - let buffer = ByteBuffer() - buffer.writeVarUInt32(64) - buffer.writeUInt8(CollectionHeader.sameType | CollectionHeader.declaredElementType) - let config = Config(trackRef: false, compatible: false) - let context = ReadContext( - buffer: buffer, - typeResolver: TypeResolver(config: config), - config: config - ) - - expectInvalidData { - let _: [String] = try [String].foryReadData(context) - } + let buffer = ByteBuffer() + buffer.writeVarUInt32(64) + buffer.writeUInt8(CollectionHeader.sameType | CollectionHeader.declaredElementType) + let config = Config(trackRef: false, compatible: false) + let context = ReadContext( + buffer: buffer, + typeResolver: TypeResolver(config: config), + config: config + ) + + expectInvalidData { + let _: [String] = try [String].foryReadData(context) + } } From 51bf5f91c8f52a49621ebde1f2963c8d9783c8b4 Mon Sep 17 00:00:00 2001 From: chaokunyang Date: Thu, 2 Jul 2026 23:59:10 +0800 Subject: [PATCH 28/54] style(js): remove unrelated operator formatting diff --- javascript/packages/core/lib/type.ts | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/javascript/packages/core/lib/type.ts b/javascript/packages/core/lib/type.ts index 1a58b2a3e8..eb65815987 100644 --- a/javascript/packages/core/lib/type.ts +++ b/javascript/packages/core/lib/type.ts @@ -174,12 +174,12 @@ export const TypeId = { }, userDefinedType(id: number) { return ( - this.structType(id) - || this.extType(id) - || this.enumType(id) - || id == TypeId.UNION - || id == TypeId.TYPED_UNION - || id == TypeId.NAMED_UNION + this.structType(id) || + this.extType(id) || + this.enumType(id) || + id == TypeId.UNION || + id == TypeId.TYPED_UNION || + id == TypeId.NAMED_UNION ); }, isBuiltin(id: number) { From e765fad754a0c6234d801a73f8501d0df1c74ed9 Mon Sep 17 00:00:00 2001 From: chaokunyang Date: Fri, 3 Jul 2026 00:08:39 +0800 Subject: [PATCH 29/54] style(js): apply javascript formatter --- javascript/package-lock.json | 277 ++---------- javascript/package.json | 11 +- javascript/packages/core/lib/context.ts | 394 ++++++------------ javascript/packages/core/lib/fory.ts | 67 +-- .../packages/core/lib/gen/collection.ts | 98 ++--- javascript/packages/core/lib/gen/ext.ts | 36 +- javascript/packages/core/lib/gen/map.ts | 140 +++---- javascript/packages/core/lib/gen/struct.ts | 288 ++++--------- javascript/test/graphMemoryBudget.test.ts | 58 +-- 9 files changed, 389 insertions(+), 980 deletions(-) diff --git a/javascript/package-lock.json b/javascript/package-lock.json index c095ab19a2..1674116f60 100644 --- a/javascript/package-lock.json +++ b/javascript/package-lock.json @@ -9,13 +9,14 @@ "packages/core" ], "devDependencies": { - "@stylistic/eslint-plugin": "^1.5.1", "@types/js-beautify": "^1.14.3", - "@types/node": "^18.19.68", + "@types/node": "18.19.130", "eslint": "^8.55.0", + "eslint-config-prettier": "^10.1.8", "jest": "^29.5.0", "jest-junit": "^17.0.0", "js-beautify": "^1.14.11", + "prettier": "^3.9.4", "ts-jest": "^29.0.2", "typescript": "^4.8.4" } @@ -1449,97 +1450,6 @@ "@sinonjs/commons": "^3.0.0" } }, - "node_modules/@stylistic/eslint-plugin": { - "version": "1.8.1", - "resolved": "https://registry.npmjs.org/@stylistic/eslint-plugin/-/eslint-plugin-1.8.1.tgz", - "integrity": "sha512-64My6I7uCcmSQ//427Pfg2vjSf9SDzfsGIWohNFgISMLYdC5BzJqDo647iDDJzSxINh3WTC0Ql46ifiKuOoTyA==", - "dev": true, - "license": "MIT", - "dependencies": { - "@stylistic/eslint-plugin-js": "1.8.1", - "@stylistic/eslint-plugin-jsx": "1.8.1", - "@stylistic/eslint-plugin-plus": "1.8.1", - "@stylistic/eslint-plugin-ts": "1.8.1", - "@types/eslint": "^8.56.10" - }, - "engines": { - "node": "^16.0.0 || >=18.0.0" - }, - "peerDependencies": { - "eslint": ">=8.40.0" - } - }, - "node_modules/@stylistic/eslint-plugin-js": { - "version": "1.8.1", - "resolved": "https://registry.npmjs.org/@stylistic/eslint-plugin-js/-/eslint-plugin-js-1.8.1.tgz", - "integrity": "sha512-c5c2C8Mos5tTQd+NWpqwEu7VT6SSRooAguFPMj1cp2RkTYl1ynKoXo8MWy3k4rkbzoeYHrqC2UlUzsroAN7wtQ==", - "dev": true, - "license": "MIT", - "dependencies": { - "@types/eslint": "^8.56.10", - "acorn": "^8.11.3", - "escape-string-regexp": "^4.0.0", - "eslint-visitor-keys": "^3.4.3", - "espree": "^9.6.1" - }, - "engines": { - "node": "^16.0.0 || >=18.0.0" - }, - "peerDependencies": { - "eslint": ">=8.40.0" - } - }, - "node_modules/@stylistic/eslint-plugin-jsx": { - "version": "1.8.1", - "resolved": "https://registry.npmjs.org/@stylistic/eslint-plugin-jsx/-/eslint-plugin-jsx-1.8.1.tgz", - "integrity": "sha512-k1Eb6rcjMP+mmjvj+vd9y5KUdWn1OBkkPLHXhsrHt5lCDFZxJEs0aVQzE5lpYrtVZVkpc5esTtss/cPJux0lfA==", - "dev": true, - "license": "MIT", - "dependencies": { - "@stylistic/eslint-plugin-js": "^1.8.1", - "@types/eslint": "^8.56.10", - "estraverse": "^5.3.0", - "picomatch": "^4.0.2" - }, - "engines": { - "node": "^16.0.0 || >=18.0.0" - }, - "peerDependencies": { - "eslint": ">=8.40.0" - } - }, - "node_modules/@stylistic/eslint-plugin-plus": { - "version": "1.8.1", - "resolved": "https://registry.npmjs.org/@stylistic/eslint-plugin-plus/-/eslint-plugin-plus-1.8.1.tgz", - "integrity": "sha512-4+40H3lHYTN8OWz+US8CamVkO+2hxNLp9+CAjorI7top/lHqemhpJvKA1LD9Uh+WMY9DYWiWpL2+SZ2wAXY9fQ==", - "dev": true, - "license": "MIT", - "dependencies": { - "@types/eslint": "^8.56.10", - "@typescript-eslint/utils": "^6.21.0" - }, - "peerDependencies": { - "eslint": "*" - } - }, - "node_modules/@stylistic/eslint-plugin-ts": { - "version": "1.8.1", - "resolved": "https://registry.npmjs.org/@stylistic/eslint-plugin-ts/-/eslint-plugin-ts-1.8.1.tgz", - "integrity": "sha512-/q1m+ZuO1JHfiSF16EATFzv7XSJkc5W6DocfvH5o9oB6WWYFMF77fVoBWnKT3wGptPOc2hkRupRKhmeFROdfWA==", - "dev": true, - "license": "MIT", - "dependencies": { - "@stylistic/eslint-plugin-js": "1.8.1", - "@types/eslint": "^8.56.10", - "@typescript-eslint/utils": "^6.21.0" - }, - "engines": { - "node": "^16.0.0 || >=18.0.0" - }, - "peerDependencies": { - "eslint": ">=8.40.0" - } - }, "node_modules/@types/babel__core": { "version": "7.20.5", "resolved": "https://registry.npmjs.org/@types/babel__core/-/babel__core-7.20.5.tgz", @@ -1585,24 +1495,6 @@ "@babel/types": "^7.28.2" } }, - "node_modules/@types/eslint": { - "version": "8.56.12", - "resolved": "https://registry.npmjs.org/@types/eslint/-/eslint-8.56.12.tgz", - "integrity": "sha512-03ruubjWyOHlmljCVoxSuNDdmfZDzsrrz0P2LeJsOXr+ZwFQ+0yQIwNCwt/GYhV7Z31fgtXJTAEs+FYlEL851g==", - "dev": true, - "license": "MIT", - "dependencies": { - "@types/estree": "*", - "@types/json-schema": "*" - } - }, - "node_modules/@types/estree": { - "version": "1.0.9", - "resolved": "https://registry.npmjs.org/@types/estree/-/estree-1.0.9.tgz", - "integrity": "sha512-GhdPgy1el4/ImP05X05Uw4cw2/M93BCUmnEvWZNStlCzEKME4Fkk+YpoA5OiHNQmoS7Cafb8Xa3Pya8m1Qrzeg==", - "dev": true, - "license": "MIT" - }, "node_modules/@types/graceful-fs": { "version": "4.1.9", "resolved": "https://registry.npmjs.org/@types/graceful-fs/-/graceful-fs-4.1.9.tgz", @@ -1655,9 +1547,9 @@ "license": "MIT" }, "node_modules/@types/node": { - "version": "18.19.68", - "resolved": "https://registry.npmjs.org/@types/node/-/node-18.19.68.tgz", - "integrity": "sha512-QGtpFH1vB99ZmTa63K4/FU8twThj4fuVSBkGddTp7uIL/cuoLWIUSL2RcOaigBhfR+hg5pgGkBnkoOxrTVBMKw==", + "version": "18.19.130", + "resolved": "https://registry.npmjs.org/@types/node/-/node-18.19.130.tgz", + "integrity": "sha512-GRaXQx6jGfL8sKfaIDD6OupbIHBr9jv7Jnaml9tB7l4v068PAOXqfcujMMo5PhbIs6ggR1XODELqahT2R8v0fg==", "dev": true, "license": "MIT", "dependencies": { @@ -1965,24 +1857,6 @@ "url": "https://opencollective.com/typescript-eslint" } }, - "node_modules/@typescript-eslint/scope-manager": { - "version": "6.21.0", - "resolved": "https://registry.npmjs.org/@typescript-eslint/scope-manager/-/scope-manager-6.21.0.tgz", - "integrity": "sha512-OwLUIWZJry80O99zvqXVEioyniJMa+d2GrqpUTqi5/v5D5rOrppJVBPa0yKCblcigC0/aYAzxxqQ1B+DS2RYsg==", - "dev": true, - "license": "MIT", - "dependencies": { - "@typescript-eslint/types": "6.21.0", - "@typescript-eslint/visitor-keys": "6.21.0" - }, - "engines": { - "node": "^16.0.0 || >=18.0.0" - }, - "funding": { - "type": "opencollective", - "url": "https://opencollective.com/typescript-eslint" - } - }, "node_modules/@typescript-eslint/type-utils": { "version": "5.62.0", "resolved": "https://registry.npmjs.org/@typescript-eslint/type-utils/-/type-utils-5.62.0.tgz", @@ -2140,93 +2014,6 @@ "node": ">=4.0" } }, - "node_modules/@typescript-eslint/types": { - "version": "6.21.0", - "resolved": "https://registry.npmjs.org/@typescript-eslint/types/-/types-6.21.0.tgz", - "integrity": "sha512-1kFmZ1rOm5epu9NZEZm1kckCDGj5UJEf7P1kliH4LKu/RkwpsfqqGmY2OOcUs18lSlQBKLDYBOGxRVtrMN5lpg==", - "dev": true, - "license": "MIT", - "engines": { - "node": "^16.0.0 || >=18.0.0" - }, - "funding": { - "type": "opencollective", - "url": "https://opencollective.com/typescript-eslint" - } - }, - "node_modules/@typescript-eslint/typescript-estree": { - "version": "6.21.0", - "resolved": "https://registry.npmjs.org/@typescript-eslint/typescript-estree/-/typescript-estree-6.21.0.tgz", - "integrity": "sha512-6npJTkZcO+y2/kr+z0hc4HwNfrrP4kNYh57ek7yCNlrBjWQ1Y0OS7jiZTkgumrvkX5HkEKXFZkkdFNkaW2wmUQ==", - "dev": true, - "license": "BSD-2-Clause", - "dependencies": { - "@typescript-eslint/types": "6.21.0", - "@typescript-eslint/visitor-keys": "6.21.0", - "debug": "^4.3.4", - "globby": "^11.1.0", - "is-glob": "^4.0.3", - "minimatch": "9.0.3", - "semver": "^7.5.4", - "ts-api-utils": "^1.0.1" - }, - "engines": { - "node": "^16.0.0 || >=18.0.0" - }, - "funding": { - "type": "opencollective", - "url": "https://opencollective.com/typescript-eslint" - }, - "peerDependenciesMeta": { - "typescript": { - "optional": true - } - } - }, - "node_modules/@typescript-eslint/utils": { - "version": "6.21.0", - "resolved": "https://registry.npmjs.org/@typescript-eslint/utils/-/utils-6.21.0.tgz", - "integrity": "sha512-NfWVaC8HP9T8cbKQxHcsJBY5YE1O33+jpMwN45qzWWaPDZgLIbo12toGMWnmhvCpd3sIxkpDw3Wv1B3dYrbDQQ==", - "dev": true, - "license": "MIT", - "dependencies": { - "@eslint-community/eslint-utils": "^4.4.0", - "@types/json-schema": "^7.0.12", - "@types/semver": "^7.5.0", - "@typescript-eslint/scope-manager": "6.21.0", - "@typescript-eslint/types": "6.21.0", - "@typescript-eslint/typescript-estree": "6.21.0", - "semver": "^7.5.4" - }, - "engines": { - "node": "^16.0.0 || >=18.0.0" - }, - "funding": { - "type": "opencollective", - "url": "https://opencollective.com/typescript-eslint" - }, - "peerDependencies": { - "eslint": "^7.0.0 || ^8.0.0" - } - }, - "node_modules/@typescript-eslint/visitor-keys": { - "version": "6.21.0", - "resolved": "https://registry.npmjs.org/@typescript-eslint/visitor-keys/-/visitor-keys-6.21.0.tgz", - "integrity": "sha512-JJtkDduxLi9bivAB+cYOVMtbkqdPOhZ+ZI5LC47MIRrDV4Yn2o+ZnW10Nkmr28xRpSpdJ6Sm42Hjf2+REYXm0A==", - "dev": true, - "license": "MIT", - "dependencies": { - "@typescript-eslint/types": "6.21.0", - "eslint-visitor-keys": "^3.4.1" - }, - "engines": { - "node": "^16.0.0 || >=18.0.0" - }, - "funding": { - "type": "opencollective", - "url": "https://opencollective.com/typescript-eslint" - } - }, "node_modules/@ungap/structured-clone": { "version": "1.3.1", "resolved": "https://registry.npmjs.org/@ungap/structured-clone/-/structured-clone-1.3.1.tgz", @@ -3181,6 +2968,22 @@ "url": "https://opencollective.com/eslint" } }, + "node_modules/eslint-config-prettier": { + "version": "10.1.8", + "resolved": "https://registry.npmjs.org/eslint-config-prettier/-/eslint-config-prettier-10.1.8.tgz", + "integrity": "sha512-82GZUjRS0p/jganf6q1rEO25VSoHH0hKPCTrgillPjdI/3bgBhAE1QzHrHTizjpRvy6pGAvKjDJtk2pF9NDq8w==", + "dev": true, + "license": "MIT", + "bin": { + "eslint-config-prettier": "bin/cli.js" + }, + "funding": { + "url": "https://opencollective.com/eslint-config-prettier" + }, + "peerDependencies": { + "eslint": ">=7.0.0" + } + }, "node_modules/eslint-scope": { "version": "7.2.2", "resolved": "https://registry.npmjs.org/eslint-scope/-/eslint-scope-7.2.2.tgz", @@ -5788,6 +5591,22 @@ "node": ">= 0.8.0" } }, + "node_modules/prettier": { + "version": "3.9.4", + "resolved": "https://registry.npmjs.org/prettier/-/prettier-3.9.4.tgz", + "integrity": "sha512-yWG/o/4oJfo036EKAfK6ACAoDOfHeRHx4tuxkfBZiauURiaSmYwlpOr5LQqKtIkRD2z1PLteme2WoxEnj4tHTg==", + "dev": true, + "license": "MIT", + "bin": { + "prettier": "bin/prettier.cjs" + }, + "engines": { + "node": ">=14" + }, + "funding": { + "url": "https://github.com/prettier/prettier?sponsor=1" + } + }, "node_modules/pretty-format": { "version": "29.7.0", "resolved": "https://registry.npmjs.org/pretty-format/-/pretty-format-29.7.0.tgz", @@ -6474,19 +6293,6 @@ "node": ">=8.0" } }, - "node_modules/ts-api-utils": { - "version": "1.4.3", - "resolved": "https://registry.npmjs.org/ts-api-utils/-/ts-api-utils-1.4.3.tgz", - "integrity": "sha512-i3eMG77UTMD0hZhgRS562pv83RC6ukSAC2GMNWc+9dieh/+jDM5u5YG+NHX6VNDRHQcHwmsTHctP9LhbC3WxVw==", - "dev": true, - "license": "MIT", - "engines": { - "node": ">=16" - }, - "peerDependencies": { - "typescript": ">=4.2.0" - } - }, "node_modules/ts-jest": { "version": "29.4.11", "resolved": "https://registry.npmjs.org/ts-jest/-/ts-jest-29.4.11.tgz", @@ -6951,6 +6757,13 @@ "undici-types": "~5.26.4" } }, + "packages/core/node_modules/undici-types": { + "version": "5.26.5", + "resolved": "https://registry.npmjs.org/undici-types/-/undici-types-5.26.5.tgz", + "integrity": "sha512-JlCMO+ehdEIKqlFxk6IfVoAUVmgz7cU7zD/h9XZ0qzeosSHmUJVOzSQvvYSYWXkFXC+IfLKSIffhv0sVZup6pA==", + "dev": true, + "license": "MIT" + }, "packages/hps": { "name": "@apache-fory/hps", "version": "1.4.0-alpha.0", diff --git a/javascript/package.json b/javascript/package.json index a856cd827a..bf4ae68e47 100644 --- a/javascript/package.json +++ b/javascript/package.json @@ -4,8 +4,10 @@ "test": "npm run build && jest", "clear": "rm -rf ./packages/core/dist && rm -rf ./packages/hps/dist", "build": "npm run clear && npm run build -w packages/core -w packages/hps", - "lint": "eslint .", - "lint-fix": "eslint . --fix" + "lint": "npm run format-check", + "lint-fix": "npm run format", + "format": "prettier --write \"{packages,test}/**/*.ts\" && eslint . --fix", + "format-check": "prettier --check \"{packages,test}/**/*.ts\" && eslint ." }, "repository": "git@github.com:apache/fory.git", "workspaces": [ @@ -13,13 +15,14 @@ "packages/core" ], "devDependencies": { - "@stylistic/eslint-plugin": "^1.5.1", "@types/js-beautify": "^1.14.3", - "@types/node": "^18.19.68", + "@types/node": "18.19.130", "eslint": "^8.55.0", + "eslint-config-prettier": "^10.1.8", "jest": "^29.5.0", "jest-junit": "^17.0.0", "js-beautify": "^1.14.11", + "prettier": "^3.9.4", "ts-jest": "^29.0.2", "typescript": "^4.8.4" }, diff --git a/javascript/packages/core/lib/context.ts b/javascript/packages/core/lib/context.ts index 69c21f6a5e..60f51de71e 100644 --- a/javascript/packages/core/lib/context.ts +++ b/javascript/packages/core/lib/context.ts @@ -19,11 +19,7 @@ import { BinaryReader } from "./reader"; import { BinaryWriter } from "./writer"; -import { - MetaString, - MetaStringDecoder, - MetaStringEncoder, -} from "./meta/MetaString"; +import { MetaString, MetaStringDecoder, MetaStringEncoder } from "./meta/MetaString"; import { InnerFieldInfo, TypeMeta } from "./meta/TypeMeta"; import { Type, TypeInfo } from "./typeInfo"; import { Config, RefFlags, Serializer, TypeId } from "./type"; @@ -52,9 +48,7 @@ type CompatibleReadSerializerCacheEntry = { serializer: Serializer; }; -function remoteListElementType( - fieldInfo: InnerFieldInfo, -): InnerFieldInfo | undefined { +function remoteListElementType(fieldInfo: InnerFieldInfo): InnerFieldInfo | undefined { if (fieldInfo.typeId !== TypeId.LIST) { return undefined; } @@ -541,18 +535,14 @@ export class ReadContext { private typeMetaCache: Map = new Map(); private totalAcceptedSchemaVersions = 0; private cachedTypeMeta: TypeMeta | undefined; - private compatibleReadSerializers = new Map< - number, - CompatibleReadSerializerCacheEntry - >(); + private compatibleReadSerializers = new Map(); private _depth = 0; private _maxDepth: number; private readonly maxGraphMemoryBytes: number; private effectiveGraphMemoryBytes = 0; private remainingGraphMemoryBytes = 0; - private remoteSchemaVersionsByType: Map | undefined - = undefined; + private remoteSchemaVersionsByType: Map | undefined = undefined; constructor( readonly typeResolver: TypeResolverLike, @@ -571,8 +561,7 @@ export class ReadContext { this.metaStringReader.reset(); this.typeMeta = []; this._depth = 0; - this.effectiveGraphMemoryBytes - = this.maxGraphMemoryBytes > 0 ? this.maxGraphMemoryBytes : 0; + this.effectiveGraphMemoryBytes = this.maxGraphMemoryBytes > 0 ? this.maxGraphMemoryBytes : 0; this.remainingGraphMemoryBytes = this.effectiveGraphMemoryBytes; } @@ -591,16 +580,14 @@ export class ReadContext { } private throwGraphMemoryOverflow(bytes: number): never { - throw new Error( - `maxGraphMemoryBytes overflow: requested ${bytes} estimated graph bytes`, - ); + throw new Error(`maxGraphMemoryBytes overflow: requested ${bytes} estimated graph bytes`); } private throwGraphBudgetExceeded(bytes: number): never { throw new Error( - `maxGraphMemoryBytes exceeded: requested ${bytes} estimated graph bytes, ` - + `${this.remainingGraphMemoryBytes} remaining, effective limit ` - + `${this.effectiveGraphMemoryBytes}`, + `maxGraphMemoryBytes exceeded: requested ${bytes} estimated graph bytes, ` + + `${this.remainingGraphMemoryBytes} remaining, effective limit ` + + `${this.effectiveGraphMemoryBytes}`, ); } @@ -612,8 +599,8 @@ export class ReadContext { this._depth++; if (this._depth > this._maxDepth) { throw new Error( - `Deserialization depth limit exceeded: ${this._depth} > ${this._maxDepth}. ` - + "The data may be malicious, or increase maxDepth if needed.", + `Deserialization depth limit exceeded: ${this._depth} > ${this._maxDepth}. ` + + "The data may be malicious, or increase maxDepth if needed.", ); } } @@ -673,12 +660,7 @@ export class ReadContext { const idOrLen = this.reader.readVarUInt32(); if (idOrLen & 1) { const typeMeta = this.readTypeMetaRef(idOrLen); - this.checkNamedTypeMeta( - typeMeta, - expectedTypeId, - expectedNamespace, - expectedTypeName, - ); + this.checkNamedTypeMeta(typeMeta, expectedTypeId, expectedNamespace, expectedTypeName); return typeMeta; } const dynamicTypeId = idOrLen >> 1; @@ -711,21 +693,14 @@ export class ReadContext { this.typeResolver.config.maxTypeMetaBytes, ); const typeMetaEnd = this.reader.readGetCursor(); - this.checkNamedTypeMeta( - typeMeta, - expectedTypeId, - expectedNamespace, - expectedTypeName, - ); + this.checkNamedTypeMeta(typeMeta, expectedTypeId, expectedNamespace, expectedTypeName); const localSerializer = this.serializerByTypeMeta(typeMeta); if (localSerializer === undefined) { throw new Error( `can't find serializer for TypeMeta ${typeMeta.getNs()}$${typeMeta.getTypeName()}`, ); } - if ( - this.matchesExactLocalTypeMeta(typeMeta, typeMetaStart, typeMetaEnd) - ) { + if (this.matchesExactLocalTypeMeta(typeMeta, typeMetaStart, typeMetaEnd)) { this.cacheTypeMeta(headerHash, typeMeta, undefined); } else { const typeKey = this.checkRemoteTypeMetaLimit(typeMeta); @@ -735,20 +710,12 @@ export class ReadContext { return typeMeta; } } - this.checkNamedTypeMeta( - typeMeta, - expectedTypeId, - expectedNamespace, - expectedTypeName, - ); + this.checkNamedTypeMeta(typeMeta, expectedTypeId, expectedNamespace, expectedTypeName); this.typeMeta[dynamicTypeId] = typeMeta; return typeMeta; } - readCompatibleStructSerializer( - localHash: number, - original?: Serializer, - ): Serializer | undefined { + readCompatibleStructSerializer(localHash: number, original?: Serializer): Serializer | undefined { const idOrLen = this.reader.readVarUInt32(); let typeMeta: TypeMeta; let remoteHash: number; @@ -774,12 +741,7 @@ export class ReadContext { remoteHash = headerHash; } if (localHash !== remoteHash) { - return this.ensureCompatibleReadSerializer( - typeMeta, - localHash, - remoteHash, - original, - ); + return this.ensureCompatibleReadSerializer(typeMeta, localHash, remoteHash, original); } return undefined; } @@ -800,14 +762,14 @@ export class ReadContext { expectedTypeName: string, ) { if ( - typeMeta.getTypeId() !== expectedTypeId - || typeMeta.getNs() !== expectedNamespace - || typeMeta.getTypeName() !== expectedTypeName + typeMeta.getTypeId() !== expectedTypeId || + typeMeta.getNs() !== expectedNamespace || + typeMeta.getTypeName() !== expectedTypeName ) { throw new Error( - `TypeMeta mismatch: expected ${expectedNamespace}$${expectedTypeName} ` - + `type ${expectedTypeId}, got ${typeMeta.getNs()}$${typeMeta.getTypeName()} ` - + `type ${typeMeta.getTypeId()}`, + `TypeMeta mismatch: expected ${expectedNamespace}$${expectedTypeName} ` + + `type ${expectedTypeId}, got ${typeMeta.getNs()}$${typeMeta.getTypeName()} ` + + `type ${typeMeta.getTypeId()}`, ); } } @@ -848,25 +810,17 @@ export class ReadContext { this.typeResolver.config.maxTypeMetaBytes, ); const typeMetaEnd = this.reader.readGetCursor(); - if ( - this.matchesExactLocalTypeMeta(typeMeta, typeMetaStart, typeMetaEnd) - ) { + if (this.matchesExactLocalTypeMeta(typeMeta, typeMetaStart, typeMetaEnd)) { this.cacheTypeMeta(headerHash, typeMeta, undefined); } else { const localSerializer = original ?? this.serializerByTypeMeta(typeMeta); - if ( - localSerializer === undefined - && !TypeId.structType(typeMeta.getTypeId()) - ) { + if (localSerializer === undefined && !TypeId.structType(typeMeta.getTypeId())) { throw new Error( `can't find serializer for TypeMeta ${typeMeta.getNs()}$${typeMeta.getTypeName()}`, ); } const typeKey = this.checkRemoteTypeMetaLimit(typeMeta); - if ( - localSerializer !== undefined - && TypeId.structType(typeMeta.getTypeId()) - ) { + if (localSerializer !== undefined && TypeId.structType(typeMeta.getTypeId())) { const expectedHash = localHash ?? localSerializer.getHash(); if (expectedHash !== typeMeta.getHash()) { this.ensureCompatibleReadSerializer( @@ -876,16 +830,8 @@ export class ReadContext { localSerializer, ); } - } else if ( - localHash !== undefined - && localHash !== typeMeta.getHash() - ) { - this.ensureCompatibleReadSerializer( - typeMeta, - localHash, - typeMeta.getHash(), - original, - ); + } else if (localHash !== undefined && localHash !== typeMeta.getHash()) { + this.ensureCompatibleReadSerializer(typeMeta, localHash, typeMeta.getHash(), original); } this.cacheTypeMeta(headerHash, typeMeta, typeKey); } @@ -918,33 +864,30 @@ export class ReadContext { : typeMeta.getUserTypeId(); const versionsByType = this.remoteSchemaVersionsByType; const versionsForType = versionsByType?.get(typeKey) ?? 0; - const maxSchemaVersionsPerType - = this.typeResolver.config.maxSchemaVersionsPerType; + const maxSchemaVersionsPerType = this.typeResolver.config.maxSchemaVersionsPerType; if (versionsForType >= maxSchemaVersionsPerType) { throw new Error( - `Remote schema version limit exceeded for type ${String(typeKey)}: ` - + `${versionsForType} >= ${maxSchemaVersionsPerType}. The data may ` - + "be malicious. If the data is not malicious, please increase " - + "maxSchemaVersionsPerType.", + `Remote schema version limit exceeded for type ${String(typeKey)}: ` + + `${versionsForType} >= ${maxSchemaVersionsPerType}. The data may ` + + "be malicious. If the data is not malicious, please increase " + + "maxSchemaVersionsPerType.", ); } - const acceptedTypeCount - = versionsForType === 0 - ? (versionsByType?.size ?? 0) + 1 - : versionsByType!.size; - const maxAverageSchemaVersionsPerType - = this.typeResolver.config.maxAverageSchemaVersionsPerType; + const acceptedTypeCount = + versionsForType === 0 ? (versionsByType?.size ?? 0) + 1 : versionsByType!.size; + const maxAverageSchemaVersionsPerType = + this.typeResolver.config.maxAverageSchemaVersionsPerType; const globalLimit = Math.max( ReadContext.MIN_REMOTE_TYPE_META_LIMIT, acceptedTypeCount * maxAverageSchemaVersionsPerType, ); if (this.totalAcceptedSchemaVersions >= globalLimit) { throw new Error( - `Remote schema version limit exceeded: ${this.totalAcceptedSchemaVersions} ` - + `metadata versions for ${acceptedTypeCount} accepted remote types ` - + `exceeds the average limit ${maxAverageSchemaVersionsPerType}. ` - + "The data may be malicious. If the data is not malicious, please " - + "increase maxAverageSchemaVersionsPerType.", + `Remote schema version limit exceeded: ${this.totalAcceptedSchemaVersions} ` + + `metadata versions for ${acceptedTypeCount} accepted remote types ` + + `exceeds the average limit ${maxAverageSchemaVersionsPerType}. ` + + "The data may be malicious. If the data is not malicious, please " + + "increase maxAverageSchemaVersionsPerType.", ); } return typeKey; @@ -972,24 +915,15 @@ export class ReadContext { private serializerByTypeMeta(typeMeta: TypeMeta) { const typeId = typeMeta.getTypeId(); if (TypeId.isNamedType(typeId)) { - return this.typeResolver.getSerializerByName( - `${typeMeta.getNs()}$${typeMeta.getTypeName()}`, - ); + return this.typeResolver.getSerializerByName(`${typeMeta.getNs()}$${typeMeta.getTypeName()}`); } if (TypeId.needsUserTypeId(typeId)) { - return this.typeResolver.getSerializerById( - typeId, - typeMeta.getUserTypeId(), - ); + return this.typeResolver.getSerializerById(typeId, typeMeta.getUserTypeId()); } return this.typeResolver.getSerializerById(typeId); } - private matchesExactLocalTypeMeta( - remoteTypeMeta: TypeMeta, - start: number, - end: number, - ): boolean { + private matchesExactLocalTypeMeta(remoteTypeMeta: TypeMeta, start: number, end: number): boolean { const serializer = this.serializerByTypeMeta(remoteTypeMeta); const localBytes = serializer?.getTypeMetaBytes?.(); if (localBytes === undefined) { @@ -1037,28 +971,23 @@ export class ReadContext { if (remote === undefined || local === undefined) { return false; } - if ( - this.canonicalTypeId(remote.typeId) !== this.canonicalFieldTypeId(local) - ) { + if (this.canonicalTypeId(remote.typeId) !== this.canonicalFieldTypeId(local)) { return false; } if ( - (remote.trackingRef === true) !== (local.trackingRef === true) - || (remote.nullable === true) !== (local.nullable === true) + (remote.trackingRef === true) !== (local.trackingRef === true) || + (remote.nullable === true) !== (local.nullable === true) ) { return false; } switch (remote.typeId) { case TypeId.MAP: return ( - this.fieldSchemasEqual(remote.options?.key, local.options?.key) - && this.fieldSchemasEqual(remote.options?.value, local.options?.value) + this.fieldSchemasEqual(remote.options?.key, local.options?.key) && + this.fieldSchemasEqual(remote.options?.value, local.options?.value) ); case TypeId.LIST: - return this.fieldSchemasEqual( - remote.options?.inner, - local.options?.inner, - ); + return this.fieldSchemasEqual(remote.options?.inner, local.options?.inner); case TypeId.SET: return this.fieldSchemasEqual(remote.options?.key, local.options?.key); default: @@ -1075,62 +1004,39 @@ export class ReadContext { if (this.fieldSchemasEqual(fieldInfo, fallbackTypeInfo)) { return fallbackTypeInfo.clone(); } - const compatible = this.compatibleFieldTypeInfo( - fieldInfo, - fallbackTypeInfo, - ); + const compatible = this.compatibleFieldTypeInfo(fieldInfo, fallbackTypeInfo); if (compatible) { return compatible; } if ( - isCompatibleScalarType(fieldInfo.typeId) - && isCompatibleScalarType(fallbackTypeInfo.typeId) - && ((fieldInfo.trackingRef === true) - !== (fallbackTypeInfo.trackingRef === true) - || ((fieldInfo.trackingRef === true - || fallbackTypeInfo.trackingRef === true) - && (fieldInfo.typeId !== fallbackTypeInfo.typeId - || fieldInfo.nullable !== fallbackTypeInfo.nullable))) + isCompatibleScalarType(fieldInfo.typeId) && + isCompatibleScalarType(fallbackTypeInfo.typeId) && + ((fieldInfo.trackingRef === true) !== (fallbackTypeInfo.trackingRef === true) || + ((fieldInfo.trackingRef === true || fallbackTypeInfo.trackingRef === true) && + (fieldInfo.typeId !== fallbackTypeInfo.typeId || + fieldInfo.nullable !== fallbackTypeInfo.nullable))) ) { - throw new Error( - "unsupported compatible scalar tracking-ref schema mismatch", - ); + throw new Error("unsupported compatible scalar tracking-ref schema mismatch"); } if ( - isCompatibleScalarPair(fieldInfo.typeId, fallbackTypeInfo.typeId) - && fieldInfo.typeId !== fallbackTypeInfo.typeId - && (fieldInfo.trackingRef === true - || fallbackTypeInfo.trackingRef === true) + isCompatibleScalarPair(fieldInfo.typeId, fallbackTypeInfo.typeId) && + fieldInfo.typeId !== fallbackTypeInfo.typeId && + (fieldInfo.trackingRef === true || fallbackTypeInfo.trackingRef === true) ) { - throw new Error( - "unsupported compatible scalar tracking-ref schema mismatch", - ); + throw new Error("unsupported compatible scalar tracking-ref schema mismatch"); } - if ( - this.hasUnsupportedListArrayMismatch( - fieldInfo, - fallbackTypeInfo, - topLevel, - ) - ) { + if (this.hasUnsupportedListArrayMismatch(fieldInfo, fallbackTypeInfo, topLevel)) { throw new Error("unsupported compatible list/array schema mismatch"); } if ( - fieldInfo.typeId !== TypeId.UNKNOWN - && this.canonicalFieldTypeId(fallbackTypeInfo) !== TypeId.UNKNOWN - && this.canonicalTypeId(fieldInfo.typeId) - !== this.canonicalFieldTypeId(fallbackTypeInfo) + fieldInfo.typeId !== TypeId.UNKNOWN && + this.canonicalFieldTypeId(fallbackTypeInfo) !== TypeId.UNKNOWN && + this.canonicalTypeId(fieldInfo.typeId) !== this.canonicalFieldTypeId(fallbackTypeInfo) ) { throw new Error("unsupported compatible field schema mismatch"); } } - if ( - this.hasUnsupportedListArrayMismatch( - fieldInfo, - fallbackTypeInfo, - topLevel, - ) - ) { + if (this.hasUnsupportedListArrayMismatch(fieldInfo, fallbackTypeInfo, topLevel)) { throw new Error("unsupported compatible list/array schema mismatch"); } if (this.hasNestedSchemaMismatch(fieldInfo, fallbackTypeInfo, topLevel)) { @@ -1139,11 +1045,7 @@ export class ReadContext { switch (fieldInfo.typeId) { case TypeId.MAP: return Type.map( - this.fieldInfoToTypeInfo( - fieldInfo.options!.key!, - fallbackTypeInfo?.options?.key, - false, - ), + this.fieldInfoToTypeInfo(fieldInfo.options!.key!, fallbackTypeInfo?.options?.key, false), this.fieldInfoToTypeInfo( fieldInfo.options!.value!, fallbackTypeInfo?.options?.value, @@ -1160,11 +1062,7 @@ export class ReadContext { ); case TypeId.SET: return Type.set( - this.fieldInfoToTypeInfo( - fieldInfo.options!.key!, - fallbackTypeInfo?.options?.key, - false, - ), + this.fieldInfoToTypeInfo(fieldInfo.options!.key!, fallbackTypeInfo?.options?.key, false), ); default: { // Remote TypeMeta only carries the nested user-defined type kind, not the @@ -1209,53 +1107,37 @@ export class ReadContext { return false; } if ( - this.schemaMatchTypeId(remote.typeId) - !== this.schemaMatchTypeId(this.typeResolver.computeTypeId(local)) + this.schemaMatchTypeId(remote.typeId) !== + this.schemaMatchTypeId(this.typeResolver.computeTypeId(local)) ) { return true; } const remoteTracksRef = remote.trackingRef === true; const localTracksRef = local.trackingRef === true; if ( - remoteTracksRef !== localTracksRef - || ((remoteTracksRef || localTracksRef) - && (remote.nullable === true) !== (local.nullable === true)) + remoteTracksRef !== localTracksRef || + ((remoteTracksRef || localTracksRef) && + (remote.nullable === true) !== (local.nullable === true)) ) { return true; } switch (remote.typeId) { case TypeId.MAP: return ( - local.options?.key === undefined - || local.options?.value === undefined - || this.hasNestedSchemaMismatch( - remote.options!.key!, - local.options.key, - false, - ) - || this.hasNestedSchemaMismatch( - remote.options!.value!, - local.options.value, - false, - ) + local.options?.key === undefined || + local.options?.value === undefined || + this.hasNestedSchemaMismatch(remote.options!.key!, local.options.key, false) || + this.hasNestedSchemaMismatch(remote.options!.value!, local.options.value, false) ); case TypeId.LIST: return ( - local.options?.inner === undefined - || this.hasNestedSchemaMismatch( - remote.options!.inner!, - local.options.inner, - false, - ) + local.options?.inner === undefined || + this.hasNestedSchemaMismatch(remote.options!.inner!, local.options.inner, false) ); case TypeId.SET: return ( - local.options?.key === undefined - || this.hasNestedSchemaMismatch( - remote.options!.key!, - local.options.key, - false, - ) + local.options?.key === undefined || + this.hasNestedSchemaMismatch(remote.options!.key!, local.options.key, false) ); default: return false; @@ -1266,25 +1148,22 @@ export class ReadContext { return this.canonicalTypeId(typeId); } - private compatibleFieldTypeInfo( - remote: InnerFieldInfo, - local: TypeInfo, - ): TypeInfo | undefined { + private compatibleFieldTypeInfo(remote: InnerFieldInfo, local: TypeInfo): TypeInfo | undefined { if (this.isByteSequenceRootPair(remote, local)) { if ( - (remote.nullable === true) !== (local.nullable === true) - || (remote.trackingRef === true) !== (local.trackingRef === true) + (remote.nullable === true) !== (local.nullable === true) || + (remote.trackingRef === true) !== (local.trackingRef === true) ) { return undefined; } return local.clone(); } if ( - this.isListArrayRootPair(remote, local) - && (remote.nullable === true - || local.nullable === true - || remote.trackingRef === true - || local.trackingRef === true) + this.isListArrayRootPair(remote, local) && + (remote.nullable === true || + local.nullable === true || + remote.trackingRef === true || + local.trackingRef === true) ) { return undefined; } @@ -1304,22 +1183,20 @@ export class ReadContext { } const remoteArrayElement = denseArrayElementTypeId(remote.typeId); if ( - remoteArrayElement !== undefined - && local.typeId === TypeId.LIST - && local.options?.inner - && compatibleArrayElementTypeId(local.options.inner.typeId) - === remoteArrayElement + remoteArrayElement !== undefined && + local.typeId === TypeId.LIST && + local.options?.inner && + compatibleArrayElementTypeId(local.options.inner.typeId) === remoteArrayElement ) { return compatibleArrayToListTypeInfo(remoteArrayElement); } if ( - remote.trackingRef !== true - && local.trackingRef !== true - && !( - remote.typeId === local.typeId - && (remote.nullable === true) === (local.nullable === true) - ) - && isCompatibleScalarPair(remote.typeId, local.typeId) + remote.trackingRef !== true && + local.trackingRef !== true && + !( + remote.typeId === local.typeId && (remote.nullable === true) === (local.nullable === true) + ) && + isCompatibleScalarPair(remote.typeId, local.typeId) ) { return markCompatibleScalarRead(local.clone(), { remoteTypeId: remote.typeId, @@ -1347,16 +1224,8 @@ export class ReadContext { switch (remote.typeId) { case TypeId.MAP: return ( - this.hasUnsupportedListArrayMismatch( - remote.options!.key!, - local.options?.key, - false, - ) - || this.hasUnsupportedListArrayMismatch( - remote.options!.value!, - local.options?.value, - false, - ) + this.hasUnsupportedListArrayMismatch(remote.options!.key!, local.options?.key, false) || + this.hasUnsupportedListArrayMismatch(remote.options!.value!, local.options?.value, false) ); case TypeId.LIST: return this.hasUnsupportedListArrayMismatch( @@ -1375,26 +1244,17 @@ export class ReadContext { } } - private isListArrayRootPair( - remote: InnerFieldInfo, - local: TypeInfo, - ): boolean { + private isListArrayRootPair(remote: InnerFieldInfo, local: TypeInfo): boolean { return ( - (remote.typeId === TypeId.LIST - && denseArrayElementTypeId(local.typeId) !== undefined) - || (denseArrayElementTypeId(remote.typeId) !== undefined - && local.typeId === TypeId.LIST) + (remote.typeId === TypeId.LIST && denseArrayElementTypeId(local.typeId) !== undefined) || + (denseArrayElementTypeId(remote.typeId) !== undefined && local.typeId === TypeId.LIST) ); } - private isByteSequenceRootPair( - remote: InnerFieldInfo, - local: TypeInfo, - ): boolean { + private isByteSequenceRootPair(remote: InnerFieldInfo, local: TypeInfo): boolean { return ( - (remote.typeId === TypeId.BINARY - && local.typeId === TypeId.UINT8_ARRAY) - || (remote.typeId === TypeId.UINT8_ARRAY && local.typeId === TypeId.BINARY) + (remote.typeId === TypeId.BINARY && local.typeId === TypeId.UINT8_ARRAY) || + (remote.typeId === TypeId.UINT8_ARRAY && local.typeId === TypeId.BINARY) ); } @@ -1408,10 +1268,7 @@ export class ReadContext { const named = `${typeMeta.getNs()}$${typeMeta.getTypeName()}`; original = this.typeResolver.getSerializerByName(named); } else { - original = this.typeResolver.getSerializerById( - typeId, - typeMeta.getUserTypeId(), - ); + original = this.typeResolver.getSerializerById(typeId, typeMeta.getUserTypeId()); } } let typeInfo: TypeInfo; @@ -1426,25 +1283,18 @@ export class ReadContext { }); } const localProps = original?.getTypeInfo().options?.props; - const fieldEntries = typeMeta - .remapFieldNames(localProps) - .map((fieldInfo) => { - const localFieldTypeInfo = localProps?.[fieldInfo.getFieldName()]; - let fieldTypeInfo = this.fieldInfoToTypeInfo( - fieldInfo, - localFieldTypeInfo, - ) - .setNullable(fieldInfo.nullable) - .setTrackingRef(fieldInfo.trackingRef) - .setId(fieldInfo.fieldId); - if (localFieldTypeInfo === undefined) { - fieldTypeInfo = markCompatibleSkipRead(fieldTypeInfo); - } - return { key: fieldInfo.getFieldName(), typeInfo: fieldTypeInfo }; - }); - const props = Object.fromEntries( - fieldEntries.map(({ key, typeInfo }) => [key, typeInfo]), - ); + const fieldEntries = typeMeta.remapFieldNames(localProps).map((fieldInfo) => { + const localFieldTypeInfo = localProps?.[fieldInfo.getFieldName()]; + let fieldTypeInfo = this.fieldInfoToTypeInfo(fieldInfo, localFieldTypeInfo) + .setNullable(fieldInfo.nullable) + .setTrackingRef(fieldInfo.trackingRef) + .setId(fieldInfo.fieldId); + if (localFieldTypeInfo === undefined) { + fieldTypeInfo = markCompatibleSkipRead(fieldTypeInfo); + } + return { key: fieldInfo.getFieldName(), typeInfo: fieldTypeInfo }; + }); + const props = Object.fromEntries(fieldEntries.map(({ key, typeInfo }) => [key, typeInfo])); typeInfo.options = { ...typeInfo.options, preserveFieldOrder: true, diff --git a/javascript/packages/core/lib/fory.ts b/javascript/packages/core/lib/fory.ts index a7d3944a45..c1ba3fdb12 100644 --- a/javascript/packages/core/lib/fory.ts +++ b/javascript/packages/core/lib/fory.ts @@ -45,23 +45,15 @@ export default class Fory { readonly config: Config; readonly writeContext: WriteContext; readonly readContext: ReadContext; - private readonly rootSerializers = new WeakMap< - Serializer, - (data: any) => PlatformBuffer - >(); + private readonly rootSerializers = new WeakMap PlatformBuffer>(); - private readonly rootDeserializers = new WeakMap< - Serializer, - (bytes: Uint8Array) => any - >(); + private readonly rootDeserializers = new WeakMap any>(); constructor(config?: Partial) { this.config = this.initConfig(config); const maxDepth = this.config.maxDepth ?? DEFAULT_DEPTH_LIMIT; if (!Number.isInteger(maxDepth) || maxDepth < MIN_DEPTH_LIMIT) { - throw new Error( - `maxDepth must be an integer >= ${MIN_DEPTH_LIMIT} but got ${maxDepth}`, - ); + throw new Error(`maxDepth must be an integer >= ${MIN_DEPTH_LIMIT} but got ${maxDepth}`); } this.typeResolver = new TypeResolver(this.config); this.writeContext = new WriteContext(this.typeResolver, this.config); @@ -74,44 +66,32 @@ export default class Fory { private initConfig(config: Partial | undefined) { const maxTypeFields = config?.maxTypeFields ?? DEFAULT_MAX_TYPE_FIELDS; if (!Number.isInteger(maxTypeFields) || maxTypeFields <= 0) { - throw new Error( - `maxTypeFields must be a positive integer but got ${maxTypeFields}`, - ); + throw new Error(`maxTypeFields must be a positive integer but got ${maxTypeFields}`); } - const maxTypeMetaBytes - = config?.maxTypeMetaBytes ?? DEFAULT_MAX_TYPE_META_BYTES; + const maxTypeMetaBytes = config?.maxTypeMetaBytes ?? DEFAULT_MAX_TYPE_META_BYTES; if (!Number.isInteger(maxTypeMetaBytes) || maxTypeMetaBytes <= 0) { - throw new Error( - `maxTypeMetaBytes must be a positive integer but got ${maxTypeMetaBytes}`, - ); + throw new Error(`maxTypeMetaBytes must be a positive integer but got ${maxTypeMetaBytes}`); } - const maxSchemaVersionsPerType - = config?.maxSchemaVersionsPerType ?? DEFAULT_MAX_SCHEMA_VERSIONS_PER_TYPE; - if ( - !Number.isInteger(maxSchemaVersionsPerType) - || maxSchemaVersionsPerType <= 0 - ) { + const maxSchemaVersionsPerType = + config?.maxSchemaVersionsPerType ?? DEFAULT_MAX_SCHEMA_VERSIONS_PER_TYPE; + if (!Number.isInteger(maxSchemaVersionsPerType) || maxSchemaVersionsPerType <= 0) { throw new Error( `maxSchemaVersionsPerType must be a positive integer but got ${maxSchemaVersionsPerType}`, ); } - const maxAverageSchemaVersionsPerType - = config?.maxAverageSchemaVersionsPerType - ?? DEFAULT_MAX_AVERAGE_SCHEMA_VERSIONS_PER_TYPE; + const maxAverageSchemaVersionsPerType = + config?.maxAverageSchemaVersionsPerType ?? DEFAULT_MAX_AVERAGE_SCHEMA_VERSIONS_PER_TYPE; if ( - !Number.isInteger(maxAverageSchemaVersionsPerType) - || maxAverageSchemaVersionsPerType <= 0 + !Number.isInteger(maxAverageSchemaVersionsPerType) || + maxAverageSchemaVersionsPerType <= 0 ) { throw new Error( `maxAverageSchemaVersionsPerType must be a positive integer but got ${maxAverageSchemaVersionsPerType}`, ); } - const maxGraphMemoryBytes - = config?.maxGraphMemoryBytes ?? DEFAULT_MAX_GRAPH_MEMORY_BYTES; + const maxGraphMemoryBytes = config?.maxGraphMemoryBytes ?? DEFAULT_MAX_GRAPH_MEMORY_BYTES; if (!Number.isSafeInteger(maxGraphMemoryBytes)) { - throw new Error( - `maxGraphMemoryBytes must be a safe integer but got ${maxGraphMemoryBytes}`, - ); + throw new Error(`maxGraphMemoryBytes must be a safe integer but got ${maxGraphMemoryBytes}`); } return { ref: Boolean(config?.ref), @@ -153,9 +133,8 @@ export default class Fory { register(constructor: any, customSerializer?: CustomSerializer) { let serializer: Serializer; if (constructor.prototype?.[ForyTypeInfoSymbol]) { - const typeInfo: TypeInfo = ( - constructor.prototype[ForyTypeInfoSymbol] as WithForyClsInfo - ).structTypeInfo; + const typeInfo: TypeInfo = (constructor.prototype[ForyTypeInfoSymbol] as WithForyClsInfo) + .structTypeInfo; typeInfo.freeze(); serializer = new Gen(this.typeResolver, { creator: constructor, @@ -177,10 +156,7 @@ export default class Fory { }; } - deserialize( - bytes: Uint8Array, - serializer: Serializer = this.anySerializer, - ): T | null { + deserialize(bytes: Uint8Array, serializer: Serializer = this.anySerializer): T | null { this.readContext.reset(bytes); const reader = this.readContext.reader; const bitmap = reader.readUint8(); @@ -191,12 +167,9 @@ export default class Fory { } private throwInvalidRootHeader(bitmap: number): never { - const knownFlags - = ConfigFlags.isCrossLanguageFlag | ConfigFlags.isOutOfBandFlag; + const knownFlags = ConfigFlags.isCrossLanguageFlag | ConfigFlags.isOutOfBandFlag; if ((bitmap & ~knownFlags) !== 0) { - throw new Error( - `unsupported root header bitmap 0x${bitmap.toString(16)}`, - ); + throw new Error(`unsupported root header bitmap 0x${bitmap.toString(16)}`); } if ((bitmap & ConfigFlags.isCrossLanguageFlag) === 0) { throw new Error("support crosslanguage mode only"); diff --git a/javascript/packages/core/lib/gen/collection.ts b/javascript/packages/core/lib/gen/collection.ts index d98a228ac0..04c3f293d8 100644 --- a/javascript/packages/core/lib/gen/collection.ts +++ b/javascript/packages/core/lib/gen/collection.ts @@ -67,10 +67,7 @@ export const CollectionFlags = { SAME_TYPE: 0b1000, }; -function compatibleArrayCollectionExpr( - elementTypeId: number, - len: string, -): string { +function compatibleArrayCollectionExpr(elementTypeId: number, len: string): string { switch (elementTypeId) { case TypeId.BOOL: return `new external.BoolArray(${len})`; @@ -153,11 +150,7 @@ class CollectionAnySerializer { trackingRef = current.needToWriteRef(); } if (isSame) { - if ( - serializer !== null - && serializer !== undefined - && current !== serializer - ) { + if (serializer !== null && serializer !== undefined && current !== serializer) { isSame = false; } else { serializer = current; @@ -189,8 +182,7 @@ class CollectionAnySerializer { if (size === 0) { return; } - const { serializer, isSame, includeNone, trackingRef } - = this.writeElementsHeader(value); + const { serializer, isSame, includeNone, trackingRef } = this.writeElementsHeader(value); if (isSame) { serializer!.writeTypeInfo(value); if (trackingRef) { @@ -216,8 +208,7 @@ class CollectionAnySerializer { } else { if (trackingRef) { for (const item of value) { - const serializer - = this.writeContext.typeResolver.getSerializerByData(item); + const serializer = this.writeContext.typeResolver.getSerializerByData(item); serializer?.writeRef(item); } } else if (includeNone) { @@ -225,16 +216,14 @@ class CollectionAnySerializer { if (item === null || item === undefined) { this.writeContext.writer.writeInt8(RefFlags.NullFlag); } else { - const serializer - = this.writeContext.typeResolver.getSerializerByData(item); + const serializer = this.writeContext.typeResolver.getSerializerByData(item); this.writeContext.writer.writeInt8(RefFlags.NotNullValueFlag); serializer!.writeNoRef(item); } } } else { for (const item of value) { - const serializer - = this.writeContext.typeResolver.getSerializerByData(item); + const serializer = this.writeContext.typeResolver.getSerializerByData(item); serializer!.writeNoRef(item); } } @@ -248,9 +237,7 @@ class CollectionAnySerializer { ): any { void fromRef; const len = this.readContext.reader.readVarUint32Small7(); - this.readContext.reserveGraphMemory( - COLLECTION_BYTES + len * REFERENCE_BYTES, - ); + this.readContext.reserveGraphMemory(COLLECTION_BYTES + len * REFERENCE_BYTES); if (len === 0) { return createCollection(len); } @@ -275,11 +262,7 @@ class CollectionAnySerializer { const refId = this.readContext.reader.readVarUInt32(); accessor(result, i, this.readContext.getReadRef(refId)); } else if (refFlag === RefFlags.RefValueFlag) { - accessor( - result, - i, - this.readSerializerWithDepth(serializer!, true), - ); + accessor(result, i, this.readSerializerWithDepth(serializer!, true)); } else { accessor(result, i, null); } @@ -290,11 +273,7 @@ class CollectionAnySerializer { if (flag === RefFlags.NullFlag) { accessor(result, i, null); } else { - accessor( - result, - i, - this.readSerializerWithDepth(serializer!, false), - ); + accessor(result, i, this.readSerializerWithDepth(serializer!, false)); } } } else { @@ -315,21 +294,13 @@ class CollectionAnySerializer { accessor(result, i, null); } else { const itemSerializer = AnyHelper.detectSerializer(this.readContext); - accessor( - result, - i, - this.readSerializerWithDepth(itemSerializer!, false), - ); + accessor(result, i, this.readSerializerWithDepth(itemSerializer!, false)); } } } else { for (let i = 0; i < len; i++) { const itemSerializer = AnyHelper.detectSerializer(this.readContext); - accessor( - result, - i, - this.readSerializerWithDepth(itemSerializer!, false), - ); + accessor(result, i, this.readSerializerWithDepth(itemSerializer!, false)); } } } @@ -345,11 +316,7 @@ export abstract class CollectionSerializerGenerator extends BaseSerializerGenera super(typeInfo, builder, scope); this.typeInfo = typeInfo; const inner = this.genericTypeDescriptin()!; - this.innerGenerator = CodegenRegistry.newGeneratorByTypeInfo( - inner, - this.builder, - this.scope, - ); + this.innerGenerator = CodegenRegistry.newGeneratorByTypeInfo(inner, this.builder, this.scope); } abstract genericTypeDescriptin(): TypeInfo | undefined; @@ -367,12 +334,12 @@ export abstract class CollectionSerializerGenerator extends BaseSerializerGenera private isDeclaredElementType() { const innerTypeId = this.innerGenerator.getTypeId(); return ( - innerTypeId !== TypeId.STRUCT - && innerTypeId !== TypeId.COMPATIBLE_STRUCT - && innerTypeId !== TypeId.NAMED_STRUCT - && innerTypeId !== TypeId.NAMED_COMPATIBLE_STRUCT - && innerTypeId !== TypeId.EXT - && innerTypeId !== TypeId.NAMED_EXT + innerTypeId !== TypeId.STRUCT && + innerTypeId !== TypeId.COMPATIBLE_STRUCT && + innerTypeId !== TypeId.NAMED_STRUCT && + innerTypeId !== TypeId.NAMED_COMPATIBLE_STRUCT && + innerTypeId !== TypeId.EXT && + innerTypeId !== TypeId.NAMED_EXT ); } @@ -441,10 +408,7 @@ export abstract class CollectionSerializerGenerator extends BaseSerializerGenera `; } - readSpecificType( - accessor: (expr: string) => string, - refState: string, - ): string { + readSpecificType(accessor: (expr: string) => string, refState: string): string { const result = this.scope.uniqueName("result"); const len = this.scope.uniqueName("len"); const flags = this.scope.uniqueName("flags"); @@ -453,12 +417,8 @@ export abstract class CollectionSerializerGenerator extends BaseSerializerGenera const elemSerializer = this.scope.uniqueName("elemSerializer"); const anyHelper = this.builder.getExternal(AnyHelper.name); const readContextName = this.builder.getReadContextName(); - const useDeclaredStructElementReader = TypeId.structType( - this.innerGenerator.getTypeId()!, - ); - const compatibleReadAction = getCompatibleCollectionArrayReadAction( - this.typeInfo, - ); + const useDeclaredStructElementReader = TypeId.structType(this.innerGenerator.getTypeId()!); + const compatibleReadAction = getCompatibleCollectionArrayReadAction(this.typeInfo); const compatibleListToArray = compatibleReadAction?.target === "array"; const newCollection = compatibleListToArray ? compatibleArrayCollectionExpr(compatibleReadAction!.elementTypeId, len) @@ -468,12 +428,7 @@ export abstract class CollectionSerializerGenerator extends BaseSerializerGenera : `${readContextName}.reserveGraphMemory(${COLLECTION_BYTES} + ${len} * ${REFERENCE_BYTES});`; const putAccessor = (item: string, index: string) => compatibleListToArray - ? compatibleArrayPutAccessor( - compatibleReadAction!.elementTypeId, - result, - item, - index, - ) + ? compatibleArrayPutAccessor(compatibleReadAction!.elementTypeId, result, item, index) : this.putAccessor(result, item, index); const rejectCompatiblePayload = compatibleListToArray ? ` @@ -490,18 +445,15 @@ export abstract class CollectionSerializerGenerator extends BaseSerializerGenera const innerReader = useDeclaredStructElementReader ? this.innerGenerator.readEmbed() : this.innerGenerator; - const readInnerElement = ( - assignStmt: (x: any) => string, - refState: string, - ) => { + const readInnerElement = (assignStmt: (x: any) => string, refState: string) => { return innerIsLeaf ? this.innerGenerator.read(assignStmt, refState) : innerReader.readWithDepth(assignStmt, refState); }; const readElementTypeInfo = useDeclaredStructElementReader ? this.innerGenerator - .readEmbed() - .readTypeInfo((expr: string) => `${elemSerializer} = ${expr};`) + .readEmbed() + .readTypeInfo((expr: string) => `${elemSerializer} = ${expr};`) : `${elemSerializer} = ${anyHelper}.detectSerializer(${readContextName});`; return ` const ${len} = ${this.builder.reader.readVarUint32Small7()}; diff --git a/javascript/packages/core/lib/gen/ext.ts b/javascript/packages/core/lib/gen/ext.ts index 11ec957d94..a45d3df4df 100644 --- a/javascript/packages/core/lib/gen/ext.ts +++ b/javascript/packages/core/lib/gen/ext.ts @@ -45,10 +45,7 @@ class ExtSerializerGenerator extends BaseSerializerGenerator { } private objectGraphBytes(): number { - return ( - OBJECT_BYTES - + Object.keys(this.typeInfo.options?.props ?? {}).length * REFERENCE_BYTES - ); + return OBJECT_BYTES + Object.keys(this.typeInfo.options?.props ?? {}).length * REFERENCE_BYTES; } write(accessor: string): string { @@ -82,7 +79,7 @@ class ExtSerializerGenerator extends BaseSerializerGenerator { ${this.readTypeInfo()} ${this.builder.getReadContextName()}.incReadDepth(); let ${result}; - ${this.read(v => `${result} = ${v}`, refState)}; + ${this.read((v) => `${result} = ${v}`, refState)}; ${this.builder.getReadContextName()}.decReadDepth(); ${assignStmt(result)}; `; @@ -132,12 +129,12 @@ class ExtSerializerGenerator extends BaseSerializerGenerator { "ext_ser", TypeId.isNamedType(this.typeInfo.typeId) ? this.builder.typeResolver.getSerializerByName( - CodecBuilder.replaceBackslashAndQuote(this.typeInfo.named!), - ) + CodecBuilder.replaceBackslashAndQuote(this.typeInfo.named!), + ) : this.builder.typeResolver.getSerializerById( - this.typeInfo.typeId, - this.typeInfo.userTypeId, - ), + this.typeInfo.typeId, + this.typeInfo.userTypeId, + ), ); return accessor(`${name}.${prop}(${args.join(",")})`); }; @@ -156,12 +153,12 @@ class ExtSerializerGenerator extends BaseSerializerGenerator { "ext_ser", TypeId.isNamedType(this.typeInfo.typeId) ? this.builder.typeResolver.getSerializerByName( - CodecBuilder.replaceBackslashAndQuote(this.typeInfo.named!), - ) + CodecBuilder.replaceBackslashAndQuote(this.typeInfo.named!), + ) : this.builder.typeResolver.getSerializerById( - this.typeInfo.typeId, - this.typeInfo.userTypeId, - ), + this.typeInfo.typeId, + this.typeInfo.userTypeId, + ), ); return `${name}.${prop}(${accessor})`; }; @@ -176,9 +173,7 @@ class ExtSerializerGenerator extends BaseSerializerGenerator { let writeUserTypeIdStmt = ""; switch (internalTypeId) { case TypeId.EXT: - writeUserTypeIdStmt = this.builder.writer.writeVarUint32Small7( - this.typeInfo.userTypeId, - ); + writeUserTypeIdStmt = this.builder.writer.writeVarUint32Small7(this.typeInfo.userTypeId); break; case TypeId.NAMED_EXT: if (!this.builder.resolver.isCompatible()) { @@ -204,10 +199,7 @@ class ExtSerializerGenerator extends BaseSerializerGenerator { "typeInfoBytes", `new Uint8Array([${TypeMeta.fromTypeInfo(this.typeInfo, this.builder.resolver).toBytes().join(",")}])`, ); - typeMeta = this.builder.typeMetaResolver.writeTypeMeta( - this.builder.getTypeInfo(), - bytes, - ); + typeMeta = this.builder.typeMetaResolver.writeTypeMeta(this.builder.getTypeInfo(), bytes); } break; default: diff --git a/javascript/packages/core/lib/gen/map.ts b/javascript/packages/core/lib/gen/map.ts index ff16c969ca..824d59a37d 100644 --- a/javascript/packages/core/lib/gen/map.ts +++ b/javascript/packages/core/lib/gen/map.ts @@ -52,9 +52,9 @@ class ElementInfo { return false; } return ( - this.serializer === other.serializer - && this.isNull === other.isNull - && this.trackRef === other.trackRef + this.serializer === other.serializer && + this.isNull === other.isNull && + this.trackRef === other.trackRef ); } } @@ -97,11 +97,7 @@ class MapChunkWriter { return flag; } - private writeHead( - keyInfo: ElementInfo, - valueInfo: ElementInfo, - withOutSize = false, - ) { + private writeHead(keyInfo: ElementInfo, valueInfo: ElementInfo, withOutSize = false) { // KV header const header = this.getHead(keyInfo, valueInfo); // chunkSize default 0 | KV header @@ -131,10 +127,10 @@ class MapChunkWriter { } // max size of chunk is 255 if ( - this.chunkSize == 255 - || this.chunkOffset == 0 - || !keyInfo.equalTo(this.preKeyInfo) - || !valueInfo.equalTo(this.preValueInfo) + this.chunkSize == 255 || + this.chunkOffset == 0 || + !keyInfo.equalTo(this.preKeyInfo) || + !valueInfo.equalTo(this.preValueInfo) ) { // new chunk this.endChunk(); @@ -149,10 +145,7 @@ class MapChunkWriter { endChunk() { if (this.chunkOffset > 0) { - this.writeContext.writer.setUint8Position( - this.chunkOffset, - this.chunkSize, - ); + this.writeContext.writer.setUint8Position(this.chunkOffset, this.chunkSize); this.chunkSize = 0; } } @@ -199,21 +192,17 @@ class MapAnySerializer { ); this.writeContext.writer.writeVarUint32Small7(value.size); for (const [k, v] of value.entries()) { - const keySerializer - = this.keySerializer !== null + const keySerializer = + this.keySerializer !== null ? this.keySerializer : this.writeContext.typeResolver.getSerializerByData(k); - const valueSerializer - = this.valueSerializer !== null + const valueSerializer = + this.valueSerializer !== null ? this.valueSerializer : this.writeContext.typeResolver.getSerializerByData(v); const header = mapChunkWriter.next( - new ElementInfo( - keySerializer || null, - k == null, - keySerializer?.needToWriteRef() || false, - ), + new ElementInfo(keySerializer || null, k == null, keySerializer?.needToWriteRef() || false), new ElementInfo( valueSerializer || null, v == null, @@ -223,10 +212,7 @@ class MapAnySerializer { const keyHeader = header & 0b111; const valueHeader = header >> 3; if (mapChunkWriter.isFirst()) { - if ( - !(keyHeader & MapFlags.HAS_NULL) - && !(valueHeader & MapFlags.HAS_NULL) - ) { + if (!(keyHeader & MapFlags.HAS_NULL) && !(valueHeader & MapFlags.HAS_NULL)) { if (!(keyHeader & MapFlags.DECL_ELEMENT_TYPE)) { keySerializer?.writeTypeInfo(null); } @@ -236,8 +222,7 @@ class MapAnySerializer { } } - const includeNone - = keyHeader & MapFlags.HAS_NULL || valueHeader & MapFlags.HAS_NULL; + const includeNone = keyHeader & MapFlags.HAS_NULL || valueHeader & MapFlags.HAS_NULL; if (!this.writeFlag(keyHeader, k)) { if (!includeNone) { keySerializer!.write(k); @@ -269,41 +254,28 @@ class MapAnySerializer { return null; } if (!trackingRef) { - serializer - = serializer == null - ? AnyHelper.detectSerializer(this.readContext) - : serializer; + serializer = serializer == null ? AnyHelper.detectSerializer(this.readContext) : serializer; return this.readSerializerWithDepth(serializer!, false); } const flag = this.readContext.reader.readInt8(); switch (flag) { case RefFlags.RefValueFlag: - serializer - = serializer == null - ? AnyHelper.detectSerializer(this.readContext) - : serializer; + serializer = serializer == null ? AnyHelper.detectSerializer(this.readContext) : serializer; return this.readSerializerWithDepth(serializer!, true); case RefFlags.RefFlag: - return this.readContext.getReadRef( - this.readContext.reader.readVarUInt32(), - ); + return this.readContext.getReadRef(this.readContext.reader.readVarUInt32()); case RefFlags.NullFlag: return null; case RefFlags.NotNullValueFlag: - serializer - = serializer == null - ? AnyHelper.detectSerializer(this.readContext) - : serializer; + serializer = serializer == null ? AnyHelper.detectSerializer(this.readContext) : serializer; return this.readSerializerWithDepth(serializer!, false); } } read(fromRef: boolean): any { let count = this.readContext.reader.readVarUint32Small7(); - this.readContext.reserveGraphMemory( - MAP_BYTES + count * 2 * REFERENCE_BYTES, - ); + this.readContext.reserveGraphMemory(MAP_BYTES + count * 2 * REFERENCE_BYTES); const result = new Map(); if (fromRef) { this.readContext.reference(result); @@ -321,10 +293,7 @@ class MapAnySerializer { let keySerializer = this.keySerializer; let valueSerializer = this.valueSerializer; - if ( - !(keyHeader & MapFlags.HAS_NULL) - && !(valueHeader & MapFlags.HAS_NULL) - ) { + if (!(keyHeader & MapFlags.HAS_NULL) && !(valueHeader & MapFlags.HAS_NULL)) { if (!(keyHeader & MapFlags.DECL_ELEMENT_TYPE)) { keySerializer = AnyHelper.detectSerializer(this.readContext); } @@ -369,23 +338,19 @@ export class MapSerializerGenerator extends BaseSerializerGenerator { const keyTypeId = this.typeInfo.options?.key!.typeId; const valueTypeId = this.typeInfo.options?.value!.typeId; return ( - keyTypeId === TypeId.UNKNOWN - || valueTypeId === TypeId.UNKNOWN - || !TypeId.isBuiltin(keyTypeId!) - || !TypeId.isBuiltin(valueTypeId!) + keyTypeId === TypeId.UNKNOWN || + valueTypeId === TypeId.UNKNOWN || + !TypeId.isBuiltin(keyTypeId!) || + !TypeId.isBuiltin(valueTypeId!) ); } private writeSpecificType(accessor: string) { const k = this.scope.uniqueName("k"); const v = this.scope.uniqueName("v"); - let keyHeader = this.keyGenerator.needToWriteRef() - ? MapFlags.TRACKING_REF - : 0; + let keyHeader = this.keyGenerator.needToWriteRef() ? MapFlags.TRACKING_REF : 0; keyHeader |= MapFlags.DECL_ELEMENT_TYPE; - let valueHeader = this.valueGenerator.needToWriteRef() - ? MapFlags.TRACKING_REF - : 0; + let valueHeader = this.valueGenerator.needToWriteRef() ? MapFlags.TRACKING_REF : 0; valueHeader |= MapFlags.DECL_ELEMENT_TYPE; const lastKeyIsNull = this.scope.uniqueName("lastKeyIsNull"); const lastValueIsNull = this.scope.uniqueName("lastValueIsNull"); @@ -477,12 +442,12 @@ export class MapSerializerGenerator extends BaseSerializerGenerator { "map_inner_ser", TypeId.isNamedType(innerTypeInfo.typeId) ? this.builder.typeResolver.getSerializerByName( - CodecBuilder.replaceBackslashAndQuote(innerTypeInfo.named!), - ) + CodecBuilder.replaceBackslashAndQuote(innerTypeInfo.named!), + ) : this.builder.typeResolver.getSerializerById( - innerTypeInfo.typeId, - innerTypeInfo.userTypeId, - ), + innerTypeInfo.typeId, + innerTypeInfo.userTypeId, + ), ); }; return `new (${anySerializer})(${this.builder.getWriteContextName()}, ${this.builder.getReadContextName()}, ${ @@ -496,10 +461,7 @@ export class MapSerializerGenerator extends BaseSerializerGenerator { }).write(${accessor})`; } - private readSpecificType( - accessor: (expr: string) => string, - refState: string, - ) { + private readSpecificType(accessor: (expr: string) => string, refState: string) { const count = this.scope.uniqueName("count"); const result = this.scope.uniqueName("result"); // Skip depth tracking for leaf key/value types. @@ -572,12 +534,12 @@ export class MapSerializerGenerator extends BaseSerializerGenerator { switch (flag) { case ${RefFlags.RefValueFlag}: if (${keyDeclaredType}) { - ${readKey(x => `key = ${x}`, "true")} + ${readKey((x) => `key = ${x}`, "true")} } else { if (!${keySerializer}) { ${keySerializer} = ${anyHelper}.detectSerializer(${readContextName}); } - ${readDynamic(keySerializer, x => `key = ${x}`, "true")} + ${readDynamic(keySerializer, (x) => `key = ${x}`, "true")} } break; case ${RefFlags.RefFlag}: @@ -588,23 +550,23 @@ export class MapSerializerGenerator extends BaseSerializerGenerator { break; case ${RefFlags.NotNullValueFlag}: if (${keyDeclaredType}) { - ${readKey(x => `key = ${x}`, "false")} + ${readKey((x) => `key = ${x}`, "false")} } else { if (!${keySerializer}) { ${keySerializer} = ${anyHelper}.detectSerializer(${readContextName}); } - ${readDynamic(keySerializer, x => `key = ${x}`, "false")} + ${readDynamic(keySerializer, (x) => `key = ${x}`, "false")} } break; } } else { if (${keyDeclaredType}) { - ${readKey(x => `key = ${x}`, "false")} + ${readKey((x) => `key = ${x}`, "false")} } else { if (!${keySerializer}) { ${keySerializer} = ${anyHelper}.detectSerializer(${readContextName}); } - ${readDynamic(keySerializer, x => `key = ${x}`, "false")} + ${readDynamic(keySerializer, (x) => `key = ${x}`, "false")} } } @@ -615,12 +577,12 @@ export class MapSerializerGenerator extends BaseSerializerGenerator { switch (flag) { case ${RefFlags.RefValueFlag}: if (${valueDeclaredType}) { - ${readValue(x => `value = ${x}`, "true")} + ${readValue((x) => `value = ${x}`, "true")} } else { if (!${valueSerializer}) { ${valueSerializer} = ${anyHelper}.detectSerializer(${readContextName}); } - ${readDynamic(valueSerializer, x => `value = ${x}`, "true")} + ${readDynamic(valueSerializer, (x) => `value = ${x}`, "true")} } break; case ${RefFlags.RefFlag}: @@ -631,23 +593,23 @@ export class MapSerializerGenerator extends BaseSerializerGenerator { break; case ${RefFlags.NotNullValueFlag}: if (${valueDeclaredType}) { - ${readValue(x => `value = ${x}`, "false")} + ${readValue((x) => `value = ${x}`, "false")} } else { if (!${valueSerializer}) { ${valueSerializer} = ${anyHelper}.detectSerializer(${readContextName}); } - ${readDynamic(valueSerializer, x => `value = ${x}`, "false")} + ${readDynamic(valueSerializer, (x) => `value = ${x}`, "false")} } break; } } else { if (${valueDeclaredType}) { - ${readValue(x => `value = ${x}`, "false")} + ${readValue((x) => `value = ${x}`, "false")} } else { if (!${valueSerializer}) { ${valueSerializer} = ${anyHelper}.detectSerializer(${readContextName}); } - ${readDynamic(valueSerializer, x => `value = ${x}`, "false")} + ${readDynamic(valueSerializer, (x) => `value = ${x}`, "false")} } } @@ -672,12 +634,12 @@ export class MapSerializerGenerator extends BaseSerializerGenerator { "map_inner_ser", TypeId.isNamedType(innerTypeInfo.typeId) ? this.builder.typeResolver.getSerializerByName( - CodecBuilder.replaceBackslashAndQuote(innerTypeInfo.named!), - ) + CodecBuilder.replaceBackslashAndQuote(innerTypeInfo.named!), + ) : this.builder.typeResolver.getSerializerById( - innerTypeInfo.typeId, - innerTypeInfo.userTypeId, - ), + innerTypeInfo.typeId, + innerTypeInfo.userTypeId, + ), ); }; return accessor( diff --git a/javascript/packages/core/lib/gen/struct.ts b/javascript/packages/core/lib/gen/struct.ts index 801a34f4e5..d4824fd086 100644 --- a/javascript/packages/core/lib/gen/struct.ts +++ b/javascript/packages/core/lib/gen/struct.ts @@ -25,10 +25,7 @@ import { CodegenRegistry } from "./router"; import { BaseSerializerGenerator, SerializerGenerator } from "./serializer"; import { TypeMeta } from "../meta/TypeMeta"; import { getCompatibleCollectionArrayReadAction } from "./collection"; -import { - CompatibleScalarConverter, - getCompatibleScalarReadAction, -} from "../compatible/scalar"; +import { CompatibleScalarConverter, getCompatibleScalarReadAction } from "../compatible/scalar"; import { shouldSkipCompatibleRead } from "../compatible/field"; const OBJECT_BYTES = 1; @@ -50,12 +47,7 @@ function isDepthFreeField(typeInfo: TypeInfo): boolean { if (id === TypeId.MAP) { const key = typeInfo.options?.key; const value = typeInfo.options?.value; - return ( - !!key - && !!value - && TypeId.isLeafTypeId(key.typeId) - && TypeId.isLeafTypeId(value.typeId) - ); + return !!key && !!value && TypeId.isLeafTypeId(key.typeId) && TypeId.isLeafTypeId(value.typeId); } return false; } @@ -72,15 +64,12 @@ function compatibleReadTargetExpr(typeInfo: TypeInfo, expr: string): string { } } -const sortProps = ( - typeInfo: TypeInfo, - typeResolver: CodecBuilder["resolver"], -) => { +const sortProps = (typeInfo: TypeInfo, typeResolver: CodecBuilder["resolver"]) => { const props = typeInfo.options!.props; if (typeInfo.options!.preserveFieldOrder) { return ( - typeInfo.options!.fieldEntries - ?? Object.entries(props!).map(([key, fieldTypeInfo]) => ({ + typeInfo.options!.fieldEntries ?? + Object.entries(props!).map(([key, fieldTypeInfo]) => ({ key, typeInfo: fieldTypeInfo, })) @@ -116,10 +105,7 @@ function toRefMode(trackingRef?: boolean, nullable?: boolean) { } } -function isDirectVarInt32Field( - typeInfo: TypeInfo, - typeResolver: CodecBuilder["resolver"], -) { +function isDirectVarInt32Field(typeInfo: TypeInfo, typeResolver: CodecBuilder["resolver"]) { return varInt32ObjectReadKind(typeInfo, typeResolver) === "number"; } @@ -128,32 +114,29 @@ function varInt32ObjectReadKind( typeResolver: CodecBuilder["resolver"], ): "number" | "bigint" | null { if ( - toRefMode(typeInfo.trackingRef, typeInfo.nullable) !== RefMode.NONE - || !typeResolver.isMonomorphic(typeInfo, typeInfo.dynamic) + toRefMode(typeInfo.trackingRef, typeInfo.nullable) !== RefMode.NONE || + !typeResolver.isMonomorphic(typeInfo, typeInfo.dynamic) ) { return null; } const scalarAction = getCompatibleScalarReadAction(typeInfo); if (scalarAction !== undefined) { - return scalarAction.remoteNullable !== true - && scalarAction.remoteTypeId === TypeId.VARINT32 - && (scalarAction.localTypeId === TypeId.INT64 - || scalarAction.localTypeId === TypeId.VARINT64 - || scalarAction.localTypeId === TypeId.TAGGED_INT64) + return scalarAction.remoteNullable !== true && + scalarAction.remoteTypeId === TypeId.VARINT32 && + (scalarAction.localTypeId === TypeId.INT64 || + scalarAction.localTypeId === TypeId.VARINT64 || + scalarAction.localTypeId === TypeId.TAGGED_INT64) ? "bigint" : null; } return typeInfo.typeId === TypeId.VARINT32 ? "number" : null; } -function directNumericFieldReadExpr( - typeInfo: TypeInfo, - builder: CodecBuilder, -): string | null { +function directNumericFieldReadExpr(typeInfo: TypeInfo, builder: CodecBuilder): string | null { if ( - toRefMode(typeInfo.trackingRef, typeInfo.nullable) !== RefMode.NONE - || !builder.resolver.isMonomorphic(typeInfo, typeInfo.dynamic) - || getCompatibleScalarReadAction(typeInfo) !== undefined + toRefMode(typeInfo.trackingRef, typeInfo.nullable) !== RefMode.NONE || + !builder.resolver.isMonomorphic(typeInfo, typeInfo.dynamic) || + getCompatibleScalarReadAction(typeInfo) !== undefined ) { return null; } @@ -205,11 +188,7 @@ function compatibleScalarFieldReadExpr( builder: CodecBuilder, ): string | null { const converter = builder.getExternal(CompatibleScalarConverter.name); - const remoteRead = compatibleScalarRemoteReadExpr( - remoteTypeId, - builder, - converter, - ); + const remoteRead = compatibleScalarRemoteReadExpr(remoteTypeId, builder, converter); if (remoteRead === null) { return null; } @@ -233,22 +212,12 @@ function compatibleScalarFieldReadExpr( case TypeId.UINT16: case TypeId.UINT32: case TypeId.UINT64: - return scalarToIntegerExpr( - remoteCanonical, - localCanonical, - remoteRead, - converter, - ); + return scalarToIntegerExpr(remoteCanonical, localCanonical, remoteRead, converter); case TypeId.FLOAT16: case TypeId.BFLOAT16: case TypeId.FLOAT32: case TypeId.FLOAT64: - return scalarToFloatExpr( - remoteCanonical, - localCanonical, - remoteRead, - converter, - ); + return scalarToFloatExpr(remoteCanonical, localCanonical, remoteRead, converter); default: return null; } @@ -325,11 +294,7 @@ function compatibleScalarRemoteReadExpr( } } -function scalarToBoolExpr( - remoteTypeId: number, - value: string, - converter: string, -): string | null { +function scalarToBoolExpr(remoteTypeId: number, value: string, converter: string): string | null { switch (remoteTypeId) { case TypeId.BOOL: return value; @@ -347,11 +312,7 @@ function scalarToBoolExpr( } } -function scalarToStringExpr( - remoteTypeId: number, - value: string, - converter: string, -): string | null { +function scalarToStringExpr(remoteTypeId: number, value: string, converter: string): string | null { switch (remoteTypeId) { case TypeId.BOOL: return `(${value} ? "true" : "false")`; @@ -468,19 +429,19 @@ function integerRangeFits(remoteTypeId: number, localTypeId: number): boolean { return remoteTypeId === TypeId.INT8 || remoteTypeId === TypeId.UINT8; case TypeId.INT32: return ( - remoteTypeId === TypeId.INT8 - || remoteTypeId === TypeId.INT16 - || remoteTypeId === TypeId.UINT8 - || remoteTypeId === TypeId.UINT16 + remoteTypeId === TypeId.INT8 || + remoteTypeId === TypeId.INT16 || + remoteTypeId === TypeId.UINT8 || + remoteTypeId === TypeId.UINT16 ); case TypeId.INT64: return ( - remoteTypeId === TypeId.INT8 - || remoteTypeId === TypeId.INT16 - || remoteTypeId === TypeId.INT32 - || remoteTypeId === TypeId.UINT8 - || remoteTypeId === TypeId.UINT16 - || remoteTypeId === TypeId.UINT32 + remoteTypeId === TypeId.INT8 || + remoteTypeId === TypeId.INT16 || + remoteTypeId === TypeId.INT32 || + remoteTypeId === TypeId.UINT8 || + remoteTypeId === TypeId.UINT16 || + remoteTypeId === TypeId.UINT32 ); case TypeId.UINT16: return remoteTypeId === TypeId.UINT8; @@ -488,9 +449,9 @@ function integerRangeFits(remoteTypeId: number, localTypeId: number): boolean { return remoteTypeId === TypeId.UINT8 || remoteTypeId === TypeId.UINT16; case TypeId.UINT64: return ( - remoteTypeId === TypeId.UINT8 - || remoteTypeId === TypeId.UINT16 - || remoteTypeId === TypeId.UINT32 + remoteTypeId === TypeId.UINT8 || + remoteTypeId === TypeId.UINT16 || + remoteTypeId === TypeId.UINT32 ); default: return false; @@ -545,29 +506,26 @@ function floatMethod(prefix: string, localTypeId: number): string | null { } } -function integerRangeFitsFloat( - remoteTypeId: number, - localTypeId: number, -): boolean { +function integerRangeFitsFloat(remoteTypeId: number, localTypeId: number): boolean { switch (localTypeId) { case TypeId.FLOAT16: case TypeId.BFLOAT16: return remoteTypeId === TypeId.INT8 || remoteTypeId === TypeId.UINT8; case TypeId.FLOAT32: return ( - remoteTypeId === TypeId.INT8 - || remoteTypeId === TypeId.INT16 - || remoteTypeId === TypeId.UINT8 - || remoteTypeId === TypeId.UINT16 + remoteTypeId === TypeId.INT8 || + remoteTypeId === TypeId.INT16 || + remoteTypeId === TypeId.UINT8 || + remoteTypeId === TypeId.UINT16 ); case TypeId.FLOAT64: return ( - remoteTypeId === TypeId.INT8 - || remoteTypeId === TypeId.INT16 - || remoteTypeId === TypeId.INT32 - || remoteTypeId === TypeId.UINT8 - || remoteTypeId === TypeId.UINT16 - || remoteTypeId === TypeId.UINT32 + remoteTypeId === TypeId.INT8 || + remoteTypeId === TypeId.INT16 || + remoteTypeId === TypeId.INT32 || + remoteTypeId === TypeId.UINT8 || + remoteTypeId === TypeId.UINT16 || + remoteTypeId === TypeId.UINT32 ); default: return false; @@ -601,8 +559,8 @@ class StructSerializerGenerator extends BaseSerializerGenerator { private isDepthFreeStruct(): boolean { return ( - this.sortedProps.length > 0 - && this.sortedProps.every(({ typeInfo }) => isDepthFreeField(typeInfo)) + this.sortedProps.length > 0 && + this.sortedProps.every(({ typeInfo }) => isDepthFreeField(typeInfo)) ); } @@ -704,7 +662,7 @@ class StructSerializerGenerator extends BaseSerializerGenerator { const noneedWrite = this.scope.uniqueName("noneedWrite"); stmt = ` let ${noneedWrite} = false; - ${embedGenerator.writeRefOrNull(fieldAccessor, expr => `${noneedWrite} = ${expr}`)} + ${embedGenerator.writeRefOrNull(fieldAccessor, (expr) => `${noneedWrite} = ${expr}`)} if (!${noneedWrite}) { ${embedGenerator.write(fieldAccessor)} } @@ -753,10 +711,7 @@ class StructSerializerGenerator extends BaseSerializerGenerator { } write(accessor: string): string { - if ( - !this.typeInfo.options?.props - || Object.keys(this.typeInfo.options.props).length === 0 - ) { + if (!this.typeInfo.options?.props || Object.keys(this.typeInfo.options.props).length === 0) { const hash = this.typeMeta.computeStructHash(); return `${!this.builder.resolver.isCompatible() ? this.builder.writer.writeInt32(hash) : ""}`; } @@ -767,18 +722,13 @@ class StructSerializerGenerator extends BaseSerializerGenerator { if (isDirectVarInt32Field(current.typeInfo, this.builder.resolver)) { let end = i + 1; while ( - end < this.sortedProps.length - && isDirectVarInt32Field( - this.sortedProps[end].typeInfo, - this.builder.resolver, - ) + end < this.sortedProps.length && + isDirectVarInt32Field(this.sortedProps[end].typeInfo, this.builder.resolver) ) { end++; } if (end - i > 1) { - fieldWrites.push( - this.writeVarInt32Run(accessor, this.sortedProps.slice(i, end)), - ); + fieldWrites.push(this.writeVarInt32Run(accessor, this.sortedProps.slice(i, end))); i = end; continue; } @@ -787,19 +737,10 @@ class StructSerializerGenerator extends BaseSerializerGenerator { if (!InnerGeneratorClass) { throw new Error(`${current.typeInfo.typeId} generator not exists`); } - const innerGenerator = new InnerGeneratorClass( - current.typeInfo, - this.builder, - this.scope, - ); + const innerGenerator = new InnerGeneratorClass(current.typeInfo, this.builder, this.scope); const fieldAccessor = `${accessor}${CodecBuilder.safePropAccessor(current.key)}`; fieldWrites.push( - this.writeField( - current.key, - current.typeInfo, - fieldAccessor, - innerGenerator.writeEmbed(), - ), + this.writeField(current.key, current.typeInfo, fieldAccessor, innerGenerator.writeEmbed()), ); i++; } @@ -809,10 +750,7 @@ class StructSerializerGenerator extends BaseSerializerGenerator { `; } - private writeVarInt32Run( - accessor: string, - fields: { key: string; typeInfo: TypeInfo }[], - ) { + private writeVarInt32Run(accessor: string, fields: { key: string; typeInfo: TypeInfo }[]) { const cursor = this.scope.uniqueName("cursor"); const buffer = this.scope.uniqueName("buffer"); const dataView = this.scope.uniqueName("dataView"); @@ -879,10 +817,7 @@ class StructSerializerGenerator extends BaseSerializerGenerator { read(accessor: (expr: string) => string, refState: string): string { const result = this.scope.uniqueName("result"); const hash = this.typeMeta.computeStructHash(); - if ( - !this.typeInfo.options?.props - || Object.keys(this.typeInfo.options.props).length === 0 - ) { + if (!this.typeInfo.options?.props || Object.keys(this.typeInfo.options.props).length === 0) { return ` ${ !this.builder.resolver.isCompatible() @@ -903,10 +838,7 @@ class StructSerializerGenerator extends BaseSerializerGenerator { ${accessor(result)}; `; } - const directNumericObjectRead = this.readDirectNumericObject( - accessor, - refState, - ); + const directNumericObjectRead = this.readDirectNumericObject(accessor, refState); if (directNumericObjectRead !== null) { return ` ${ @@ -941,9 +873,7 @@ class StructSerializerGenerator extends BaseSerializerGenerator { const ${result} = { ${this.sortedProps .map(({ key }) => { - if ( - shouldSkipCompatibleRead(this.typeInfo.options!.props![key]) - ) { + if (shouldSkipCompatibleRead(this.typeInfo.options!.props![key])) { return ""; } return `${CodecBuilder.safePropName(key)}: null`; @@ -960,13 +890,9 @@ class StructSerializerGenerator extends BaseSerializerGenerator { if (!InnerGeneratorClass) { throw new Error(`${typeInfo.typeId} generator not exists`); } - const innerGenerator = new InnerGeneratorClass( - typeInfo, - this.builder, - this.scope, - ); + const innerGenerator = new InnerGeneratorClass(typeInfo, this.builder, this.scope); return ` - ${this.readField(key, typeInfo, expr => `${result}${CodecBuilder.safePropAccessor(key)} = ${expr}`, innerGenerator.readEmbed())} + ${this.readField(key, typeInfo, (expr) => `${result}${CodecBuilder.safePropAccessor(key)} = ${expr}`, innerGenerator.readEmbed())} `; }) .join(";\n")} @@ -978,17 +904,11 @@ class StructSerializerGenerator extends BaseSerializerGenerator { accessor: (expr: string) => string, refState: string, ): string | null { - const varInt32ObjectRead = this.readDirectVarInt32Object( - accessor, - refState, - ); + const varInt32ObjectRead = this.readDirectVarInt32Object(accessor, refState); if (varInt32ObjectRead !== null) { return varInt32ObjectRead; } - if ( - this.typeInfo.options!.withConstructor - || this.sortedProps.length === 0 - ) { + if (this.typeInfo.options!.withConstructor || this.sortedProps.length === 0) { return null; } const fields: Array<{ key: string; expr: string }> = []; @@ -997,15 +917,15 @@ class StructSerializerGenerator extends BaseSerializerGenerator { return null; } const scalarAction = getCompatibleScalarReadAction(typeInfo); - const expr - = scalarAction?.remoteNullable === true + const expr = + scalarAction?.remoteNullable === true ? null : scalarAction ? compatibleScalarFieldReadExpr( - scalarAction.remoteTypeId, - scalarAction.localTypeId, - this.builder, - ) + scalarAction.remoteTypeId, + scalarAction.localTypeId, + this.builder, + ) : directNumericFieldReadExpr(typeInfo, this.builder); if (expr === null) { return null; @@ -1027,10 +947,7 @@ class StructSerializerGenerator extends BaseSerializerGenerator { accessor: (expr: string) => string, refState: string, ): string | null { - if ( - this.typeInfo.options!.withConstructor - || this.sortedProps.length === 0 - ) { + if (this.typeInfo.options!.withConstructor || this.sortedProps.length === 0) { return null; } const fields = []; @@ -1120,10 +1037,7 @@ class StructSerializerGenerator extends BaseSerializerGenerator { } readWithDepth(assignStmt: (v: string) => string, refState: string): string { - if ( - !this.typeInfo.options?.props - || Object.keys(this.typeInfo.options.props).length === 0 - ) { + if (!this.typeInfo.options?.props || Object.keys(this.typeInfo.options.props).length === 0) { const result = this.scope.uniqueName("result"); return ` ${this.builder.getReadContextName()}.incReadDepth(); @@ -1137,12 +1051,9 @@ class StructSerializerGenerator extends BaseSerializerGenerator { readNoRef(assignStmt: (v: string) => string, refState: string): string { const result = this.scope.uniqueName("result"); - if ( - !this.typeInfo.options?.props - || Object.keys(this.typeInfo.options.props).length === 0 - ) { + if (!this.typeInfo.options?.props || Object.keys(this.typeInfo.options.props).length === 0) { return this.readTypeInfoThen( - changedSerializer => ` + (changedSerializer) => ` ${assignStmt(`${changedSerializer}.read(${refState})`)}; `, () => ` @@ -1156,25 +1067,25 @@ class StructSerializerGenerator extends BaseSerializerGenerator { } if (this.isDepthFreeStruct()) { return this.readTypeInfoThen( - changedSerializer => ` + (changedSerializer) => ` ${assignStmt(`${changedSerializer}.read(${refState})`)}; `, () => ` let ${result}; - ${this.read(v => `${result} = ${v}`, refState)}; + ${this.read((v) => `${result} = ${v}`, refState)}; ${assignStmt(result)}; `, true, ); } return this.readTypeInfoThen( - changedSerializer => ` + (changedSerializer) => ` ${assignStmt(`${changedSerializer}.read(${refState})`)}; `, () => ` ${this.builder.getReadContextName()}.incReadDepth(); let ${result}; - ${this.read(v => `${result} = ${v}`, refState)}; + ${this.read((v) => `${result} = ${v}`, refState)}; ${this.builder.getReadContextName()}.decReadDepth(); ${assignStmt(result)}; `, @@ -1254,13 +1165,11 @@ class StructSerializerGenerator extends BaseSerializerGenerator { const builder = this.builder; const internalTypeId = this.getInternalTypeId(); const serializer = builder.resolver.getSerializerByTypeInfo(this.typeInfo); - const canInlineCompatibleTypeInfo - = internalTypeId === TypeId.COMPATIBLE_STRUCT - || internalTypeId === TypeId.NAMED_COMPATIBLE_STRUCT - || (internalTypeId === TypeId.NAMED_STRUCT - && builder.resolver.isCompatible()); - const canUseHeaderCacheFastPath - = canInlineCompatibleTypeInfo && serializer?._initialized; + const canInlineCompatibleTypeInfo = + internalTypeId === TypeId.COMPATIBLE_STRUCT || + internalTypeId === TypeId.NAMED_COMPATIBLE_STRUCT || + (internalTypeId === TypeId.NAMED_STRUCT && builder.resolver.isCompatible()); + const canUseHeaderCacheFastPath = canInlineCompatibleTypeInfo && serializer?._initialized; const inlineCompatibleTypeInfo = ( onMetaChanged: (changedSerializer: string) => string, onMetaUnchanged: () => string, @@ -1297,8 +1206,7 @@ class StructSerializerGenerator extends BaseSerializerGenerator { const result = scope.uniqueName("result"); return ` ${inlineCompatibleTypeInfo( - changedSerializer => - `${accessor(`${changedSerializer}.read(${refState})`)};`, + (changedSerializer) => `${accessor(`${changedSerializer}.read(${refState})`)};`, () => ` ${builder.getReadContextName()}.incReadDepth(); let ${result} = ${hoisted}.read(${refState}); @@ -1322,7 +1230,7 @@ class StructSerializerGenerator extends BaseSerializerGenerator { ${result} = ${builder.referenceResolver.getReadRef(builder.reader.readVarUInt32())}; } else { ${inlineCompatibleTypeInfo( - changedSerializer => + (changedSerializer) => `${result} = ${changedSerializer}.read(${refFlag} === ${RefFlags.RefValueFlag});`, () => ` ${builder.getReadContextName()}.incReadDepth(); @@ -1400,18 +1308,13 @@ class StructSerializerGenerator extends BaseSerializerGenerator { let writeUserTypeIdStmt = ""; switch (internalTypeId) { case TypeId.STRUCT: - writeUserTypeIdStmt = this.builder.writer.writeVarUint32Small7( - this.typeInfo.userTypeId, - ); + writeUserTypeIdStmt = this.builder.writer.writeVarUint32Small7(this.typeInfo.userTypeId); break; case TypeId.NAMED_COMPATIBLE_STRUCT: case TypeId.COMPATIBLE_STRUCT: { const bytes = this.typeMetaBytesExpr(); - typeMeta = this.builder.typeMetaResolver.writeTypeMeta( - this.builder.getTypeInfo(), - bytes, - ); + typeMeta = this.builder.typeMetaResolver.writeTypeMeta(this.builder.getTypeInfo(), bytes); } break; case TypeId.NAMED_STRUCT: @@ -1459,25 +1362,17 @@ class StructSerializerGenerator extends BaseSerializerGenerator { let fixedSize = 8; if (options!.props) { Object.values(options!.props).forEach((x) => { - const propGenerator = new (CodegenRegistry.get(x.typeId)!)( - x, - this.builder, - this.scope, - ); + const propGenerator = new (CodegenRegistry.get(x.typeId)!)(x, this.builder, this.scope); fixedSize += propGenerator.getFixedSize(); }); } else { - fixedSize += this.builder.resolver.getSerializerByName( - typeInfo.named!, - )!.fixedSize; + fixedSize += this.builder.resolver.getSerializerByName(typeInfo.named!)!.fixedSize; } return fixedSize; } getHash(): string { - return TypeMeta.fromTypeInfo(this.typeInfo, this.builder.resolver) - .getHash() - .toString(); + return TypeMeta.fromTypeInfo(this.typeInfo, this.builder.resolver).getHash().toString(); } getTypeMetaBytes(): string { @@ -1495,7 +1390,4 @@ class StructSerializerGenerator extends BaseSerializerGenerator { CodegenRegistry.register(TypeId.STRUCT, StructSerializerGenerator); CodegenRegistry.register(TypeId.NAMED_STRUCT, StructSerializerGenerator); CodegenRegistry.register(TypeId.COMPATIBLE_STRUCT, StructSerializerGenerator); -CodegenRegistry.register( - TypeId.NAMED_COMPATIBLE_STRUCT, - StructSerializerGenerator, -); +CodegenRegistry.register(TypeId.NAMED_COMPATIBLE_STRUCT, StructSerializerGenerator); diff --git a/javascript/test/graphMemoryBudget.test.ts b/javascript/test/graphMemoryBudget.test.ts index b86bc6fa34..15d271b34a 100644 --- a/javascript/test/graphMemoryBudget.test.ts +++ b/javascript/test/graphMemoryBudget.test.ts @@ -45,12 +45,8 @@ describe("graph memory budget", () => { const fory = new Fory({ compatible: false }); fory.readContext.reset(new Uint8Array(17)); - expect(() => - fory.readContext.reserveGraphMemory(DEFAULT_GRAPH_MEMORY_BYTES), - ).not.toThrow(); - expect(() => fory.readContext.reserveGraphMemory(1)).toThrow( - /maxGraphMemoryBytes/, - ); + expect(() => fory.readContext.reserveGraphMemory(DEFAULT_GRAPH_MEMORY_BYTES)).not.toThrow(); + expect(() => fory.readContext.reserveGraphMemory(1)).toThrow(/maxGraphMemoryBytes/); }); test("handles explicit config and disable", () => { @@ -58,15 +54,11 @@ describe("graph memory budget", () => { fory.readContext.reset(new Uint8Array(1)); expect(() => fory.readContext.reserveGraphMemory(0)).not.toThrow(); expect(() => fory.readContext.reserveGraphMemory(24)).not.toThrow(); - expect(() => fory.readContext.reserveGraphMemory(1)).toThrow( - /maxGraphMemoryBytes/, - ); + expect(() => fory.readContext.reserveGraphMemory(1)).toThrow(/maxGraphMemoryBytes/); const disabled = new Fory({ maxGraphMemoryBytes: 0 }); disabled.readContext.reset(new Uint8Array(1)); - expect(() => - disabled.readContext.reserveGraphMemory(Number.MAX_SAFE_INTEGER), - ).not.toThrow(); + expect(() => disabled.readContext.reserveGraphMemory(Number.MAX_SAFE_INTEGER)).not.toThrow(); expect(() => new Fory({ maxGraphMemoryBytes: -2 })).not.toThrow(); }); @@ -87,9 +79,7 @@ describe("graph memory budget", () => { maxGraphMemoryBytes: objectBytes(1) + listBytes(1) + listBytes(0) - 1, }).register(typeInfo); - expect(() => failingReader.deserialize(bytes)).toThrow( - /maxGraphMemoryBytes/, - ); + expect(() => failingReader.deserialize(bytes)).toThrow(/maxGraphMemoryBytes/); expect(passingReader.deserialize(bytes)).toEqual({ values: [[]] }); }); @@ -112,9 +102,7 @@ describe("graph memory budget", () => { maxGraphMemoryBytes: objectBytes(1) + listBytes(3) + 3 * listBytes(0) - 1, }).register(typeInfo); - expect(() => failingReader.deserialize(bytes)).toThrow( - /maxGraphMemoryBytes/, - ); + expect(() => failingReader.deserialize(bytes)).toThrow(/maxGraphMemoryBytes/); expect(passingReader.deserialize(bytes)).toEqual({ values: [[], [], []], }); @@ -148,9 +136,7 @@ describe("graph memory budget", () => { failingReader.register(childType); failingReader.register(typeInfo); - expect(() => failingReader.deserialize(bytes)).toThrow( - /maxGraphMemoryBytes/, - ); + expect(() => failingReader.deserialize(bytes)).toThrow(/maxGraphMemoryBytes/); expect(passingReader.deserialize(bytes)).toEqual({ first: {}, second: {}, @@ -186,9 +172,7 @@ describe("graph memory budget", () => { failingReader.register(EmptyChild); failingReader.register(EmptyParent); - expect(() => failingReader.deserialize(bytes)).toThrow( - /maxGraphMemoryBytes/, - ); + expect(() => failingReader.deserialize(bytes)).toThrow(/maxGraphMemoryBytes/); const decoded = passingReader.deserialize(bytes); expect(decoded).toBeInstanceOf(EmptyParent); expect(decoded.child).toBeInstanceOf(EmptyChild); @@ -197,9 +181,7 @@ describe("graph memory budget", () => { test("reserves map entries", () => { const bytes = serializeAny(new Map([[1, 2]])); - expect(() => deserializeAny(bytes, mapBytes(1) - 1)).toThrow( - /maxGraphMemoryBytes/, - ); + expect(() => deserializeAny(bytes, mapBytes(1) - 1)).toThrow(/maxGraphMemoryBytes/); expect(deserializeAny(bytes, mapBytes(1))).toEqual(new Map([[1, 2]])); }); @@ -218,19 +200,15 @@ describe("graph memory budget", () => { const passingReader = new Fory({ compatible: false, ref: true, - maxGraphMemoryBytes: - objectBytes(3) + listBytes(1) + listBytes(1) + mapBytes(1), + maxGraphMemoryBytes: objectBytes(3) + listBytes(1) + listBytes(1) + mapBytes(1), }).register(typeInfo); const failingReader = new Fory({ compatible: false, ref: true, - maxGraphMemoryBytes: - objectBytes(3) + listBytes(1) + listBytes(1) + mapBytes(1) - 1, + maxGraphMemoryBytes: objectBytes(3) + listBytes(1) + listBytes(1) + mapBytes(1) - 1, }).register(typeInfo); - expect(() => failingReader.deserialize(bytes)).toThrow( - /maxGraphMemoryBytes/, - ); + expect(() => failingReader.deserialize(bytes)).toThrow(/maxGraphMemoryBytes/); expect(passingReader.deserialize(bytes)).toEqual({ list: [1], set: new Set(["a"]), @@ -256,12 +234,8 @@ describe("graph memory budget", () => { maxGraphMemoryBytes: objectBytes(1) - 1, }).register(readerType); - expect(() => failingReader.deserialize(bytes)).toThrow( - /maxGraphMemoryBytes/, - ); - expect(Array.from(passingReader.deserialize(bytes).values)).toEqual([ - 1, 2, 3, - ]); + expect(() => failingReader.deserialize(bytes)).toThrow(/maxGraphMemoryBytes/); + expect(Array.from(passingReader.deserialize(bytes).values)).toEqual([1, 2, 3]); }); test("skips scalar dense owners", () => { @@ -303,8 +277,6 @@ describe("graph memory budget", () => { maxGraphMemoryBytes: 1024 * 1024, }).register(typeInfo); - expect(() => - reader.deserialize(bytes.slice(0, bytes.length - 1)), - ).toThrow(); + expect(() => reader.deserialize(bytes.slice(0, bytes.length - 1))).toThrow(); }); }); From 065bbc22cc7baa355e2057686e6ca54ac4b98284 Mon Sep 17 00:00:00 2001 From: chaokunyang Date: Fri, 3 Jul 2026 00:41:42 +0800 Subject: [PATCH 30/54] refactor: clean graph memory budget owner paths --- .agents/docs-and-formatting.md | 7 +- csharp/src/Fory/ReadContext.cs | 26 ----- .../Sources/Fory/CollectionSerializers.swift | 62 +++++----- swift/Sources/Fory/FieldCodecs.swift | 67 ++++++++--- swift/Sources/Fory/Fory.swift | 14 --- .../ForyTests/GraphMemoryBudgetTests.swift | 107 ++++++++++++++++-- 6 files changed, 184 insertions(+), 99 deletions(-) diff --git a/.agents/docs-and-formatting.md b/.agents/docs-and-formatting.md index 753c2ac815..a8b7e54250 100644 --- a/.agents/docs-and-formatting.md +++ b/.agents/docs-and-formatting.md @@ -45,10 +45,9 @@ Load this file when changing documentation, public APIs, protocol specs, benchma - Python code, including `compiler/`, `benchmarks/`, `integration_tests/`, and `python/`: `python -m ruff format ` and `python -m ruff check --fix ` -- JavaScript/TypeScript under `javascript/`: use the package's ESLint-owned formatting path - (`npm run lint -- --fix` when fixing style, `npm run lint -- --quiet` when checking). Do not run - Prettier on JavaScript or TypeScript files unless that package has an explicit Prettier config or - script; otherwise it creates unrelated formatting churn. +- JavaScript/TypeScript under `javascript/`: use the package formatter scripts + (`npm run format` when fixing style, `npm run format-check` when checking). They run Prettier + before ESLint; do not use raw ESLint as the formatting gate. - Repo-wide format and lint sweep: `bash ci/format.sh --all` When code changes touch `compiler/` or `benchmarks/`, format those changed source files with the diff --git a/csharp/src/Fory/ReadContext.cs b/csharp/src/Fory/ReadContext.cs index f82d329632..56a90df503 100644 --- a/csharp/src/Fory/ReadContext.cs +++ b/csharp/src/Fory/ReadContext.cs @@ -117,32 +117,6 @@ public void ReserveGraphMemory(long bytes) _remainingGraphMemoryBytes = remaining - bytes; } - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public void ReserveGraphMemory(int bytes) - { - long remaining = _remainingGraphMemoryBytes; - if (bytes < 0 || bytes > remaining) - { - ReserveGraphMemorySlow(bytes, remaining); - return; - } - - _remainingGraphMemoryBytes = remaining - bytes; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public void ReserveGraphMemory(uint bytes) - { - long remaining = _remainingGraphMemoryBytes - bytes; - if (remaining < 0) - { - ReserveGraphMemorySlow(bytes, _remainingGraphMemoryBytes); - return; - } - - _remainingGraphMemoryBytes = remaining; - } - [MethodImpl(MethodImplOptions.NoInlining)] private void ReserveGraphMemorySlow(long bytes, long remaining) { diff --git a/swift/Sources/Fory/CollectionSerializers.swift b/swift/Sources/Fory/CollectionSerializers.swift index 137b9fb37b..73afdfec31 100644 --- a/swift/Sources/Fory/CollectionSerializers.swift +++ b/swift/Sources/Fory/CollectionSerializers.swift @@ -42,18 +42,28 @@ private func storedElementBytes(_ type: Element.Type) -> In } @inline(__always) -private func reserveGraphStorage( +private func storedOwnerBytes(_ type: T.Type) -> Int { + max(1, MemoryLayout.stride) +} + +@inline(__always) +private func reserveGraphElements( _ context: ReadContext, + ownerBytes: Int, count: Int, elementBytes: Int ) throws { - if count < 0 || elementBytes < 0 { + if ownerBytes < 0 || count < 0 || elementBytes < 0 { throw ForyError.invalidData("graph memory estimate overflows") } - let (bytes, overflow) = count.multipliedReportingOverflow(by: elementBytes) + let (storageBytes, overflow) = count.multipliedReportingOverflow(by: elementBytes) if overflow { throw ForyError.invalidData("graph memory estimate overflows") } + let (bytes, addOverflow) = ownerBytes.addingReportingOverflow(storageBytes) + if addOverflow { + throw ForyError.invalidData("graph memory estimate overflows") + } try context.reserveGraphMemory(bytes) } @@ -61,9 +71,11 @@ private func reserveGraphStorage( private func reserveGraphArrayMemory( _ context: ReadContext, _ type: Element.Type, + ownerBytes: Int, count: Int ) throws { - try reserveGraphStorage(context, count: count, elementBytes: storedElementBytes(type)) + try reserveGraphElements( + context, ownerBytes: ownerBytes, count: count, elementBytes: storedElementBytes(type)) } @inline(__always) @@ -71,6 +83,7 @@ private func reserveGraphMapMemory( _ context: ReadContext, key: Key.Type, value: Value.Type, + ownerBytes: Int, count: Int ) throws { let keyBytes = storedElementBytes(key) @@ -79,24 +92,7 @@ private func reserveGraphMapMemory( if overflow { throw ForyError.invalidData("graph memory estimate overflows") } - try reserveGraphStorage(context, count: count, elementBytes: elementBytes) -} - -private func primitiveArrayTypeID(for _: Element.Type) -> TypeId? { - if Element.self == UInt8.self { return .uint8Array } - if Element.self == Bool.self { return .boolArray } - if Element.self == Int8.self { return .int8Array } - if Element.self == Int16.self { return .int16Array } - if Element.self == Int32.self { return .int32Array } - if Element.self == Int64.self { return .int64Array } - if Element.self == UInt16.self { return .uint16Array } - if Element.self == UInt32.self { return .uint32Array } - if Element.self == UInt64.self { return .uint64Array } - if Element.self == Float16.self { return .float16Array } - if Element.self == BFloat16.self { return .bfloat16Array } - if Element.self == Float.self { return .float32Array } - if Element.self == Double.self { return .float64Array } - return nil + try reserveGraphElements(context, ownerBytes: ownerBytes, count: count, elementBytes: elementBytes) } private let hostIsLittleEndian = Int(littleEndian: 1) == 1 @@ -292,7 +288,8 @@ private func preparePrimitiveArray( ) throws { try context.ensureCollectionLength(count, label: label) if reserveGraphStorage { - try reserveGraphArrayMemory(context, type, count: count) + try reserveGraphArrayMemory( + context, type, ownerBytes: storedOwnerBytes([Element].self), count: count) } } @@ -633,9 +630,11 @@ extension Array: Serializer where Element: Serializer { let buffer = context.buffer let length = Int(try buffer.readVarUInt32()) try context.ensureCollectionLength(length, label: "array") + let ownerBytes = reserveGraphStorage ? storedOwnerBytes([Element].self) : 0 if length == 0 { if reserveGraphStorage { - try reserveGraphArrayMemory(context, Element.self, count: length) + try reserveGraphArrayMemory( + context, Element.self, ownerBytes: ownerBytes, count: length) } return [] } @@ -647,7 +646,8 @@ extension Array: Serializer where Element: Serializer { let sameType = (header & CollectionHeader.sameType) != 0 if !sameType { if reserveGraphStorage { - try reserveGraphArrayMemory(context, Element.self, count: length) + try reserveGraphArrayMemory( + context, Element.self, ownerBytes: ownerBytes, count: length) } try context.ensureRemainingBytes(length, label: "array") if trackRef { @@ -688,7 +688,7 @@ extension Array: Serializer where Element: Serializer { let elementTypeInfo = declared ? nil : try Element.foryReadTypeInfo(context) if reserveGraphStorage { - try reserveGraphArrayMemory(context, Element.self, count: length) + try reserveGraphArrayMemory(context, Element.self, ownerBytes: ownerBytes, count: length) } try context.ensureRemainingBytes(length, label: "array") return try context.withTypeInfo(elementTypeInfo, for: Element.self) { @@ -749,7 +749,8 @@ extension Set: Serializer where Element: Serializer & Hashable { public static func foryReadData(_ context: ReadContext) throws -> Set { let values = try [Element].readData(context, reserveGraphStorage: false) - try reserveGraphArrayMemory(context, Element.self, count: values.count) + try reserveGraphArrayMemory( + context, Element.self, ownerBytes: storedOwnerBytes(Set.self), count: values.count) return Set(values) } } @@ -980,12 +981,15 @@ extension Dictionary: Serializer where Key: Serializer & Hashable, Value: Serial public static func foryReadData(_ context: ReadContext) throws -> [Key: Value] { let totalLength = Int(try context.buffer.readVarUInt32()) try context.ensureCollectionLength(totalLength, label: "map") + let ownerBytes = storedOwnerBytes(Dictionary.self) if totalLength == 0 { - try reserveGraphMapMemory(context, key: Key.self, value: Value.self, count: totalLength) + try reserveGraphMapMemory( + context, key: Key.self, value: Value.self, ownerBytes: ownerBytes, count: totalLength) return [:] } - try reserveGraphMapMemory(context, key: Key.self, value: Value.self, count: totalLength) + try reserveGraphMapMemory( + context, key: Key.self, value: Value.self, ownerBytes: ownerBytes, count: totalLength) try context.ensureRemainingBytes(totalLength, label: "map") var map: [Key: Value] = [:] map.reserveCapacity(totalLength) diff --git a/swift/Sources/Fory/FieldCodecs.swift b/swift/Sources/Fory/FieldCodecs.swift index 6aa375fec7..5285cfe1dc 100644 --- a/swift/Sources/Fory/FieldCodecs.swift +++ b/swift/Sources/Fory/FieldCodecs.swift @@ -29,19 +29,29 @@ private func serializerElementBytes(_ type: Element.Type) - type.isRefType ? fieldReferenceBytes : max(1, MemoryLayout.stride) } +@inline(__always) +private func fieldOwnerBytes(_ type: T.Type) -> Int { + max(1, MemoryLayout.stride) +} + @inline(__always) private func reserveFieldStorage( _ context: ReadContext, + ownerBytes: Int, count: Int, elementBytes: Int ) throws { - if count < 0 || elementBytes < 0 { + if ownerBytes < 0 || count < 0 || elementBytes < 0 { throw ForyError.invalidData("graph memory estimate overflows") } - let (bytes, overflow) = count.multipliedReportingOverflow(by: elementBytes) + let (storageBytes, overflow) = count.multipliedReportingOverflow(by: elementBytes) if overflow { throw ForyError.invalidData("graph memory estimate overflows") } + let (bytes, addOverflow) = ownerBytes.addingReportingOverflow(storageBytes) + if addOverflow { + throw ForyError.invalidData("graph memory estimate overflows") + } try context.reserveGraphMemory(bytes) } @@ -49,18 +59,22 @@ private func reserveFieldStorage( private func reserveFieldArrayStorage( _ context: ReadContext, _ codec: ElementCodec.Type, + ownerBytes: Int, count: Int ) throws { - try reserveFieldStorage(context, count: count, elementBytes: fieldElementBytes(codec)) + try reserveFieldStorage( + context, ownerBytes: ownerBytes, count: count, elementBytes: fieldElementBytes(codec)) } @inline(__always) private func reserveSerializerArrayMemory( _ context: ReadContext, _ type: Element.Type, + ownerBytes: Int, count: Int ) throws { - try reserveFieldStorage(context, count: count, elementBytes: serializerElementBytes(type)) + try reserveFieldStorage( + context, ownerBytes: ownerBytes, count: count, elementBytes: serializerElementBytes(type)) } @inline(__always) @@ -68,6 +82,7 @@ private func reserveFieldMapStorage: FieldCodec { } public static func readPayload(_ context: ReadContext) throws -> Value { - return try readCollectionPayload(context, elementCodec: ElementCodec.self) + return try readCollectionPayload( + context, + elementCodec: ElementCodec.self, + ownerBytes: fieldOwnerBytes([ElementCodec.Value].self) + ) } public static func readCompatibleField( @@ -910,8 +929,17 @@ public enum SetFieldCodec: FieldCodec where ElementCod } public static func readPayload(_ context: ReadContext) throws -> Value { - let values = try readCollectionPayload(context, elementCodec: ElementCodec.self) - try reserveFieldArrayStorage(context, ElementCodec.self, count: values.count) + let values = try readCollectionPayload( + context, + elementCodec: ElementCodec.self, + ownerBytes: 0 + ) + try reserveFieldArrayStorage( + context, + ElementCodec.self, + ownerBytes: fieldOwnerBytes(Set.self), + count: values.count + ) return Set(values) } } @@ -1030,14 +1058,17 @@ where KeyCodec.Value: Hashable { public static func readPayload(_ context: ReadContext) throws -> Value { let totalLength = Int(try context.buffer.readVarUInt32()) try context.ensureCollectionLength(totalLength, label: "map") + let ownerBytes = fieldOwnerBytes(Dictionary.self) if totalLength == 0 { try reserveFieldMapStorage( - context, key: KeyCodec.self, value: ValueCodec.self, count: totalLength) + context, key: KeyCodec.self, value: ValueCodec.self, ownerBytes: ownerBytes, + count: totalLength) return [:] } try reserveFieldMapStorage( - context, key: KeyCodec.self, value: ValueCodec.self, count: totalLength) + context, key: KeyCodec.self, value: ValueCodec.self, ownerBytes: ownerBytes, + count: totalLength) try context.ensureRemainingBytes(totalLength, label: "map") var map: Value = [:] map.reserveCapacity(totalLength) @@ -1419,7 +1450,8 @@ private func readIntArrayPayload( { let count = try readPackedArrayElementCount(context, width: 8, label: "int64_array") if reserveGraphStorage { - try reserveSerializerArrayMemory(context, Int.self, count: count) + try reserveSerializerArrayMemory( + context, Int.self, ownerBytes: fieldOwnerBytes([Int].self), count: count) } var values: [Int] = [] values.reserveCapacity(count) @@ -1436,7 +1468,8 @@ private func readUIntArrayPayload( { let count = try readPackedArrayElementCount(context, width: 8, label: "uint64_array") if reserveGraphStorage { - try reserveSerializerArrayMemory(context, UInt.self, count: count) + try reserveSerializerArrayMemory( + context, UInt.self, ownerBytes: fieldOwnerBytes([UInt].self), count: count) } var values: [UInt] = [] values.reserveCapacity(count) @@ -1747,13 +1780,15 @@ private func writeCollectionPayload( private func readCollectionPayload( _ context: ReadContext, - elementCodec _: ElementCodec.Type + elementCodec _: ElementCodec.Type, + ownerBytes: Int ) throws -> [ElementCodec.Value] { let buffer = context.buffer let length = Int(try buffer.readVarUInt32()) try context.ensureCollectionLength(length, label: "array") if length == 0 { - try reserveFieldArrayStorage(context, ElementCodec.self, count: length) + try reserveFieldArrayStorage( + context, ElementCodec.self, ownerBytes: ownerBytes, count: length) return [] } @@ -1768,7 +1803,7 @@ private func readCollectionPayload( let sameType = (header & CollectionHeader.sameType) != 0 var result: [ElementCodec.Value] = [] - try reserveFieldArrayStorage(context, ElementCodec.self, count: length) + try reserveFieldArrayStorage(context, ElementCodec.self, ownerBytes: ownerBytes, count: length) try context.ensureRemainingBytes(length, label: "array") result.reserveCapacity(length) @@ -1854,7 +1889,6 @@ private func readListPayloadAsArrayPayload( let length = Int(try buffer.readVarUInt32()) try context.ensureCollectionLength(length, label: "array") if length == 0 { - try reserveFieldArrayStorage(context, ElementCodec.self, count: length) return [] } @@ -1882,7 +1916,6 @@ private func readListPayloadAsArrayPayload( } try context.ensureRemainingBytes(length, label: "array") var result: [ElementCodec.Value] = [] - try reserveFieldArrayStorage(context, ElementCodec.self, count: length) result.reserveCapacity(length) return try ElementCodec.withTypeInfo(elementTypeInfo, context) { for _ in 0..( context: ReadContext ) throws -> T { - try reserveRootGraphOwner(T.self, context: context) return try T.foryRead( context, refMode: refMode, @@ -488,19 +487,6 @@ public final class Fory { ) } - @inline(__always) - private func reserveRootGraphOwner( - _: T.Type, - context: ReadContext - ) throws { - switch T.staticTypeId { - case .list, .set, .map: - try context.reserveGraphMemory(max(1, MemoryLayout.stride)) - default: - break - } - } - @inline(__always) func withReusableReadContext( data: Data, diff --git a/swift/Tests/ForyTests/GraphMemoryBudgetTests.swift b/swift/Tests/ForyTests/GraphMemoryBudgetTests.swift index ed5c885d83..0bb40f985c 100644 --- a/swift/Tests/ForyTests/GraphMemoryBudgetTests.swift +++ b/swift/Tests/ForyTests/GraphMemoryBudgetTests.swift @@ -45,6 +45,17 @@ private struct BudgetDenseHolder: Equatable { var dense: [Int32] = [] } +@ForyStruct +private struct BudgetListDenseWriter { + var dense: [Int32] = [] +} + +@ForyStruct +private struct BudgetListDenseReader: Equatable { + @ArrayField(element: .int32()) + var dense: [Int32] = [] +} + private let defaultGraphMemoryBytes: Int64 = 128 * 1024 * 1024 private func makeBudgetFory(maxGraphMemoryBytes: Int64 = defaultGraphMemoryBytes) -> Fory { @@ -60,6 +71,15 @@ private func makeBudgetFory(maxGraphMemoryBytes: Int64 = defaultGraphMemoryBytes return fory } +private func makeCompatibleBudgetFory(maxGraphMemoryBytes: Int64 = defaultGraphMemoryBytes) -> Fory { + Fory( + config: .init( + trackRef: false, + compatible: true, + maxGraphMemoryBytes: maxGraphMemoryBytes + )) +} + private let testReferenceBytes = 4 private let budgetNodeGraphBytes = 1 + 4 @@ -75,7 +95,7 @@ private func arrayBudget(_ type: Element.Type, count: Int) count * elementBytes(type) } -private func rootArrayBudget( +private func listBudget( _ type: Element.Type, count: Int, elementOwnerBytes: Int = 0 @@ -83,6 +103,14 @@ private func rootArrayBudget( ownerBytes([Element].self) + arrayBudget(type, count: count) + count * elementOwnerBytes } +private func rootArrayBudget( + _ type: Element.Type, + count: Int, + elementOwnerBytes: Int = 0 +) -> Int { + listBudget(type, count: count, elementOwnerBytes: elementOwnerBytes) +} + private func mapBudget( key: Key.Type, value: Value.Type, @@ -91,7 +119,7 @@ private func mapBudget( count * (elementBytes(key) + elementBytes(value)) } -private func rootMapBudget( +private func dictionaryBudget( key: Key.Type, value: Value.Type, count: Int @@ -99,6 +127,14 @@ private func rootMapBudget( ownerBytes(Dictionary.self) + mapBudget(key: key, value: value, count: count) } +private func rootMapBudget( + key: Key.Type, + value: Value.Type, + count: Int +) -> Int { + dictionaryBudget(key: key, value: value, count: count) +} + private func expectInvalidData(_ body: () throws -> Void) { do { try body() @@ -167,7 +203,7 @@ func siblingContainersShareOneBudget() throws { right: (16..<32).map { BudgetNode(id: Int32($0)) } ) let bytes = try makeBudgetFory().serialize(value) - let oneList = arrayBudget(BudgetNode.self, count: 16) + 16 * budgetNodeGraphBytes + let oneList = listBudget(BudgetNode.self, count: 16, elementOwnerBytes: budgetNodeGraphBytes) let required = ownerBytes(BudgetSiblings.self) + oneList * 2 expectInvalidData { @@ -180,6 +216,22 @@ func siblingContainersShareOneBudget() throws { #expect(decoded.right.count == 16) } +@Test +func nestedEmptyArraysChargeOwner() throws { + let count = 3 + let value = Array(repeating: [String](), count: count) + let bytes = try makeBudgetFory().serialize(value) + let required = listBudget([String].self, count: count) + count * ownerBytes([String].self) + + expectInvalidData { + let _: [[String]] = try makeBudgetFory(maxGraphMemoryBytes: Int64(required - 1)) + .deserialize(bytes) + } + let decoded: [[String]] = try makeBudgetFory(maxGraphMemoryBytes: Int64(required)) + .deserialize(bytes) + #expect(decoded == value) +} + @Test func mapBudgetIsCharged() throws { let value: [String: Int32] = ["a": 1, "b": 2, "c": 3] @@ -195,6 +247,21 @@ func mapBudgetIsCharged() throws { #expect(decoded == value) } +@Test +func emptyTypedMapOwnerIsCharged() throws { + let value: [String: Int32] = [:] + let bytes = try makeBudgetFory().serialize(value) + let required = rootMapBudget(key: String.self, value: Int32.self, count: value.count) + + expectInvalidData { + let _: [String: Int32] = try makeBudgetFory(maxGraphMemoryBytes: Int64(required - 1)) + .deserialize(bytes) + } + let decoded: [String: Int32] = try makeBudgetFory(maxGraphMemoryBytes: Int64(required)) + .deserialize(bytes) + #expect(decoded == value) +} + @Test func referenceAndInlineValueArraysAreCharged() throws { let nodes = (0..<4).map { BudgetNode(id: Int32($0)) } @@ -261,7 +328,8 @@ func dynamicAnyEmptyMapOwnerSelf() throws { let value = [:] as [AnyHashable: Any] let bytes = try makeBudgetFory().serialize(value as Any) let required = - ownerBytes(Dictionary.self) + dictionaryBudget(key: AnyHashable.self, value: SerializableAny.self, count: value.count) + + ownerBytes(Dictionary.self) + ownerBytes(Dictionary.self) expectInvalidData { @@ -277,7 +345,7 @@ func dynamicAnyEmptyMapOwnerSelf() throws { func publicAnyArrayBudget() throws { let value: [Any] = [Int32(1), Int32(2), Int32(3)] let bytes = try makeBudgetFory().serialize(value) - let wrappedBudget = arrayBudget(SerializableAny.self, count: value.count) + let wrappedBudget = listBudget(SerializableAny.self, count: value.count) let finalBudget = ownerBytes([Any].self) + value.count * testReferenceBytes expectInvalidData { @@ -293,7 +361,7 @@ func publicAnyArrayBudget() throws { func publicAnyMapBudget() throws { let stringMap: [String: Any] = ["a": Int32(1), "b": Int32(2), "c": Int32(3)] let stringBytes = try makeBudgetFory().serialize(stringMap) - let stringWrapped = mapBudget( + let stringWrapped = dictionaryBudget( key: String.self, value: SerializableAny.self, count: stringMap.count @@ -310,7 +378,7 @@ func publicAnyMapBudget() throws { let intMap: [Int32: Any] = [1: Int32(10), 2: Int32(20), 3: Int32(30)] let intBytes = try makeBudgetFory().serialize(intMap) - let intWrapped = mapBudget( + let intWrapped = dictionaryBudget( key: Int32.self, value: SerializableAny.self, count: intMap.count @@ -330,7 +398,7 @@ func publicAnyMapBudget() throws { AnyHashable(true): Int32(3) ] let anyHashableBytes = try makeBudgetFory().serialize(anyHashableMap) - let anyHashableWrapped = mapBudget( + let anyHashableWrapped = dictionaryBudget( key: AnyHashable.self, value: SerializableAny.self, count: anyHashableMap.count @@ -354,7 +422,7 @@ func dynamicAnyArrayBudget() throws { let value: Any = list let bytes = try makeBudgetFory().serialize(value) let count = list.count - let wrappedBudget = arrayBudget(SerializableAny.self, count: count) + let wrappedBudget = listBudget(SerializableAny.self, count: count) let finalBudget = ownerBytes([Any].self) + count * testReferenceBytes expectInvalidData { @@ -366,6 +434,27 @@ func dynamicAnyArrayBudget() throws { #expect((decoded as? [Any])?.count == count) } +@Test +func compatibleListToDenseArraySkipsLeafOwner() throws { + let writer = makeCompatibleBudgetFory() + writer.register(BudgetListDenseWriter.self, id: 9804) + let reader = makeCompatibleBudgetFory( + maxGraphMemoryBytes: Int64(ownerBytes(BudgetListDenseReader.self)) + ) + reader.register(BudgetListDenseReader.self, id: 9804) + let bytes = try writer.serialize(BudgetListDenseWriter(dense: [1, 2, 3])) + + expectInvalidData { + let failingReader = makeCompatibleBudgetFory( + maxGraphMemoryBytes: Int64(ownerBytes(BudgetListDenseReader.self) - 1) + ) + failingReader.register(BudgetListDenseReader.self, id: 9804) + let _: BudgetListDenseReader = try failingReader.deserialize(bytes) + } + let decoded: BudgetListDenseReader = try reader.deserialize(bytes) + #expect(decoded.dense == [1, 2, 3]) +} + @Test func byteAvailabilityCheckStillRejectsLargeLength() throws { let buffer = ByteBuffer() From c1857ec71f2e706baeb1c8c54660d4d53d62ee6b Mon Sep 17 00:00:00 2001 From: chaokunyang Date: Fri, 3 Jul 2026 00:44:23 +0800 Subject: [PATCH 31/54] refactor: rename graph memory limit error helper --- cpp/fory/serialization/context.cc | 2 +- cpp/fory/serialization/context.h | 2 +- go/fory/array.go | 2 +- go/fory/map.go | 6 +++--- go/fory/map_primitive.go | 4 ++-- go/fory/reader.go | 8 ++++---- go/fory/set.go | 8 ++++---- go/fory/slice.go | 4 ++-- go/fory/slice_dyn.go | 4 ++-- go/fory/slice_primitive.go | 4 ++-- go/fory/slice_primitive_list.go | 8 ++++---- 11 files changed, 26 insertions(+), 26 deletions(-) diff --git a/cpp/fory/serialization/context.cc b/cpp/fory/serialization/context.cc index 0cec51704a..3888882b5b 100644 --- a/cpp/fory/serialization/context.cc +++ b/cpp/fory/serialization/context.cc @@ -743,7 +743,7 @@ const TypeInfo *ReadContext::read_any_type_info(Error &error) { return result.value(); } -bool ReadContext::set_graph_memory_error(const std::string &message) { +bool ReadContext::set_graph_memory_limit_error(const std::string &message) { set_error(Error::invalid_data(message)); return false; } diff --git a/cpp/fory/serialization/context.h b/cpp/fory/serialization/context.h index cc09617458..36469d6f2a 100644 --- a/cpp/fory/serialization/context.h +++ b/cpp/fory/serialization/context.h @@ -693,7 +693,7 @@ class ReadContext { FORY_NOINLINE Result check_remote_type_meta_limit(const TypeMeta &type_meta); void record_remote_type_meta(const std::string &type_key); - FORY_NOINLINE bool set_graph_memory_error(const std::string &message); + FORY_NOINLINE bool set_graph_memory_limit_error(const std::string &message); FORY_NOINLINE bool set_graph_memory_exceeded(size_t bytes, size_t remaining); // Error state - accumulated during deserialization, checked at the end diff --git a/go/fory/array.go b/go/fory/array.go index 4a6b3d2588..7c05463320 100644 --- a/go/fory/array.go +++ b/go/fory/array.go @@ -320,7 +320,7 @@ func (s arrayDynSerializer) ReadData(ctx *ReadContext, value reflect.Value) { sliceType := reflect.SliceOf(value.Type().Elem()) elemBytes := int64(value.Type().Elem().Size()) if int64(value.Len()) > maxGraphCount(elemBytes) { - ctx.setGraphMemoryError("graph memory estimate overflows: length=%d elementBytes=%d", value.Len(), elemBytes) + ctx.setGraphMemoryLimitError("graph memory estimate overflows: length=%d elementBytes=%d", value.Len(), elemBytes) return } if !ctx.ReserveGraphMemory(int64(value.Len()) * elemBytes) { diff --git a/go/fory/map.go b/go/fory/map.go index a5f2acaccc..e2cab2ad71 100644 --- a/go/fory/map.go +++ b/go/fory/map.go @@ -307,15 +307,15 @@ func (s mapSerializer) ReadData(ctx *ReadContext, value reflect.Value) { valueBytes := int64(mapType.Elem().Size()) elemBytes := keyBytes + valueBytes if elemBytes < keyBytes { - ctx.setGraphMemoryError("map entry size overflows: key=%d value=%d", keyBytes, valueBytes) + ctx.setGraphMemoryLimitError("map entry size overflows: key=%d value=%d", keyBytes, valueBytes) return } if size < 0 { - ctx.setGraphMemoryError("negative graph element count: %d", size) + ctx.setGraphMemoryLimitError("negative graph element count: %d", size) return } if int64(size) > maxGraphCount(elemBytes) { - ctx.setGraphMemoryError("graph memory estimate overflows: length=%d elementBytes=%d", size, elemBytes) + ctx.setGraphMemoryLimitError("graph memory estimate overflows: length=%d elementBytes=%d", size, elemBytes) return } if !ctx.ReserveGraphMemory(int64(size) * elemBytes) { diff --git a/go/fory/map_primitive.go b/go/fory/map_primitive.go index 0a658b2755..2ba6b55b4d 100644 --- a/go/fory/map_primitive.go +++ b/go/fory/map_primitive.go @@ -95,11 +95,11 @@ func readTypedMapSize(ctx *ReadContext, elemBytes int64, maxLength int64) (int, return 0, false } if size < 0 { - ctx.setGraphMemoryError("negative graph element count: %d", size) + ctx.setGraphMemoryLimitError("negative graph element count: %d", size) return 0, false } if int64(size) > maxLength { - ctx.setGraphMemoryError("graph memory estimate overflows: length=%d elementBytes=%d", size, elemBytes) + ctx.setGraphMemoryLimitError("graph memory estimate overflows: length=%d elementBytes=%d", size, elemBytes) return 0, false } if !ctx.ReserveGraphMemory(int64(size) * elemBytes) { diff --git a/go/fory/reader.go b/go/fory/reader.go index de46143fc2..c13a39b291 100644 --- a/go/fory/reader.go +++ b/go/fory/reader.go @@ -185,12 +185,12 @@ func (c *ReadContext) ReserveGraphMemory(bytes int64) bool { //go:noinline func (c *ReadContext) rejectGraphMemoryBytes(bytes int64) bool { - c.setGraphMemoryError("estimated graph memory must be non-negative, got %d bytes", bytes) + c.setGraphMemoryLimitError("estimated graph memory must be non-negative, got %d bytes", bytes) return false } //go:noinline -func (c *ReadContext) setGraphMemoryError(format string, args ...any) { +func (c *ReadContext) setGraphMemoryLimitError(format string, args ...any) { c.SetError(DeserializationErrorf(format, args...)) } @@ -674,11 +674,11 @@ func (c *ReadContext) readStringSliceData() []string { return nil } if length < 0 { - c.setGraphMemoryError("negative graph element count: %d", length) + c.setGraphMemoryLimitError("negative graph element count: %d", length) return nil } if int64(length) > stringMaxLength { - c.setGraphMemoryError("graph memory estimate overflows: length=%d elementBytes=%d", length, stringElementBytes) + c.setGraphMemoryLimitError("graph memory estimate overflows: length=%d elementBytes=%d", length, stringElementBytes) return nil } if !c.ReserveGraphMemory(int64(length) * stringElementBytes) { diff --git a/go/fory/set.go b/go/fory/set.go index a014d91f5d..0bd22f68fc 100644 --- a/go/fory/set.go +++ b/go/fory/set.go @@ -322,7 +322,7 @@ func (s setSerializer) ReadData(ctx *ReadContext, value reflect.Value) { valueBytes := int64(type_.Elem().Size()) elemBytes := keyBytes + valueBytes if elemBytes < keyBytes { - ctx.setGraphMemoryError("map entry size overflows: key=%d value=%d", keyBytes, valueBytes) + ctx.setGraphMemoryLimitError("map entry size overflows: key=%d value=%d", keyBytes, valueBytes) return } if !ctx.ReserveGraphMemory(0) { @@ -370,15 +370,15 @@ func (s setSerializer) ReadData(ctx *ReadContext, value reflect.Value) { valueBytes := int64(type_.Elem().Size()) elemBytes := keyBytes + valueBytes if elemBytes < keyBytes { - ctx.setGraphMemoryError("map entry size overflows: key=%d value=%d", keyBytes, valueBytes) + ctx.setGraphMemoryLimitError("map entry size overflows: key=%d value=%d", keyBytes, valueBytes) return } if length < 0 { - ctx.setGraphMemoryError("negative graph element count: %d", length) + ctx.setGraphMemoryLimitError("negative graph element count: %d", length) return } if int64(length) > maxGraphCount(elemBytes) { - ctx.setGraphMemoryError("graph memory estimate overflows: length=%d elementBytes=%d", length, elemBytes) + ctx.setGraphMemoryLimitError("graph memory estimate overflows: length=%d elementBytes=%d", length, elemBytes) return } if !ctx.ReserveGraphMemory(int64(length) * elemBytes) { diff --git a/go/fory/slice.go b/go/fory/slice.go index 67dbc708e2..95cf93e1e7 100644 --- a/go/fory/slice.go +++ b/go/fory/slice.go @@ -321,11 +321,11 @@ func (s *sliceSerializer) ReadData(ctx *ReadContext, value reflect.Value) { } if !isArrayType { if length < 0 { - ctx.setGraphMemoryError("negative graph element count: %d", length) + ctx.setGraphMemoryLimitError("negative graph element count: %d", length) return } if int64(length) > s.maxLength { - ctx.setGraphMemoryError("graph memory estimate overflows: length=%d elementBytes=%d", length, s.elemBytes) + ctx.setGraphMemoryLimitError("graph memory estimate overflows: length=%d elementBytes=%d", length, s.elemBytes) return } if !ctx.ReserveGraphMemory(int64(length) * s.elemBytes) { diff --git a/go/fory/slice_dyn.go b/go/fory/slice_dyn.go index 35594be458..80df03f0e8 100644 --- a/go/fory/slice_dyn.go +++ b/go/fory/slice_dyn.go @@ -289,11 +289,11 @@ func (s *sliceDynSerializer) readData(ctx *ReadContext, value reflect.Value, exp } if !allocatedByCaller { if length < 0 { - ctx.setGraphMemoryError("negative graph element count: %d", length) + ctx.setGraphMemoryLimitError("negative graph element count: %d", length) return } if int64(length) > s.maxLength { - ctx.setGraphMemoryError("graph memory estimate overflows: length=%d elementBytes=%d", length, s.elemBytes) + ctx.setGraphMemoryLimitError("graph memory estimate overflows: length=%d elementBytes=%d", length, s.elemBytes) return } if !ctx.ReserveGraphMemory(int64(length) * s.elemBytes) { diff --git a/go/fory/slice_primitive.go b/go/fory/slice_primitive.go index ec5c99d723..c7089776d5 100644 --- a/go/fory/slice_primitive.go +++ b/go/fory/slice_primitive.go @@ -653,11 +653,11 @@ func (s stringSliceSerializer) ReadData(ctx *ReadContext, value reflect.Value) { } ptr := (*[]string)(value.Addr().UnsafePointer()) if length < 0 { - ctx.setGraphMemoryError("negative graph element count: %d", length) + ctx.setGraphMemoryLimitError("negative graph element count: %d", length) return } if int64(length) > stringMaxLength { - ctx.setGraphMemoryError("graph memory estimate overflows: length=%d elementBytes=%d", length, stringElementBytes) + ctx.setGraphMemoryLimitError("graph memory estimate overflows: length=%d elementBytes=%d", length, stringElementBytes) return } if !ctx.ReserveGraphMemory(int64(length) * stringElementBytes) { diff --git a/go/fory/slice_primitive_list.go b/go/fory/slice_primitive_list.go index 3f920ad17e..c327815aa6 100644 --- a/go/fory/slice_primitive_list.go +++ b/go/fory/slice_primitive_list.go @@ -180,11 +180,11 @@ func (s primitiveListSerializer) ReadData(ctx *ReadContext, value reflect.Value) return } if length < 0 { - ctx.setGraphMemoryError("negative graph element count: %d", length) + ctx.setGraphMemoryLimitError("negative graph element count: %d", length) return } if int64(length) > s.maxLength { - ctx.setGraphMemoryError("graph memory estimate overflows: length=%d elementBytes=%d", length, s.elemBytes) + ctx.setGraphMemoryLimitError("graph memory estimate overflows: length=%d elementBytes=%d", length, s.elemBytes) return } if !ctx.ReserveGraphMemory(int64(length) * s.elemBytes) { @@ -293,11 +293,11 @@ func (s compatiblePrimitiveListToArraySerializer) ReadData(ctx *ReadContext, val } if value.Kind() == reflect.Slice { if length < 0 { - ctx.setGraphMemoryError("negative graph element count: %d", length) + ctx.setGraphMemoryLimitError("negative graph element count: %d", length) return } if int64(length) > s.listReader.maxLength { - ctx.setGraphMemoryError("graph memory estimate overflows: length=%d elementBytes=%d", length, s.listReader.elemBytes) + ctx.setGraphMemoryLimitError("graph memory estimate overflows: length=%d elementBytes=%d", length, s.listReader.elemBytes) return } if !ctx.ReserveGraphMemory(int64(length) * s.listReader.elemBytes) { From e0d568a754b8de6002d7258fb0de0374d11aca0e Mon Sep 17 00:00:00 2001 From: chaokunyang Date: Fri, 3 Jul 2026 01:21:24 +0800 Subject: [PATCH 32/54] refactor: clean graph memory collection readers --- .agents/languages/csharp.md | 3 + AGENTS.md | 3 + .../src/Fory.Generator/ForyModelGenerator.cs | 3 - csharp/src/Fory/CollectionSerializers.cs | 961 +++++------------- csharp/src/Fory/ReadContext.cs | 68 +- csharp/src/Fory/UnionSerializer.cs | 2 - .../serializer/collection_serializers.dart | 59 +- .../fory/test/collection_serializer_test.dart | 13 + 8 files changed, 323 insertions(+), 789 deletions(-) diff --git a/.agents/languages/csharp.md b/.agents/languages/csharp.md index 882f912930..a3685e9560 100644 --- a/.agents/languages/csharp.md +++ b/.agents/languages/csharp.md @@ -16,6 +16,9 @@ Load this file when changing `csharp/` or C# xlang behavior. memory-backed today, but the graph budget uses the same fixed default for every root shape. `ReadContext` may expose only raw byte reservation; concrete serializers and generated serializers must compute list, array, map, struct, and object byte formulas before calling it. +- `ReadContext` must not expose ref-publication pause/resume APIs. Keep nested reference + publication stack-aware inside read-state internals so immutable, generated, and conversion + serializers do not publish temporary owners. - For C# graph budget formulas, distinguish inline value storage from reference storage: use cheap value-type size for `List`/`T[]` value paths and the 4-byte reference fallback for reference paths. Class/reference serializers reserve their own shallow self cost plus field storage when diff --git a/AGENTS.md b/AGENTS.md index 81ac87c2dd..18a29a4041 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -34,6 +34,9 @@ This is the entry point for AI guidance in Apache Fory. Read this file first, th - Do not make assumptions about runtime behavior, ownership, registration, metadata construction, protocol semantics, or test coverage. Read the current code, owning docs/specs, and relevant tests before making a design judgment or implementation decision. If the evidence is incomplete, inspect more or state the uncertainty explicitly instead of filling gaps from memory or analogy with another runtime. - For untrusted deserialization, read `docs/security/deserialization.md` before changing allocation, stream filling, skip, reference, metadata, or policy validation behavior. Variable-length deserialization must not allocate or reserve backing/output capacity from attacker-declared lengths or counts before the byte owner has proven proportional readable bytes with `checkReadableBytes` or the runtime equivalent. Root graph memory reservation is accounting only and may happen before that byte check, but it must not replace the byte check. - Root deserialization graph memory budgets estimate lower-bound shallow memory for materialized graph owners, not exact heap accounting, input byte accounting, or raw element counts. `maxGraphMemoryBytes` defaults to fixed `128 MiB`; positive values override the default; explicit non-positive values intentionally disable graph-memory enforcement and must be documented as deserialization DoS risk for compact inputs that materialize large graphs. Do not derive this budget from root input size, and do not split known-length and stream root behavior. Read context/read state owns only raw byte reservation with `reserveGraphMemory(bytes)`; it must not expose counted arithmetic helpers or collection, map, array, struct, or object semantic reservation APIs. Concrete serializers and generated serializers own counted formulas and overflow checks for their allocation path. Reserve self storage exactly once at the owner that stores or allocates the value: reference/object runtimes reserve parent owner self cost plus reference storage and every referenced heap owner reserves its own shallow self cost when materialized; inline/value runtimes reserve value storage in the root, struct/product, collection, map, set, array, smart-pointer, or box owner that actually owns that storage, and nested value serializers do not unconditionally reserve their own `sizeof(T)` again. Collection, map, set, and reference-array owners reserve nonzero shallow self cost when independently materialized, plus backing/reference/inline storage. Struct, record, POJO, tuple/product, compatible, generated, and dynamic object owners reserve a nonzero shallow self cost plus shallow field storage. Reference fields use 4 bytes when reference size is not cheap or reliable to query; primitive/value fields use their storage width. Parents do not recursively include child object, collection, map, string, binary, or primitive dense-array contents. Skip enum/union as separate owners and skip dedicated string, binary, primitive scalar, primitive array, and primitive dense-array leaf owners. Do not guess allocator, bucket-table, node, debug, or map-entry overhead unless it is a documented lower-bound owner allocation. Do not add dynamic stream bytes-read accounting or nested hot-path cleanup just for this budget. +- Read context/read state must not expose ref-publication pause/resume APIs. Nested reference + publication should be stack-aware inside the read state so immutable, generated, and conversion + owners can materialize final values without publishing temporary owners. - For remote TypeDef/TypeMeta reads, the checked metadata cache is the only owner of remote "already validated" state. Cache hit means the header was previously parsed, body/hash-validated, policy-checked, and published by that cache, so the hot path must skip the body and use cached metadata without extra validation, hashing, limit checks, exact-local checks, allocation, or policy work. A known expected local TypeDef/TypeMeta header/hash match is a local-schema hit, not a remote cache miss: it may skip the body and use the local TypeInfo/TypeMeta without schema-version counting or cache publish. Cache miss is the only path that parses and validates non-local metadata, enforces limits, performs exact-local byte comparison when needed, and publishes remote metadata to the cache. Do not add nullable accepted-header fields, sentinel headers, per-TypeInfo markers, pending metadata state, parallel header-low/header-high slots, or parallel acceptance state for this decision. If a runtime needs a metadata hit hint, cache the concrete checked metadata owner object, such as the TypeInfo, TypeDef, or TypeMeta used by that runtime, and compare its validated header identity directly. - When a user corrects a non-obvious invariant, encode it in the nearest source comment before continuing, and also update `AGENTS.md`, `.agents/**`, docs, or specs when the rule is reusable beyond one file. Do not rely only on chat history, task notes, commit messages, or benchmark logs for corrections that protect security, protocol behavior, ownership, naming, or hot-path performance. - Reject semantic hacks. Do not bypass broken semantics by deleting cases, simplifying callers, adding coercion hooks, or using workaround fallbacks; fix the underlying bug and prove it with focused tests. diff --git a/csharp/src/Fory.Generator/ForyModelGenerator.cs b/csharp/src/Fory.Generator/ForyModelGenerator.cs index 4bbb832afe..8d7a6c2286 100644 --- a/csharp/src/Fory.Generator/ForyModelGenerator.cs +++ b/csharp/src/Fory.Generator/ForyModelGenerator.cs @@ -744,7 +744,6 @@ private static void EmitUnionSerializer(StringBuilder sb, TypeModel model) sb.AppendLine(); sb.AppendLine($" public override {model.TypeName} ReadData(global::Apache.Fory.ReadContext context)"); sb.AppendLine(" {"); - sb.AppendLine(" uint __foryPausedRef = context.PauseRefPublication();"); sb.AppendLine(" uint rawCaseId = context.Reader.ReadVarUInt32();"); sb.AppendLine(" if (rawCaseId > int.MaxValue)"); sb.AppendLine(" {"); @@ -762,7 +761,6 @@ private static void EmitUnionSerializer(StringBuilder sb, TypeModel model) sb.AppendLine(" {"); EmitReadUnionCasePayload(sb, unionCase, valueVar, 4); sb.AppendLine($" {model.TypeName} __foryUnion = new {unionCase.TypeName}({valueVar});"); - sb.AppendLine(" context.ResumeRefPublication(__foryPausedRef);"); sb.AppendLine(" return __foryUnion;"); sb.AppendLine(" }"); } @@ -776,7 +774,6 @@ private static void EmitUnionSerializer(StringBuilder sb, TypeModel model) else { sb.AppendLine($" {model.TypeName} __foryUnion = new {unknownCase.TypeName}(global::Apache.Fory.UnknownCaseSerializer.ReadPayload(context, caseId));"); - sb.AppendLine(" context.ResumeRefPublication(__foryPausedRef);"); sb.AppendLine(" return __foryUnion;"); } diff --git a/csharp/src/Fory/CollectionSerializers.cs b/csharp/src/Fory/CollectionSerializers.cs index b0a08b96fc..94d97fbe75 100644 --- a/csharp/src/Fory/CollectionSerializers.cs +++ b/csharp/src/Fory/CollectionSerializers.cs @@ -213,57 +213,14 @@ private static class ElementStorage internal static readonly int Bytes = typeof(T).IsValueType ? Unsafe.SizeOf() : ReferenceBytes; } - internal readonly struct CollectionFrame - { - internal CollectionFrame(bool trackRef, bool hasNull, bool declared, bool sameType) - { - TrackRef = trackRef; - HasNull = hasNull; - Declared = declared; - SameType = sameType; - } - - internal bool TrackRef { get; } - internal bool HasNull { get; } - internal bool Declared { get; } - internal bool SameType { get; } - } - - internal static int ReadLengthAndReserve(ReadContext context) - { - int length = checked((int)context.Reader.ReadVarUInt32()); - ReserveElementStorage(context, length); - return length; - } - - internal static CollectionFrame ReadFrame(ReadContext context, int length) - { - byte header = context.Reader.ReadUInt8(); - context.Reader.CheckBound(length); - return new CollectionFrame( - (header & CollectionBits.TrackingRef) != 0, - (header & CollectionBits.HasNull) != 0, - (header & CollectionBits.DeclaredElementType) != 0, - (header & CollectionBits.SameType) != 0); - } - - public static List ReadCollectionData( - Serializer elementSerializer, - ReadContext context, - bool reserveOwner = true, - bool storeOwnerRef = true) + public static List ReadCollectionData(Serializer elementSerializer, ReadContext context) { int length = checked((int)context.Reader.ReadVarUInt32()); - bool storeRef = storeOwnerRef && context.ShouldStoreRef; if (length == 0) { - if (reserveOwner) - { - ReserveElementStorage(context, length); - } - + ReserveElementStorage(context, length); List empty = []; - if (storeRef) + if (context.ShouldStoreRef) { context.StoreRef(empty); } @@ -280,14 +237,10 @@ public static List ReadCollectionData( bool hasNull = (header & CollectionBits.HasNull) != 0; bool declared = (header & CollectionBits.DeclaredElementType) != 0; bool sameType = (header & CollectionBits.SameType) != 0; - if (reserveOwner) - { - ReserveElementStorage(context, length); - } - + ReserveElementStorage(context, length); context.Reader.CheckBound(length); List values = new(length); - if (storeRef) + if (context.ShouldStoreRef) { context.StoreRef(values); } @@ -385,6 +338,242 @@ public static List ReadCollectionData( return values; } + private interface ICollectionBuilder + { + bool StoresOwnerRef { get; } + + TCollection Create(int length); + + void Add(TCollection values, T value); + } + + private readonly struct HashSetBuilder : ICollectionBuilder, T> where T : notnull + { + public bool StoresOwnerRef => true; + + public HashSet Create(int length) => new(length); + + public void Add(HashSet values, T value) => values.Add(value); + } + + private readonly struct SortedSetBuilder : ICollectionBuilder, T> where T : notnull + { + public bool StoresOwnerRef => true; + + public SortedSet Create(int length) => new(); + + public void Add(SortedSet values, T value) => values.Add(value); + } + + private readonly struct ImmutableHashSetBuilder : ICollectionBuilder.Builder, T> + { + public bool StoresOwnerRef => false; + + public ImmutableHashSet.Builder Create(int length) => ImmutableHashSet.CreateBuilder(); + + public void Add(ImmutableHashSet.Builder values, T value) => values.Add(value); + } + + private readonly struct LinkedListBuilder : ICollectionBuilder, T> + { + public bool StoresOwnerRef => true; + + public LinkedList Create(int length) => new(); + + public void Add(LinkedList values, T value) => values.AddLast(value); + } + + private readonly struct QueueBuilder : ICollectionBuilder, T> + { + public bool StoresOwnerRef => true; + + public Queue Create(int length) => new(length); + + public void Add(Queue values, T value) => values.Enqueue(value); + } + + private readonly struct StackBuilder : ICollectionBuilder, T> + { + public bool StoresOwnerRef => true; + + public Stack Create(int length) => new(length); + + public void Add(Stack values, T value) => values.Push(value); + } + + private static TCollection ReadCollectionOwner( + Serializer elementSerializer, + ReadContext context, + TBuilder builder) + where TBuilder : struct, ICollectionBuilder + { + int length = checked((int)context.Reader.ReadVarUInt32()); + ReserveElementStorage(context, length); + TCollection values = builder.Create(length); + if (builder.StoresOwnerRef && context.ShouldStoreRef) + { + context.StoreRef(values); + } + + if (length == 0) + { + return values; + } + + byte header = context.Reader.ReadUInt8(); + bool trackRef = (header & CollectionBits.TrackingRef) != 0; + bool hasNull = (header & CollectionBits.HasNull) != 0; + bool declared = (header & CollectionBits.DeclaredElementType) != 0; + bool sameType = (header & CollectionBits.SameType) != 0; + context.Reader.CheckBound(length); + if (!sameType) + { + if (trackRef) + { + for (int i = 0; i < length; i++) + { + builder.Add(values, elementSerializer.Read(context, RefMode.Tracking, true)); + } + + return values; + } + + if (hasNull) + { + for (int i = 0; i < length; i++) + { + sbyte refFlag = context.Reader.ReadInt8(); + if (refFlag == (sbyte)RefFlag.Null) + { + builder.Add(values, (T)elementSerializer.DefaultObject!); + } + else if (refFlag == (sbyte)RefFlag.NotNullValue) + { + builder.Add(values, elementSerializer.Read(context, RefMode.None, true)); + } + else + { + throw new RefException($"invalid nullability flag {refFlag}"); + } + } + } + else + { + for (int i = 0; i < length; i++) + { + builder.Add(values, elementSerializer.Read(context, RefMode.None, true)); + } + } + + return values; + } + + if (!declared) + { + context.TypeResolver.ReadTypeInfo(elementSerializer, context); + } + + if (trackRef) + { + for (int i = 0; i < length; i++) + { + builder.Add(values, elementSerializer.Read(context, RefMode.Tracking, false)); + } + + if (!declared) + { + context.ClearReadTypeInfo(typeof(T)); + } + + return values; + } + + if (hasNull) + { + for (int i = 0; i < length; i++) + { + sbyte refFlag = context.Reader.ReadInt8(); + if (refFlag == (sbyte)RefFlag.Null) + { + builder.Add(values, (T)elementSerializer.DefaultObject!); + } + else + { + builder.Add(values, elementSerializer.ReadData(context)); + } + } + } + else + { + for (int i = 0; i < length; i++) + { + builder.Add(values, elementSerializer.ReadData(context)); + } + } + + if (!declared) + { + context.ClearReadTypeInfo(typeof(T)); + } + + return values; + } + + internal static HashSet ReadHashSetData(Serializer elementSerializer, ReadContext context) + where T : notnull + { + return ReadCollectionOwner, HashSetBuilder>( + elementSerializer, + context, + new HashSetBuilder()); + } + + internal static SortedSet ReadSortedSetData(Serializer elementSerializer, ReadContext context) + where T : notnull + { + return ReadCollectionOwner, SortedSetBuilder>( + elementSerializer, + context, + new SortedSetBuilder()); + } + + internal static ImmutableHashSet ReadImmutableHashSetData( + Serializer elementSerializer, + ReadContext context) + where T : notnull + { + ImmutableHashSet.Builder values = + ReadCollectionOwner.Builder, ImmutableHashSetBuilder>( + elementSerializer, + context, + new ImmutableHashSetBuilder()); + return values.ToImmutable(); + } + + internal static LinkedList ReadLinkedListData(Serializer elementSerializer, ReadContext context) + { + return ReadCollectionOwner, LinkedListBuilder>( + elementSerializer, + context, + new LinkedListBuilder()); + } + + internal static Queue ReadQueueData(Serializer elementSerializer, ReadContext context) + { + return ReadCollectionOwner, QueueBuilder>( + elementSerializer, + context, + new QueueBuilder()); + } + + internal static Stack ReadStackData(Serializer elementSerializer, ReadContext context) + { + return ReadCollectionOwner, StackBuilder>( + elementSerializer, + context, + new StackBuilder()); + } + public static T[] ReadArrayData(Serializer elementSerializer, ReadContext context) { int length = checked((int)context.Reader.ReadVarUInt32()); @@ -761,287 +950,39 @@ public override void WriteData(WriteContext context, in HashSet value, bool h public override HashSet ReadData(ReadContext context) { - if (context.ShouldStoreRef) - { - return ReadStoredSetData(context); - } + return CollectionCodec.ReadHashSetData(context.TypeResolver.GetSerializer(), context); + } +} - List values = CollectionCodec.ReadCollectionData( - context.TypeResolver.GetSerializer(), - context, - reserveOwner: false, - storeOwnerRef: false); - CollectionCodec.ReserveElementStorage(context, values.Count); - return [.. values]; +public sealed class SortedSetSerializer : Serializer> where T : notnull +{ + public override SortedSet DefaultValue => null!; + + public override void WriteData(WriteContext context, in SortedSet value, bool hasGenerics) + { + SortedSet safe = value ?? new SortedSet(); + CollectionCodec.WriteCollectionData(safe, context.TypeResolver.GetSerializer(), context, hasGenerics); } - private static HashSet ReadStoredSetData(ReadContext context) + public override SortedSet ReadData(ReadContext context) { - Serializer elementSerializer = context.TypeResolver.GetSerializer(); - int length = checked((int)context.Reader.ReadVarUInt32()); - CollectionCodec.ReserveElementStorage(context, length); - HashSet values = new(length); - context.StoreRef(values); - if (length == 0) - { - return values; - } + return CollectionCodec.ReadSortedSetData(context.TypeResolver.GetSerializer(), context); + } +} - byte header = context.Reader.ReadUInt8(); - bool trackRef = (header & CollectionBits.TrackingRef) != 0; - bool hasNull = (header & CollectionBits.HasNull) != 0; - bool declared = (header & CollectionBits.DeclaredElementType) != 0; - bool sameType = (header & CollectionBits.SameType) != 0; - context.Reader.CheckBound(length); - if (!sameType) - { - if (trackRef) - { - for (int i = 0; i < length; i++) - { - values.Add(elementSerializer.Read(context, RefMode.Tracking, true)); - } +public sealed class ImmutableHashSetSerializer : Serializer> where T : notnull +{ + public override ImmutableHashSet DefaultValue => null!; - return values; - } - - if (hasNull) - { - for (int i = 0; i < length; i++) - { - sbyte refFlag = context.Reader.ReadInt8(); - if (refFlag == (sbyte)RefFlag.Null) - { - values.Add((T)elementSerializer.DefaultObject!); - } - else if (refFlag == (sbyte)RefFlag.NotNullValue) - { - values.Add(elementSerializer.Read(context, RefMode.None, true)); - } - else - { - throw new RefException($"invalid nullability flag {refFlag}"); - } - } - } - else - { - for (int i = 0; i < length; i++) - { - values.Add(elementSerializer.Read(context, RefMode.None, true)); - } - } - - return values; - } - - if (!declared) - { - context.TypeResolver.ReadTypeInfo(elementSerializer, context); - } - - if (trackRef) - { - for (int i = 0; i < length; i++) - { - values.Add(elementSerializer.Read(context, RefMode.Tracking, false)); - } - - if (!declared) - { - context.ClearReadTypeInfo(typeof(T)); - } - - return values; - } - - if (hasNull) - { - for (int i = 0; i < length; i++) - { - sbyte refFlag = context.Reader.ReadInt8(); - if (refFlag == (sbyte)RefFlag.Null) - { - values.Add((T)elementSerializer.DefaultObject!); - } - else - { - values.Add(elementSerializer.ReadData(context)); - } - } - } - else - { - for (int i = 0; i < length; i++) - { - values.Add(elementSerializer.ReadData(context)); - } - } - - if (!declared) - { - context.ClearReadTypeInfo(typeof(T)); - } - - return values; - } -} - -public sealed class SortedSetSerializer : Serializer> where T : notnull -{ - public override SortedSet DefaultValue => null!; - - public override void WriteData(WriteContext context, in SortedSet value, bool hasGenerics) - { - SortedSet safe = value ?? new SortedSet(); - CollectionCodec.WriteCollectionData(safe, context.TypeResolver.GetSerializer(), context, hasGenerics); - } - - public override SortedSet ReadData(ReadContext context) - { - if (context.ShouldStoreRef) - { - return ReadStoredSortedSetData(context); - } - - uint refId = context.PauseRefPublication(); - List values = CollectionCodec.ReadCollectionData( - context.TypeResolver.GetSerializer(), - context, - reserveOwner: false, - storeOwnerRef: false); - context.ResumeRefPublication(refId); - CollectionCodec.ReserveElementStorage(context, values.Count); - return [.. values]; - } - - private static SortedSet ReadStoredSortedSetData(ReadContext context) - { - Serializer elementSerializer = context.TypeResolver.GetSerializer(); - int length = CollectionCodec.ReadLengthAndReserve(context); - SortedSet values = new(); - context.StoreRef(values); - if (length == 0) - { - return values; - } - - CollectionCodec.CollectionFrame frame = CollectionCodec.ReadFrame(context, length); - if (!frame.SameType) - { - if (frame.TrackRef) - { - for (int i = 0; i < length; i++) - { - values.Add(elementSerializer.Read(context, RefMode.Tracking, true)); - } - - return values; - } - - if (frame.HasNull) - { - for (int i = 0; i < length; i++) - { - sbyte refFlag = context.Reader.ReadInt8(); - if (refFlag == (sbyte)RefFlag.Null) - { - values.Add((T)elementSerializer.DefaultObject!); - } - else if (refFlag == (sbyte)RefFlag.NotNullValue) - { - values.Add(elementSerializer.Read(context, RefMode.None, true)); - } - else - { - throw new RefException($"invalid nullability flag {refFlag}"); - } - } - } - else - { - for (int i = 0; i < length; i++) - { - values.Add(elementSerializer.Read(context, RefMode.None, true)); - } - } - - return values; - } - - if (!frame.Declared) - { - context.TypeResolver.ReadTypeInfo(elementSerializer, context); - } - - if (frame.TrackRef) - { - for (int i = 0; i < length; i++) - { - values.Add(elementSerializer.Read(context, RefMode.Tracking, false)); - } - - if (!frame.Declared) - { - context.ClearReadTypeInfo(typeof(T)); - } - - return values; - } - - if (frame.HasNull) - { - for (int i = 0; i < length; i++) - { - sbyte refFlag = context.Reader.ReadInt8(); - if (refFlag == (sbyte)RefFlag.Null) - { - values.Add((T)elementSerializer.DefaultObject!); - } - else - { - values.Add(elementSerializer.ReadData(context)); - } - } - } - else - { - for (int i = 0; i < length; i++) - { - values.Add(elementSerializer.ReadData(context)); - } - } - - if (!frame.Declared) - { - context.ClearReadTypeInfo(typeof(T)); - } - - return values; - } -} - -public sealed class ImmutableHashSetSerializer : Serializer> where T : notnull -{ - public override ImmutableHashSet DefaultValue => null!; - - public override void WriteData(WriteContext context, in ImmutableHashSet value, bool hasGenerics) - { - ImmutableHashSet safe = value ?? ImmutableHashSet.Empty; - CollectionCodec.WriteCollectionData(safe, context.TypeResolver.GetSerializer(), context, hasGenerics); - } + public override void WriteData(WriteContext context, in ImmutableHashSet value, bool hasGenerics) + { + ImmutableHashSet safe = value ?? ImmutableHashSet.Empty; + CollectionCodec.WriteCollectionData(safe, context.TypeResolver.GetSerializer(), context, hasGenerics); + } public override ImmutableHashSet ReadData(ReadContext context) { - uint refId = context.PauseRefPublication(); - List values = CollectionCodec.ReadCollectionData( - context.TypeResolver.GetSerializer(), - context, - reserveOwner: false, - storeOwnerRef: false); - context.ResumeRefPublication(refId); - CollectionCodec.ReserveElementStorage(context, values.Count); - return ImmutableHashSet.CreateRange(values); + return CollectionCodec.ReadImmutableHashSetData(context.TypeResolver.GetSerializer(), context); } } @@ -1057,125 +998,7 @@ public override void WriteData(WriteContext context, in LinkedList value, boo public override LinkedList ReadData(ReadContext context) { - if (context.ShouldStoreRef) - { - return ReadStoredLinkedListData(context); - } - - uint refId = context.PauseRefPublication(); - List values = CollectionCodec.ReadCollectionData( - context.TypeResolver.GetSerializer(), - context, - reserveOwner: false, - storeOwnerRef: false); - context.ResumeRefPublication(refId); - CollectionCodec.ReserveElementStorage(context, values.Count); - return new LinkedList(values); - } - - private static LinkedList ReadStoredLinkedListData(ReadContext context) - { - Serializer elementSerializer = context.TypeResolver.GetSerializer(); - int length = CollectionCodec.ReadLengthAndReserve(context); - LinkedList values = new(); - context.StoreRef(values); - if (length == 0) - { - return values; - } - - CollectionCodec.CollectionFrame frame = CollectionCodec.ReadFrame(context, length); - if (!frame.SameType) - { - if (frame.TrackRef) - { - for (int i = 0; i < length; i++) - { - values.AddLast(elementSerializer.Read(context, RefMode.Tracking, true)); - } - - return values; - } - - if (frame.HasNull) - { - for (int i = 0; i < length; i++) - { - sbyte refFlag = context.Reader.ReadInt8(); - if (refFlag == (sbyte)RefFlag.Null) - { - values.AddLast((T)elementSerializer.DefaultObject!); - } - else if (refFlag == (sbyte)RefFlag.NotNullValue) - { - values.AddLast(elementSerializer.Read(context, RefMode.None, true)); - } - else - { - throw new RefException($"invalid nullability flag {refFlag}"); - } - } - } - else - { - for (int i = 0; i < length; i++) - { - values.AddLast(elementSerializer.Read(context, RefMode.None, true)); - } - } - - return values; - } - - if (!frame.Declared) - { - context.TypeResolver.ReadTypeInfo(elementSerializer, context); - } - - if (frame.TrackRef) - { - for (int i = 0; i < length; i++) - { - values.AddLast(elementSerializer.Read(context, RefMode.Tracking, false)); - } - - if (!frame.Declared) - { - context.ClearReadTypeInfo(typeof(T)); - } - - return values; - } - - if (frame.HasNull) - { - for (int i = 0; i < length; i++) - { - sbyte refFlag = context.Reader.ReadInt8(); - if (refFlag == (sbyte)RefFlag.Null) - { - values.AddLast((T)elementSerializer.DefaultObject!); - } - else - { - values.AddLast(elementSerializer.ReadData(context)); - } - } - } - else - { - for (int i = 0; i < length; i++) - { - values.AddLast(elementSerializer.ReadData(context)); - } - } - - if (!frame.Declared) - { - context.ClearReadTypeInfo(typeof(T)); - } - - return values; + return CollectionCodec.ReadLinkedListData(context.TypeResolver.GetSerializer(), context); } } @@ -1191,131 +1014,7 @@ public override void WriteData(WriteContext context, in Queue value, bool has public override Queue ReadData(ReadContext context) { - if (context.ShouldStoreRef) - { - return ReadStoredQueueData(context); - } - - uint refId = context.PauseRefPublication(); - List values = CollectionCodec.ReadCollectionData( - context.TypeResolver.GetSerializer(), - context, - reserveOwner: false, - storeOwnerRef: false); - context.ResumeRefPublication(refId); - CollectionCodec.ReserveElementStorage(context, values.Count); - Queue queue = new(values.Count); - for (int i = 0; i < values.Count; i++) - { - queue.Enqueue(values[i]); - } - - return queue; - } - - private static Queue ReadStoredQueueData(ReadContext context) - { - Serializer elementSerializer = context.TypeResolver.GetSerializer(); - int length = CollectionCodec.ReadLengthAndReserve(context); - Queue values = new(length); - context.StoreRef(values); - if (length == 0) - { - return values; - } - - CollectionCodec.CollectionFrame frame = CollectionCodec.ReadFrame(context, length); - if (!frame.SameType) - { - if (frame.TrackRef) - { - for (int i = 0; i < length; i++) - { - values.Enqueue(elementSerializer.Read(context, RefMode.Tracking, true)); - } - - return values; - } - - if (frame.HasNull) - { - for (int i = 0; i < length; i++) - { - sbyte refFlag = context.Reader.ReadInt8(); - if (refFlag == (sbyte)RefFlag.Null) - { - values.Enqueue((T)elementSerializer.DefaultObject!); - } - else if (refFlag == (sbyte)RefFlag.NotNullValue) - { - values.Enqueue(elementSerializer.Read(context, RefMode.None, true)); - } - else - { - throw new RefException($"invalid nullability flag {refFlag}"); - } - } - } - else - { - for (int i = 0; i < length; i++) - { - values.Enqueue(elementSerializer.Read(context, RefMode.None, true)); - } - } - - return values; - } - - if (!frame.Declared) - { - context.TypeResolver.ReadTypeInfo(elementSerializer, context); - } - - if (frame.TrackRef) - { - for (int i = 0; i < length; i++) - { - values.Enqueue(elementSerializer.Read(context, RefMode.Tracking, false)); - } - - if (!frame.Declared) - { - context.ClearReadTypeInfo(typeof(T)); - } - - return values; - } - - if (frame.HasNull) - { - for (int i = 0; i < length; i++) - { - sbyte refFlag = context.Reader.ReadInt8(); - if (refFlag == (sbyte)RefFlag.Null) - { - values.Enqueue((T)elementSerializer.DefaultObject!); - } - else - { - values.Enqueue(elementSerializer.ReadData(context)); - } - } - } - else - { - for (int i = 0; i < length; i++) - { - values.Enqueue(elementSerializer.ReadData(context)); - } - } - - if (!frame.Declared) - { - context.ClearReadTypeInfo(typeof(T)); - } - - return values; + return CollectionCodec.ReadQueueData(context.TypeResolver.GetSerializer(), context); } } @@ -1344,130 +1043,6 @@ public override void WriteData(WriteContext context, in Stack value, bool has public override Stack ReadData(ReadContext context) { - if (context.ShouldStoreRef) - { - return ReadStoredStackData(context); - } - - uint refId = context.PauseRefPublication(); - List values = CollectionCodec.ReadCollectionData( - context.TypeResolver.GetSerializer(), - context, - reserveOwner: false, - storeOwnerRef: false); - context.ResumeRefPublication(refId); - CollectionCodec.ReserveElementStorage(context, values.Count); - Stack stack = new(values.Count); - for (int i = 0; i < values.Count; i++) - { - stack.Push(values[i]); - } - - return stack; - } - - private static Stack ReadStoredStackData(ReadContext context) - { - Serializer elementSerializer = context.TypeResolver.GetSerializer(); - int length = CollectionCodec.ReadLengthAndReserve(context); - Stack values = new(length); - context.StoreRef(values); - if (length == 0) - { - return values; - } - - CollectionCodec.CollectionFrame frame = CollectionCodec.ReadFrame(context, length); - if (!frame.SameType) - { - if (frame.TrackRef) - { - for (int i = 0; i < length; i++) - { - values.Push(elementSerializer.Read(context, RefMode.Tracking, true)); - } - - return values; - } - - if (frame.HasNull) - { - for (int i = 0; i < length; i++) - { - sbyte refFlag = context.Reader.ReadInt8(); - if (refFlag == (sbyte)RefFlag.Null) - { - values.Push((T)elementSerializer.DefaultObject!); - } - else if (refFlag == (sbyte)RefFlag.NotNullValue) - { - values.Push(elementSerializer.Read(context, RefMode.None, true)); - } - else - { - throw new RefException($"invalid nullability flag {refFlag}"); - } - } - } - else - { - for (int i = 0; i < length; i++) - { - values.Push(elementSerializer.Read(context, RefMode.None, true)); - } - } - - return values; - } - - if (!frame.Declared) - { - context.TypeResolver.ReadTypeInfo(elementSerializer, context); - } - - if (frame.TrackRef) - { - for (int i = 0; i < length; i++) - { - values.Push(elementSerializer.Read(context, RefMode.Tracking, false)); - } - - if (!frame.Declared) - { - context.ClearReadTypeInfo(typeof(T)); - } - - return values; - } - - if (frame.HasNull) - { - for (int i = 0; i < length; i++) - { - sbyte refFlag = context.Reader.ReadInt8(); - if (refFlag == (sbyte)RefFlag.Null) - { - values.Push((T)elementSerializer.DefaultObject!); - } - else - { - values.Push(elementSerializer.ReadData(context)); - } - } - } - else - { - for (int i = 0; i < length; i++) - { - values.Push(elementSerializer.ReadData(context)); - } - } - - if (!frame.Declared) - { - context.ClearReadTypeInfo(typeof(T)); - } - - return values; + return CollectionCodec.ReadStackData(context.TypeResolver.GetSerializer(), context); } } diff --git a/csharp/src/Fory/ReadContext.cs b/csharp/src/Fory/ReadContext.cs index 56a90df503..9023fc5dff 100644 --- a/csharp/src/Fory/ReadContext.cs +++ b/csharp/src/Fory/ReadContext.cs @@ -32,6 +32,8 @@ public sealed class ReadContext private readonly List _readMetaStrings = []; internal readonly UInt64Map _readTypeInfoByType = new(); + // Consumed slots stay on the stack until their matching reader scope clears them. That lets + // nested child reads restore an outer owner that has not been materialized yet. internal readonly List _reservedRefIds = []; private readonly int _maxDynamicReadDepth; internal Type? _typeMetaType; @@ -40,7 +42,6 @@ public sealed class ReadContext internal Type? _cachedTypeMetaType; internal TypeMeta? _cachedTypeMeta; internal int _currentDynamicReadDepth; - private bool _hasReservedRefId; private readonly Dictionary _remoteSchemaVersionsByType = []; private readonly Config _config; private int _totalAcceptedSchemaVersions; @@ -77,7 +78,11 @@ public ReadContext( public bool ShouldStoreRef { [MethodImpl(MethodImplOptions.AggressiveInlining)] - get => _hasReservedRefId; + get + { + int index = _reservedRefIds.Count - 1; + return index >= 0 && _reservedRefIds[index] != NoReservedRefId; + } } internal RefReader RefReader { get; } @@ -471,69 +476,25 @@ internal void ClearReadTypeInfo(Type type) [MethodImpl(MethodImplOptions.AggressiveInlining)] public void StoreRef(object? value) { - if (!_hasReservedRefId) + int index = _reservedRefIds.Count - 1; + if (index < 0) { return; } - int index = _reservedRefIds.Count - 1; - if (index < 0) + uint refId = _reservedRefIds[index]; + if (refId == NoReservedRefId) { - _hasReservedRefId = false; return; } - RefReader.StoreRefAt(_reservedRefIds[index], value); - _hasReservedRefId = false; + RefReader.StoreRefAt(refId, value); + _reservedRefIds[index] = NoReservedRefId; } internal void SetReservedRefId(uint refId) { _reservedRefIds.Add(refId); - _hasReservedRefId = true; - } - - /// - /// Hides the current publishable ref id while a serializer reads a child or temporary owner. - /// - /// - /// The reserved slot stays on the stack so the outer owner can publish it after materialization. - /// This prevents immutable wrappers and conversion serializers from letting children consume the - /// parent ref id before the parent object exists. - /// - public uint PauseRefPublication() - { - if (!_hasReservedRefId) - { - return NoReservedRefId; - } - - int index = _reservedRefIds.Count - 1; - if (index < 0) - { - _hasReservedRefId = false; - return NoReservedRefId; - } - - _hasReservedRefId = false; - return _reservedRefIds[index]; - } - - /// Restores a ref id hidden by . - public void ResumeRefPublication(uint refId) - { - if (refId == NoReservedRefId) - { - return; - } - - int index = _reservedRefIds.Count - 1; - if (index < 0) - { - throw new RefException($"cannot resume ref publication for ref id {refId}"); - } - - _hasReservedRefId = true; } internal void ClearReservedRefId() @@ -543,8 +504,6 @@ internal void ClearReservedRefId() { _reservedRefIds.RemoveAt(count - 1); } - - _hasReservedRefId = false; } internal void IncreaseReadDepth() @@ -573,7 +532,6 @@ internal void Reset() _typeMetaByType?.ClearKeys(); _readTypeInfoByType.ClearKeys(); _reservedRefIds.Clear(); - _hasReservedRefId = false; _cachedTypeMetaType = null; _cachedTypeMeta = null; _currentDynamicReadDepth = 0; diff --git a/csharp/src/Fory/UnionSerializer.cs b/csharp/src/Fory/UnionSerializer.cs index 47197c9271..71a4834d84 100644 --- a/csharp/src/Fory/UnionSerializer.cs +++ b/csharp/src/Fory/UnionSerializer.cs @@ -50,7 +50,6 @@ public override void WriteData(WriteContext context, in TUnion value, bool hasGe public override TUnion ReadData(ReadContext context) { - uint refId = context.PauseRefPublication(); uint rawCaseId = context.Reader.ReadVarUInt32(); if (rawCaseId > int.MaxValue) { @@ -69,7 +68,6 @@ public override TUnion ReadData(ReadContext context) } TUnion value = Factory(caseId, caseValue); - context.ResumeRefPublication(refId); return value; } diff --git a/dart/packages/fory/lib/src/serializer/collection_serializers.dart b/dart/packages/fory/lib/src/serializer/collection_serializers.dart index 71b2d7dd27..8245a3492e 100644 --- a/dart/packages/fory/lib/src/serializer/collection_serializers.dart +++ b/dart/packages/fory/lib/src/serializer/collection_serializers.dart @@ -335,13 +335,8 @@ final class ListSerializer extends Serializer { ReadContext context, FieldType? elementFieldType, { bool hasPreservedRef = false, - bool reserveOwner = true, }) { - final state = _prepareListRead( - context, - elementFieldType, - reserveOwner: reserveOwner, - ); + final state = _prepareListRead(context, elementFieldType); context.buffer.checkReadableBytes(state.size); final result = List.filled(state.size, null, growable: false); if (hasPreservedRef) { @@ -386,17 +381,24 @@ final class SetSerializer extends Serializer { FieldType? elementFieldType, { bool hasPreservedRef = false, }) { - final values = ListSerializer.readPayload( - context, - elementFieldType, - hasPreservedRef: false, - reserveOwner: false, - ); - context.reserveGraphMemory(_ownerBytes + values.length * _referenceBytes); - final result = Set.of(values); + final state = _prepareListRead(context, elementFieldType); + context.buffer.checkReadableBytes(state.size); + final result = {}; if (hasPreservedRef) { context.reference(result); } + if (state.size == 0) { + return result; + } + if (state.tracksDepth) { + context.increaseDepth(); + } + for (var index = 0; index < state.size; index += 1) { + result.add(_readPreparedListItem(context, state)); + } + if (state.tracksDepth) { + context.decreaseDepth(); + } return result; } } @@ -654,14 +656,9 @@ Object _arrayToListValue(ReadContext context, Object? raw) { List readTypedListPayload( ReadContext context, FieldType? elementFieldType, - T Function(Object? value) convert, { - bool reserveOwner = true, -}) { - final state = _prepareListRead( - context, - elementFieldType, - reserveOwner: reserveOwner, - ); + T Function(Object? value) convert, +) { + final state = _prepareListRead(context, elementFieldType); if (state.size == 0) { return List.empty(growable: false); } @@ -739,14 +736,7 @@ Set readTypedSetPayload( FieldType? elementFieldType, T Function(Object? value) convert, ) { - final values = readTypedListPayload( - context, - elementFieldType, - convert, - reserveOwner: false, - ); - context.reserveGraphMemory(_ownerBytes + values.length * _referenceBytes); - return Set.of(values); + return Set.of(readTypedListPayload(context, elementFieldType, convert)); } void writeTypedListPayload( @@ -934,13 +924,10 @@ final class _PreparedListRead { @pragma('vm:prefer-inline') _PreparedListRead _prepareListRead( ReadContext context, - FieldType? elementFieldType, { - bool reserveOwner = true, -}) { + FieldType? elementFieldType, +) { final size = context.buffer.readVarUint32(); - if (reserveOwner) { - context.reserveGraphMemory(_ownerBytes + size * _referenceBytes); - } + context.reserveGraphMemory(_ownerBytes + size * _referenceBytes); if (size == 0) { return _PreparedListRead( size: 0, diff --git a/dart/packages/fory/test/collection_serializer_test.dart b/dart/packages/fory/test/collection_serializer_test.dart index a42e5d701c..9704462a99 100644 --- a/dart/packages/fory/test/collection_serializer_test.dart +++ b/dart/packages/fory/test/collection_serializer_test.dart @@ -257,6 +257,19 @@ void main() { expect(identical(roundTrip, roundTrip[0]), isTrue); }); + test('trackRef true preserves self-referential root sets', () { + final fory = Fory(); + final value = {}; + value.add(value); + + final roundTrip = + fory.deserialize(fory.serialize(value, trackRef: true)) + as Set; + + expect(roundTrip, hasLength(1)); + expect(identical(roundTrip, roundTrip.single), isTrue); + }); + test('trackRef true preserves self-referential root maps', () { final fory = Fory(); final value = {}; From 3a10256a3f2ce4610c7fc56ee95c77aba542fe80 Mon Sep 17 00:00:00 2001 From: chaokunyang Date: Fri, 3 Jul 2026 02:08:39 +0800 Subject: [PATCH 33/54] refactor: simplify graph memory read state --- .agents/languages/csharp.md | 6 +- .agents/repo-reference.md | 9 + AGENTS.md | 5 +- cpp/fory/serialization/context.h | 22 +-- cpp/fory/serialization/fory.h | 10 +- .../serialization/graph_memory_budget_test.cc | 22 +-- .../src/Fory.Generator/ForyModelGenerator.cs | 28 +-- csharp/src/Fory/CollectionSerializers.cs | 44 +++-- csharp/src/Fory/DictionarySerializers.cs | 10 +- csharp/src/Fory/Fory.cs | 12 +- csharp/src/Fory/NullableKeyDictionary.cs | 11 +- csharp/src/Fory/ReadContext.cs | 29 +-- .../Fory.Tests/GraphMemoryBudgetTests.cs | 6 +- .../fory/lib/src/context/read_context.dart | 6 - .../fory/test/graph_memory_budget_test.dart | 6 - .../xlang_implementation_guide.md | 4 + go/fory/fory.go | 170 ++++++------------ go/fory/graph_memory_budget_test.go | 15 +- go/fory/reader.go | 56 ------ go/fory/stream.go | 33 +++- .../org/apache/fory/context/ReadContext.java | 4 - python/pyfory/collection.pxi | 36 +--- python/pyfory/context.pxi | 13 -- python/pyfory/struct.pxi | 2 +- rust/fory-core/src/context.rs | 25 +-- rust/fory-core/src/fory.rs | 28 ++- swift/Sources/Fory/Fory.swift | 6 +- swift/Sources/Fory/ReadContext.swift | 9 +- .../ForyTests/GraphMemoryBudgetTests.swift | 4 +- 29 files changed, 214 insertions(+), 417 deletions(-) diff --git a/.agents/languages/csharp.md b/.agents/languages/csharp.md index a3685e9560..4d82f2518b 100644 --- a/.agents/languages/csharp.md +++ b/.agents/languages/csharp.md @@ -16,9 +16,9 @@ Load this file when changing `csharp/` or C# xlang behavior. memory-backed today, but the graph budget uses the same fixed default for every root shape. `ReadContext` may expose only raw byte reservation; concrete serializers and generated serializers must compute list, array, map, struct, and object byte formulas before calling it. -- `ReadContext` must not expose ref-publication pause/resume APIs. Keep nested reference - publication stack-aware inside read-state internals so immutable, generated, and conversion - serializers do not publish temporary owners. +- `ReadContext` must not expose ref-publication pause/resume APIs or any non-budget owner + controls. Concrete serializers and generated serializers own ref publication timing directly, + and must not publish temporary owners. - For C# graph budget formulas, distinguish inline value storage from reference storage: use cheap value-type size for `List`/`T[]` value paths and the 4-byte reference fallback for reference paths. Class/reference serializers reserve their own shallow self cost plus field storage when diff --git a/.agents/repo-reference.md b/.agents/repo-reference.md index 4d420b1f05..3934449a9d 100644 --- a/.agents/repo-reference.md +++ b/.agents/repo-reference.md @@ -80,6 +80,15 @@ Apache Fory is a multi-language serialization framework with multiple wire forma copy/decompression and before field-list allocation, and never add cache-hit or generated-reader hot-path work for them. +## Root Graph Memory Budget Ownership + +Root graph memory budgeting is a read-state accounting feature only. Read context or equivalent +read state may expose raw byte reservation and, when a runtime cannot reasonably avoid it, +root-operation budget setup/reset. It must not grow semantic APIs for collection, map, array, +struct, object, temporary-owner, serializer-owner, conversion, counted-allocation, or +ref-publication control. Concrete serializers and generated serializers own allocation formulas, +overflow checks, and reference publication timing. + ## Runtime Map ### Java diff --git a/AGENTS.md b/AGENTS.md index 18a29a4041..a275acdbe7 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -33,10 +33,7 @@ This is the entry point for AI guidance in Apache Fory. Read this file first, th - Check the spec before implementation. For wire behavior and xlang mapping, use the specs as the source of truth and never copy one runtime's bug into another runtime just to make tests pass. - Do not make assumptions about runtime behavior, ownership, registration, metadata construction, protocol semantics, or test coverage. Read the current code, owning docs/specs, and relevant tests before making a design judgment or implementation decision. If the evidence is incomplete, inspect more or state the uncertainty explicitly instead of filling gaps from memory or analogy with another runtime. - For untrusted deserialization, read `docs/security/deserialization.md` before changing allocation, stream filling, skip, reference, metadata, or policy validation behavior. Variable-length deserialization must not allocate or reserve backing/output capacity from attacker-declared lengths or counts before the byte owner has proven proportional readable bytes with `checkReadableBytes` or the runtime equivalent. Root graph memory reservation is accounting only and may happen before that byte check, but it must not replace the byte check. -- Root deserialization graph memory budgets estimate lower-bound shallow memory for materialized graph owners, not exact heap accounting, input byte accounting, or raw element counts. `maxGraphMemoryBytes` defaults to fixed `128 MiB`; positive values override the default; explicit non-positive values intentionally disable graph-memory enforcement and must be documented as deserialization DoS risk for compact inputs that materialize large graphs. Do not derive this budget from root input size, and do not split known-length and stream root behavior. Read context/read state owns only raw byte reservation with `reserveGraphMemory(bytes)`; it must not expose counted arithmetic helpers or collection, map, array, struct, or object semantic reservation APIs. Concrete serializers and generated serializers own counted formulas and overflow checks for their allocation path. Reserve self storage exactly once at the owner that stores or allocates the value: reference/object runtimes reserve parent owner self cost plus reference storage and every referenced heap owner reserves its own shallow self cost when materialized; inline/value runtimes reserve value storage in the root, struct/product, collection, map, set, array, smart-pointer, or box owner that actually owns that storage, and nested value serializers do not unconditionally reserve their own `sizeof(T)` again. Collection, map, set, and reference-array owners reserve nonzero shallow self cost when independently materialized, plus backing/reference/inline storage. Struct, record, POJO, tuple/product, compatible, generated, and dynamic object owners reserve a nonzero shallow self cost plus shallow field storage. Reference fields use 4 bytes when reference size is not cheap or reliable to query; primitive/value fields use their storage width. Parents do not recursively include child object, collection, map, string, binary, or primitive dense-array contents. Skip enum/union as separate owners and skip dedicated string, binary, primitive scalar, primitive array, and primitive dense-array leaf owners. Do not guess allocator, bucket-table, node, debug, or map-entry overhead unless it is a documented lower-bound owner allocation. Do not add dynamic stream bytes-read accounting or nested hot-path cleanup just for this budget. -- Read context/read state must not expose ref-publication pause/resume APIs. Nested reference - publication should be stack-aware inside the read state so immutable, generated, and conversion - owners can materialize final values without publishing temporary owners. +- Root deserialization graph memory budgets estimate lower-bound shallow memory for materialized graph owners, not exact heap accounting, input byte accounting, or raw element counts. `maxGraphMemoryBytes` defaults to fixed `128 MiB`; positive values override the default; explicit non-positive values intentionally disable graph-memory enforcement and must be documented as deserialization DoS risk for compact inputs that materialize large graphs. Do not derive this budget from root input size, and do not split known-length and stream root behavior. Read context/read state owns only raw byte reservation with `reserveGraphMemory(bytes)`; it must not expose counted arithmetic helpers or collection, map, array, struct, or object semantic reservation APIs. Do not add any non-memory-budget read-context/read-state API for this feature, including ref-publication controls, temporary-owner controls, serializer-owner controls, conversion helpers, or APIs that encode what kind of value is being materialized. Concrete serializers and generated serializers own counted formulas, overflow checks, allocation-owner decisions, and reference publication timing for their allocation path. Reserve self storage exactly once at the owner that stores or allocates the value: reference/object runtimes reserve parent owner self cost plus reference storage and every referenced heap owner reserves its own shallow self cost when materialized; inline/value runtimes reserve value storage in the root, struct/product, collection, map, set, array, smart-pointer, or box owner that actually owns that storage, and nested value serializers do not unconditionally reserve their own `sizeof(T)` again. Collection, map, set, and reference-array owners reserve nonzero shallow self cost when independently materialized, plus backing/reference/inline storage. Struct, record, POJO, tuple/product, compatible, generated, and dynamic object owners reserve a nonzero shallow self cost plus shallow field storage. Reference fields use 4 bytes when reference size is not cheap or reliable to query; primitive/value fields use their storage width. Parents do not recursively include child object, collection, map, string, binary, or primitive dense-array contents. Skip enum/union as separate owners and skip dedicated string, binary, primitive scalar, primitive array, and primitive dense-array leaf owners. Do not guess allocator, bucket-table, node, debug, or map-entry overhead unless it is a documented lower-bound owner allocation. Do not add dynamic stream bytes-read accounting or nested hot-path cleanup just for this budget. - For remote TypeDef/TypeMeta reads, the checked metadata cache is the only owner of remote "already validated" state. Cache hit means the header was previously parsed, body/hash-validated, policy-checked, and published by that cache, so the hot path must skip the body and use cached metadata without extra validation, hashing, limit checks, exact-local checks, allocation, or policy work. A known expected local TypeDef/TypeMeta header/hash match is a local-schema hit, not a remote cache miss: it may skip the body and use the local TypeInfo/TypeMeta without schema-version counting or cache publish. Cache miss is the only path that parses and validates non-local metadata, enforces limits, performs exact-local byte comparison when needed, and publishes remote metadata to the cache. Do not add nullable accepted-header fields, sentinel headers, per-TypeInfo markers, pending metadata state, parallel header-low/header-high slots, or parallel acceptance state for this decision. If a runtime needs a metadata hit hint, cache the concrete checked metadata owner object, such as the TypeInfo, TypeDef, or TypeMeta used by that runtime, and compare its validated header identity directly. - When a user corrects a non-obvious invariant, encode it in the nearest source comment before continuing, and also update `AGENTS.md`, `.agents/**`, docs, or specs when the rule is reusable beyond one file. Do not rely only on chat history, task notes, commit messages, or benchmark logs for corrections that protect security, protocol behavior, ownership, naming, or hot-path performance. - Reject semantic hacks. Do not bypass broken semantics by deleting cases, simplifying callers, adding coercion hooks, or using workaround fallbacks; fix the underlying bug and prove it with focused tests. diff --git a/cpp/fory/serialization/context.h b/cpp/fory/serialization/context.h index 36469d6f2a..117d2ed27a 100644 --- a/cpp/fory/serialization/context.h +++ b/cpp/fory/serialization/context.h @@ -505,24 +505,6 @@ class ReadContext { } } - template - FORY_ALWAYS_INLINE bool init_graph_budget() { - const size_t limit = graph_memory_limit_bytes_; - if (FORY_PREDICT_TRUE(limit != 0)) { - if constexpr (ReserveBytes != 0) { - if (FORY_PREDICT_FALSE(ReserveBytes > limit)) { - return set_graph_memory_exceeded(ReserveBytes, limit); - } - remaining_graph_memory_bytes_ = limit - ReserveBytes; - } else { - remaining_graph_memory_bytes_ = limit; - } - return true; - } - remaining_graph_memory_bytes_ = std::numeric_limits::max(); - return true; - } - FORY_ALWAYS_INLINE bool reserve_graph_memory(size_t bytes) { const size_t remaining = remaining_graph_memory_bytes_; if (FORY_PREDICT_FALSE(remaining == std::numeric_limits::max())) { @@ -690,6 +672,8 @@ class ReadContext { inline const Config &config() const { return *config_; } private: + friend class Fory; + FORY_NOINLINE Result check_remote_type_meta_limit(const TypeMeta &type_meta); void record_remote_type_meta(const std::string &type_key); @@ -724,8 +708,6 @@ class ReadContext { meta::MetaStringTable meta_string_table_; fory::flat_hash_map remote_schema_versions_by_type_; size_t total_accepted_schema_versions_ = 0; - - friend class Fory; }; /// Implementation of DynDepthGuard destructor diff --git a/cpp/fory/serialization/fory.h b/cpp/fory/serialization/fory.h index 75f1334685..aa6a976059 100644 --- a/cpp/fory/serialization/fory.h +++ b/cpp/fory/serialization/fory.h @@ -887,15 +887,15 @@ class Fory : public BaseFory { } read_ctx_->attach(buffer); + read_ctx_->remaining_graph_memory_bytes_ = + read_ctx_->graph_memory_limit_bytes_ != 0 + ? read_ctx_->graph_memory_limit_bytes_ + : std::numeric_limits::max(); if constexpr (needs_graph_budget_v) { constexpr size_t root_owner_bytes = graph_value_owner_self_bytes(); if constexpr (root_owner_bytes != 0) { if (FORY_PREDICT_FALSE( - !read_ctx_->template init_graph_budget())) { - return Unexpected(read_ctx_->take_error()); - } - } else if constexpr (has_graph_budget_children_v) { - if (FORY_PREDICT_FALSE(!read_ctx_->template init_graph_budget<>())) { + !read_ctx_->reserve_graph_memory(root_owner_bytes))) { return Unexpected(read_ctx_->take_error()); } } diff --git a/cpp/fory/serialization/graph_memory_budget_test.cc b/cpp/fory/serialization/graph_memory_budget_test.cc index f6c69d6b0b..89937b0f74 100644 --- a/cpp/fory/serialization/graph_memory_budget_test.cc +++ b/cpp/fory/serialization/graph_memory_budget_test.cc @@ -128,20 +128,20 @@ void expect_budget_boundary(const T &value, size_t required) { TEST(GraphMemoryBudgetTest, FixedDefaultBudgetAndDisable) { Config config; - ReadContext context(config, std::make_unique()); - - ASSERT_TRUE(context.init_graph_budget()); - ASSERT_TRUE(context.reserve_graph_memory( - static_cast(kDefaultGraphMemoryBytes))); - ASSERT_FALSE(context.reserve_graph_memory(1)); - EXPECT_EQ(context.take_error().code(), ErrorCode::InvalidData); + EXPECT_EQ(config.max_graph_memory_bytes, kDefaultGraphMemoryBytes); Config disabled_config; disabled_config.max_graph_memory_bytes = 0; - ReadContext disabled(disabled_config, std::make_unique()); - ASSERT_TRUE(disabled.init_graph_budget()); - ASSERT_TRUE( - disabled.reserve_graph_memory(std::numeric_limits::max())); + EXPECT_EQ(disabled_config.max_graph_memory_bytes, 0); + + constexpr size_t count = 3; + std::vector> value(count); + auto bytes = serialize_value(value); + auto disabled_result = with_fory(0, [&](Fory &fory) { + return fory.deserialize>>(bytes); + }); + ASSERT_TRUE(disabled_result.ok()) << disabled_result.error().to_string(); + EXPECT_EQ(disabled_result.value(), value); } TEST(GraphMemoryBudgetTest, RootKindsShareConfiguredBudget) { diff --git a/csharp/src/Fory.Generator/ForyModelGenerator.cs b/csharp/src/Fory.Generator/ForyModelGenerator.cs index 8d7a6c2286..91c7f7fa81 100644 --- a/csharp/src/Fory.Generator/ForyModelGenerator.cs +++ b/csharp/src/Fory.Generator/ForyModelGenerator.cs @@ -460,8 +460,8 @@ private static void EmitObjectSerializer(StringBuilder sb, TypeModel model) sb.AppendLine(" }"); sb.AppendLine(); - EmitReadDataWithoutTypeMeta(sb, model, "ReadDataWithoutTypeMeta", "context.ShouldStoreRef"); - EmitReadDataMethod(sb, model, "ReadData", "ReadDataWithoutTypeMeta", "context.ShouldStoreRef", "public"); + EmitReadDataWithoutTypeMeta(sb, model, "ReadDataWithoutTypeMeta"); + EmitReadDataMethod(sb, model, "ReadData", "ReadDataWithoutTypeMeta", "public"); sb.AppendLine("}"); } @@ -469,8 +469,7 @@ private static void EmitObjectSerializer(StringBuilder sb, TypeModel model) private static void EmitReadDataWithoutTypeMeta( StringBuilder sb, TypeModel model, - string methodName, - string? storeRefCondition) + string methodName) { sb.AppendLine($" private {model.TypeName} {methodName}(global::Apache.Fory.ReadContext context)"); sb.AppendLine(" {"); @@ -480,7 +479,7 @@ private static void EmitReadDataWithoutTypeMeta( } sb.AppendLine($" {model.TypeName} valueNoTypeMeta = new {model.TypeName}();"); - EmitStoreRef(sb, model, storeRefCondition, "valueNoTypeMeta", 2); + EmitStoreRef(sb, model, "valueNoTypeMeta", 2); foreach (MemberModel member in model.SortedMembers) { @@ -505,7 +504,6 @@ private static void EmitReadDataMethod( TypeModel model, string methodName, string noTypeMetaMethodName, - string? storeRefCondition, string accessibility) { sb.AppendLine($" {accessibility} override {model.TypeName} {methodName}(global::Apache.Fory.ReadContext context)"); @@ -526,7 +524,7 @@ private static void EmitReadDataMethod( } sb.AppendLine($" {model.TypeName} value = new {model.TypeName}();"); - EmitStoreRef(sb, model, storeRefCondition, "value", 3); + EmitStoreRef(sb, model, "value", 3); sb.AppendLine(" bool __ForyExactTypeMeta = __ForyMatchesCachedTypeMeta(typeMeta, context.TrackRef, context.TypeResolver);"); sb.AppendLine(" if (__ForyAllFieldsBuiltIn && __ForyExactTypeMeta)"); @@ -638,7 +636,7 @@ private static void EmitReadDataMethod( } sb.AppendLine($" {model.TypeName} valueSchema = new {model.TypeName}();"); - EmitStoreRef(sb, model, storeRefCondition, "valueSchema", 2); + EmitStoreRef(sb, model, "valueSchema", 2); foreach (MemberModel member in model.SortedMembers) { @@ -653,26 +651,16 @@ private static void EmitReadDataMethod( private static void EmitStoreRef( StringBuilder sb, TypeModel model, - string? condition, string valueName, int indentLevel) { - if (model.Kind != DeclKind.Class || condition is null) + if (model.Kind != DeclKind.Class) { return; } string indent = new(' ', indentLevel * 4); - if (condition == "true") - { - sb.AppendLine($"{indent}context.StoreRef({valueName});"); - return; - } - - sb.AppendLine($"{indent}if ({condition})"); - sb.AppendLine($"{indent}{{"); - sb.AppendLine($"{indent} context.StoreRef({valueName});"); - sb.AppendLine($"{indent}}}"); + sb.AppendLine($"{indent}context.StoreRef({valueName});"); } private static void EmitUnionSerializer(StringBuilder sb, TypeModel model) diff --git a/csharp/src/Fory/CollectionSerializers.cs b/csharp/src/Fory/CollectionSerializers.cs index 94d97fbe75..9054d91880 100644 --- a/csharp/src/Fory/CollectionSerializers.cs +++ b/csharp/src/Fory/CollectionSerializers.cs @@ -220,10 +220,7 @@ public static List ReadCollectionData(Serializer elementSerializer, Rea { ReserveElementStorage(context, length); List empty = []; - if (context.ShouldStoreRef) - { - context.StoreRef(empty); - } + context.StoreRef(empty); return empty; } @@ -240,10 +237,7 @@ public static List ReadCollectionData(Serializer elementSerializer, Rea ReserveElementStorage(context, length); context.Reader.CheckBound(length); List values = new(length); - if (context.ShouldStoreRef) - { - context.StoreRef(values); - } + context.StoreRef(values); if (!sameType) { @@ -409,15 +403,15 @@ private static TCollection ReadCollectionOwner( { int length = checked((int)context.Reader.ReadVarUInt32()); ReserveElementStorage(context, length); - TCollection values = builder.Create(length); - if (builder.StoresOwnerRef && context.ShouldStoreRef) - { - context.StoreRef(values); - } - if (length == 0) { - return values; + TCollection empty = builder.Create(length); + if (builder.StoresOwnerRef) + { + context.StoreRef(empty); + } + + return empty; } byte header = context.Reader.ReadUInt8(); @@ -425,7 +419,15 @@ private static TCollection ReadCollectionOwner( bool hasNull = (header & CollectionBits.HasNull) != 0; bool declared = (header & CollectionBits.DeclaredElementType) != 0; bool sameType = (header & CollectionBits.SameType) != 0; + // Some builders allocate backing capacity from length, so prove proportional payload bytes + // before materializing non-empty owners. context.Reader.CheckBound(length); + TCollection values = builder.Create(length); + if (builder.StoresOwnerRef) + { + context.StoreRef(values); + } + if (!sameType) { if (trackRef) @@ -581,10 +583,7 @@ public static T[] ReadArrayData(Serializer elementSerializer, ReadContext { ReserveElementStorage(context, length); T[] empty = []; - if (context.ShouldStoreRef) - { - context.StoreRef(empty); - } + context.StoreRef(empty); return empty; } @@ -597,10 +596,7 @@ public static T[] ReadArrayData(Serializer elementSerializer, ReadContext ReserveElementStorage(context, length); context.Reader.CheckBound(length); T[] values = new T[length]; - if (context.ShouldStoreRef) - { - context.StoreRef(values); - } + context.StoreRef(values); if (!sameType) { @@ -779,7 +775,7 @@ public static object ReadMapPayload(ReadContext context) { Serializer> serializer = context.TypeResolver.GetSerializer>(); - bool storeRef = context.ShouldStoreRef; + bool storeRef = context._reservedRefIds.Count != 0; NullableKeyDictionary map = serializer.ReadData(context); if (storeRef) { diff --git a/csharp/src/Fory/DictionarySerializers.cs b/csharp/src/Fory/DictionarySerializers.cs index a1b13f6c5e..8722be02bd 100644 --- a/csharp/src/Fory/DictionarySerializers.cs +++ b/csharp/src/Fory/DictionarySerializers.cs @@ -230,10 +230,7 @@ public override TDictionary ReadData(ReadContext context) { ReserveMapStorage(context, totalLength); TDictionary empty = CreateMap(0); - if (context.ShouldStoreRef) - { - context.StoreRef(empty); - } + context.StoreRef(empty); return empty; } @@ -241,10 +238,7 @@ public override TDictionary ReadData(ReadContext context) ReserveMapStorage(context, totalLength); context.Reader.CheckBound(totalLength); TDictionary map = CreateMap(totalLength); - if (context.ShouldStoreRef) - { - context.StoreRef(map); - } + context.StoreRef(map); bool keyDynamicType = keyTypeInfo.IsDynamicType; bool valueDynamicType = valueTypeInfo.IsDynamicType; diff --git a/csharp/src/Fory/Fory.cs b/csharp/src/Fory/Fory.cs index a1169201ee..2143682e2f 100644 --- a/csharp/src/Fory/Fory.cs +++ b/csharp/src/Fory/Fory.cs @@ -192,7 +192,9 @@ public T Deserialize(ReadOnlySpan payload) { ByteReader reader = _readContext.Reader; reader.Reset(payload); - _readContext.InitGraphBudget(); + long graphLimit = Config.MaxGraphMemoryBytes; + _readContext._graphMemoryLimitBytes = graphLimit > 0 ? graphLimit : 0; + _readContext._remainingGraphMemoryBytes = graphLimit > 0 ? graphLimit : long.MaxValue; T value = DeserializeFromReader(reader); if (reader.Remaining != 0) { @@ -213,7 +215,9 @@ public T Deserialize(byte[] payload) { ByteReader reader = _readContext.Reader; reader.Reset(payload); - _readContext.InitGraphBudget(); + long graphLimit = Config.MaxGraphMemoryBytes; + _readContext._graphMemoryLimitBytes = graphLimit > 0 ? graphLimit : 0; + _readContext._remainingGraphMemoryBytes = graphLimit > 0 ? graphLimit : long.MaxValue; T value = DeserializeFromReader(reader); if (reader.Remaining != 0) { @@ -234,7 +238,9 @@ public T Deserialize(ref ReadOnlySequence payload) byte[] bytes = payload.ToArray(); ByteReader reader = _readContext.Reader; reader.Reset(bytes); - _readContext.InitGraphBudget(); + long graphLimit = Config.MaxGraphMemoryBytes; + _readContext._graphMemoryLimitBytes = graphLimit > 0 ? graphLimit : 0; + _readContext._remainingGraphMemoryBytes = graphLimit > 0 ? graphLimit : long.MaxValue; T value = DeserializeFromReader(reader); payload = payload.Slice(reader.Cursor); return value; diff --git a/csharp/src/Fory/NullableKeyDictionary.cs b/csharp/src/Fory/NullableKeyDictionary.cs index 00adde9aa4..a12bbf61d1 100644 --- a/csharp/src/Fory/NullableKeyDictionary.cs +++ b/csharp/src/Fory/NullableKeyDictionary.cs @@ -548,16 +548,12 @@ public override NullableKeyDictionary ReadData(ReadContext context Serializer valueSerializer = context.TypeResolver.GetSerializer(); TypeInfo keyTypeInfo = context.TypeResolver.GetTypeInfo(); TypeInfo valueTypeInfo = context.TypeResolver.GetTypeInfo(); - bool storeRef = context.ShouldStoreRef; int totalLength = checked((int)context.Reader.ReadVarUInt32()); if (totalLength == 0) { ReserveMapStorage(context, totalLength); NullableKeyDictionary empty = new(); - if (storeRef) - { - context.StoreRef(empty); - } + context.StoreRef(empty); return empty; } @@ -565,10 +561,7 @@ public override NullableKeyDictionary ReadData(ReadContext context ReserveMapStorage(context, totalLength); context.Reader.CheckBound(totalLength); NullableKeyDictionary map = new(totalLength); - if (storeRef) - { - context.StoreRef(map); - } + context.StoreRef(map); bool keyDynamicType = keyTypeInfo.IsDynamicType; bool valueDynamicType = valueTypeInfo.IsDynamicType; diff --git a/csharp/src/Fory/ReadContext.cs b/csharp/src/Fory/ReadContext.cs index 9023fc5dff..42da224a30 100644 --- a/csharp/src/Fory/ReadContext.cs +++ b/csharp/src/Fory/ReadContext.cs @@ -45,8 +45,8 @@ public sealed class ReadContext private readonly Dictionary _remoteSchemaVersionsByType = []; private readonly Config _config; private int _totalAcceptedSchemaVersions; - private long _graphMemoryLimitBytes = long.MaxValue; - private long _remainingGraphMemoryBytes = long.MaxValue; + internal long _graphMemoryLimitBytes = long.MaxValue; + internal long _remainingGraphMemoryBytes = long.MaxValue; public ReadContext( ByteReader reader, @@ -75,33 +75,8 @@ public ReadContext( public bool CheckStructVersion { get; } - public bool ShouldStoreRef - { - [MethodImpl(MethodImplOptions.AggressiveInlining)] - get - { - int index = _reservedRefIds.Count - 1; - return index >= 0 && _reservedRefIds[index] != NoReservedRefId; - } - } - internal RefReader RefReader { get; } - [MethodImpl(MethodImplOptions.AggressiveInlining)] - internal void InitGraphBudget() - { - long limit = _config.MaxGraphMemoryBytes; - if (limit <= 0) - { - _graphMemoryLimitBytes = 0; - _remainingGraphMemoryBytes = long.MaxValue; - return; - } - - _graphMemoryLimitBytes = limit; - _remainingGraphMemoryBytes = limit; - } - /// /// Reserves estimated graph memory for the current root deserialization. /// diff --git a/csharp/tests/Fory.Tests/GraphMemoryBudgetTests.cs b/csharp/tests/Fory.Tests/GraphMemoryBudgetTests.cs index 26e5f12c9b..2c73e744d5 100644 --- a/csharp/tests/Fory.Tests/GraphMemoryBudgetTests.cs +++ b/csharp/tests/Fory.Tests/GraphMemoryBudgetTests.cs @@ -149,12 +149,14 @@ public void DefaultFixedBudgetAndDisable() Assert.Equal(-2, NewFory(-2).Config.MaxGraphMemoryBytes); ReadContext context = new(new ByteReader([]), new TypeResolver(), NewFory().Config); - context.InitGraphBudget(); + context._graphMemoryLimitBytes = DefaultGraphMemoryBytes; + context._remainingGraphMemoryBytes = DefaultGraphMemoryBytes; context.ReserveGraphMemory(DefaultGraphMemoryBytes); Assert.Throws(() => context.ReserveGraphMemory(ReferenceBytes)); ReadContext disabled = new(new ByteReader([]), new TypeResolver(), NewFory(0).Config); - disabled.InitGraphBudget(); + disabled._graphMemoryLimitBytes = 0; + disabled._remainingGraphMemoryBytes = long.MaxValue; disabled.ReserveGraphMemory(long.MaxValue); Assert.Throws(() => disabled.ReserveGraphMemory(-1)); } diff --git a/dart/packages/fory/lib/src/context/read_context.dart b/dart/packages/fory/lib/src/context/read_context.dart index a6b78262af..3a0cbc1b58 100644 --- a/dart/packages/fory/lib/src/context/read_context.dart +++ b/dart/packages/fory/lib/src/context/read_context.dart @@ -99,12 +99,6 @@ final class ReadContext { @internal RefReader get refReader => _refReader; - @internal - int get effectiveGraphMemoryBytes => _effectiveGraphMemoryBytes; - - @internal - int get remainingGraphMemoryBytes => _remainingGraphMemoryBytes; - @internal @pragma('vm:prefer-inline') void reserveGraphMemory(int bytes) { diff --git a/dart/packages/fory/test/graph_memory_budget_test.dart b/dart/packages/fory/test/graph_memory_budget_test.dart index 499e043168..eadddb5611 100644 --- a/dart/packages/fory/test/graph_memory_budget_test.dart +++ b/dart/packages/fory/test/graph_memory_budget_test.dart @@ -138,10 +138,6 @@ void main() { final buffer = Buffer.wrap(Uint8List(17)); final context = _readContext(buffer); - expect( - context.effectiveGraphMemoryBytes, - equals(_defaultGraphMemoryBytes), - ); expect( () => context.reserveGraphMemory(_defaultGraphMemoryBytes), returnsNormally, @@ -153,12 +149,10 @@ void main() { final buffer = Buffer.wrap(Uint8List(4096)); final context = _readContext(buffer, maxGraphMemoryBytes: 31); - expect(context.effectiveGraphMemoryBytes, equals(31)); expect(() => context.reserveGraphMemory(31), returnsNormally); expect(() => context.reserveGraphMemory(1), _throwsGraphBudget); final disabled = _readContext(buffer, maxGraphMemoryBytes: 0); - expect(disabled.effectiveGraphMemoryBytes, equals(0)); expect( () => disabled.reserveGraphMemory(_defaultGraphMemoryBytes + 1), returnsNormally, diff --git a/docs/specification/xlang_implementation_guide.md b/docs/specification/xlang_implementation_guide.md index 4d6c9e867c..a738bd8ace 100644 --- a/docs/specification/xlang_implementation_guide.md +++ b/docs/specification/xlang_implementation_guide.md @@ -412,6 +412,10 @@ not expose counted arithmetic helpers or collection, map, array, struct, or object semantic reservation APIs. Concrete serializers and generated serializer owners compute the storage constants and formulas for the owner path they allocate, including counted-byte overflow checks. +Read state must not grow non-memory-budget APIs for this feature, including +ref-publication controls, temporary-owner controls, serializer-owner controls, +conversion helpers, or APIs that encode the kind of value being materialized. +Concrete serializers and generated serializers own those decisions. The budget estimates lower-bound shallow memory for materialized graph owners, not exact heap bytes. Reserve self storage exactly once at the owner that stores diff --git a/go/fory/fory.go b/go/fory/fory.go index 561a66ebc7..874c7e3619 100644 --- a/go/fory/fory.go +++ b/go/fory/fory.go @@ -199,13 +199,10 @@ type Fory struct { typeResolver *TypeResolver refResolver *RefResolver - rootGraphType reflect.Type - rootGraphBytes int64 - rootGraphHasChildren bool - rootGraphSkipType reflect.Type - rootGraphSkipTypeID unsafe.Pointer - rootReadTypeID unsafe.Pointer - rootReadSerializer Serializer + rootGraphType reflect.Type + rootGraphBytes int64 + rootReadTypeID unsafe.Pointer + rootReadSerializer Serializer } // New creates a new Fory instance with the given options @@ -241,7 +238,6 @@ func New(opts ...Option) *Fory { f.writeCtx.xlang = f.config.IsXlang f.readCtx = NewReadContext(f.config.TrackRef) - f.readCtx.maxGraphMemoryBytes = f.config.MaxGraphMemoryBytes f.readCtx.typeResolver = f.typeResolver f.readCtx.refResolver = f.refResolver f.readCtx.compatible = f.config.Compatible @@ -581,15 +577,19 @@ func (f *Fory) Deserialize(data []byte, v any) error { defer f.resetReadState() f.readCtx.SetData(data) typeID := (*ifaceWords)(unsafe.Pointer(&v)).typ - var target reflect.Value - if typeID != f.rootGraphSkipTypeID { - target = reflect.ValueOf(v).Elem() - targetType := target.Type() - if err := f.initRootGraphBudgetType(targetType); err != nil { - return err - } - if targetType == f.rootGraphSkipType { - f.rootGraphSkipTypeID = typeID + target := reflect.ValueOf(v).Elem() + targetType := target.Type() + limit := f.config.MaxGraphMemoryBytes + if limit <= 0 { + f.readCtx.graphMemoryLimitBytes = 0 + f.readCtx.remainingGraphMemoryBytes = MaxInt64 + } else { + f.readCtx.graphMemoryLimitBytes = limit + f.readCtx.remainingGraphMemoryBytes = limit + } + if bytes, ok := f.rootGraphBytesFor(targetType); ok && bytes > 0 { + if !f.readCtx.ReserveGraphMemory(bytes) { + return f.readCtx.TakeError() } } @@ -599,9 +599,6 @@ func (f *Fory) Deserialize(data []byte, v any) error { } // Deserialize the value - TypeMeta is read inline using streaming protocol - if !target.IsValid() { - target = reflect.ValueOf(v).Elem() - } f.readRootValue(target, typeID) if f.readCtx.HasError() { return f.readCtx.TakeError() @@ -688,9 +685,19 @@ func (f *Fory) DeserializeFrom(buf *ByteBuffer, v any) error { f.readCtx.buffer = buf target := reflect.ValueOf(v).Elem() targetType := target.Type() - if err := f.initRootGraphBudgetType(targetType); err != nil { - f.readCtx.buffer = origBuffer - return err + limit := f.config.MaxGraphMemoryBytes + if limit <= 0 { + f.readCtx.graphMemoryLimitBytes = 0 + f.readCtx.remainingGraphMemoryBytes = MaxInt64 + } else { + f.readCtx.graphMemoryLimitBytes = limit + f.readCtx.remainingGraphMemoryBytes = limit + } + if bytes, ok := f.rootGraphBytesFor(targetType); ok && bytes > 0 { + if !f.readCtx.ReserveGraphMemory(bytes) { + f.readCtx.buffer = origBuffer + return f.readCtx.TakeError() + } } readHeader(f.readCtx) @@ -805,8 +812,18 @@ func (f *Fory) DeserializeWithCallbackBuffers(buffer *ByteBuffer, v any, buffers target := rv.Elem() targetType := target.Type() - if err := f.initRootGraphBudgetType(targetType); err != nil { - return err + limit := f.config.MaxGraphMemoryBytes + if limit <= 0 { + f.readCtx.graphMemoryLimitBytes = 0 + f.readCtx.remainingGraphMemoryBytes = MaxInt64 + } else { + f.readCtx.graphMemoryLimitBytes = limit + f.readCtx.remainingGraphMemoryBytes = limit + } + if bytes, ok := f.rootGraphBytesFor(targetType); ok && bytes > 0 { + if !f.readCtx.ReserveGraphMemory(bytes) { + return f.readCtx.TakeError() + } } // ReadData and validate header @@ -1065,6 +1082,14 @@ func Deserialize[T any](f *Fory, data []byte, target *T) error { // Reuse context, reset and set new data f.readCtx.Reset() f.readCtx.SetData(data) + limit := f.config.MaxGraphMemoryBytes + if limit <= 0 { + f.readCtx.graphMemoryLimitBytes = 0 + f.readCtx.remainingGraphMemoryBytes = MaxInt64 + } else { + f.readCtx.graphMemoryLimitBytes = limit + f.readCtx.remainingGraphMemoryBytes = limit + } var targetVal reflect.Value var targetType reflect.Type @@ -1074,8 +1099,10 @@ func Deserialize[T any](f *Fory, data []byte, target *T) error { default: targetVal = reflect.ValueOf(target).Elem() targetType = targetVal.Type() - if err := f.initRootGraphBudgetType(targetType); err != nil { - return err + if bytes, ok := f.rootGraphBytesFor(targetType); ok && bytes > 0 { + if !f.readCtx.ReserveGraphMemory(bytes) { + return f.readCtx.TakeError() + } } } @@ -1233,99 +1260,20 @@ func Deserialize[T any](f *Fory, data []byte, target *T) error { } } -func (f *Fory) initRootGraphBudget(target reflect.Value) error { - if !target.IsValid() { - return f.initRootGraphBudgetType(nil) - } - return f.initRootGraphBudgetType(target.Type()) -} - -func (f *Fory) initRootGraphBudgetType(targetType reflect.Type) error { - if targetType == nil { - f.readCtx.initGraphMemoryBudget() - if f.readCtx.HasError() { - return f.readCtx.TakeError() - } - return nil - } - if targetType == f.rootGraphSkipType { - return nil - } - if targetType == f.rootGraphType && f.rootGraphHasChildren { - return f.initRootGraphBudgetWithSelf(f.rootGraphBytes) - } - return f.initRootGraphBudgetSlow(targetType) -} - -//go:noinline -func (f *Fory) initRootGraphBudgetSlow(targetType reflect.Type) error { - bytes, hasChildren, isStruct := f.rootGraphInfo(targetType) - if !isStruct { - f.readCtx.initGraphMemoryBudget() - if f.readCtx.HasError() { - return f.readCtx.TakeError() - } - return nil - } - if hasChildren { - return f.initRootGraphBudgetWithSelf(bytes) - } - if f.config.MaxGraphMemoryBytes <= 0 || bytes <= f.config.MaxGraphMemoryBytes { - f.rootGraphSkipType = targetType - return nil - } - return f.checkRootGraphSelf(bytes) -} - -func (f *Fory) initRootGraphBudgetWithSelf(bytes int64) error { - limit := f.config.MaxGraphMemoryBytes - if limit <= 0 { - f.readCtx.graphMemoryLimitBytes = 0 - f.readCtx.remainingGraphMemoryBytes = MaxInt64 - return nil - } - if bytes > limit { - return DeserializationErrorf( - "estimated graph memory request %d bytes exceeds maxGraphMemoryBytes remaining budget %d bytes out of effective limit %d bytes", - bytes, limit, limit) - } - f.readCtx.graphMemoryLimitBytes = limit - f.readCtx.remainingGraphMemoryBytes = limit - bytes - return nil -} - -func (f *Fory) rootGraphInfo(targetType reflect.Type) (int64, bool, bool) { +func (f *Fory) rootGraphBytesFor(targetType reflect.Type) (int64, bool) { if targetType == nil || targetType.Kind() != reflect.Struct { - return 0, false, false + return 0, false } if targetType == dateReflectType || targetType == timeReflectType { - return 0, false, true + return 0, true } if targetType == f.rootGraphType { - return f.rootGraphBytes, f.rootGraphHasChildren, true + return f.rootGraphBytes, true } bytes := structGraphBytes(targetType) - hasChildren := typeHasGraphChildren(targetType) f.rootGraphType = targetType f.rootGraphBytes = bytes - f.rootGraphHasChildren = hasChildren - return bytes, hasChildren, true -} - -func (f *Fory) checkRootGraphSelf(bytes int64) error { - if bytes <= 0 { - return nil - } - limit := f.config.MaxGraphMemoryBytes - if limit <= 0 { - return nil - } - if bytes > limit { - return DeserializationErrorf( - "estimated graph memory request %d bytes exceeds maxGraphMemoryBytes remaining budget %d bytes out of effective limit %d bytes", - bytes, limit, limit) - } - return nil + return bytes, true } func (f *Fory) readRootValue(target reflect.Value, typeID unsafe.Pointer) { diff --git a/go/fory/graph_memory_budget_test.go b/go/fory/graph_memory_budget_test.go index b3362aa4f7..b7e3c18520 100644 --- a/go/fory/graph_memory_budget_test.go +++ b/go/fory/graph_memory_budget_test.go @@ -52,25 +52,23 @@ func TestGraphMemoryBudgetConfig(t *testing.T) { func TestGraphMemoryBudgetFixedDefaultAndDisable(t *testing.T) { ctx := NewReadContext(false) - ctx.initGraphMemoryBudget() - require.False(t, ctx.HasError()) + ctx.graphMemoryLimitBytes = 128 * 1024 * 1024 + ctx.remainingGraphMemoryBytes = 128 * 1024 * 1024 require.Equal(t, int64(128*1024*1024), ctx.graphMemoryLimitBytes) require.True(t, ctx.ReserveGraphMemory(ctx.graphMemoryLimitBytes)) require.False(t, ctx.ReserveGraphMemory(1)) require.Contains(t, ctx.CheckError().Error(), "maxGraphMemoryBytes") ctx = NewReadContext(false) - ctx.maxGraphMemoryBytes = 0 - ctx.initGraphMemoryBudget() - require.False(t, ctx.HasError()) + ctx.graphMemoryLimitBytes = 0 + ctx.remainingGraphMemoryBytes = MaxInt64 require.Equal(t, int64(0), ctx.graphMemoryLimitBytes) require.True(t, ctx.ReserveGraphMemory(MaxInt64)) require.False(t, ctx.HasError()) ctx = NewReadContext(false) - ctx.maxGraphMemoryBytes = 77 - ctx.initGraphMemoryBudget() - require.False(t, ctx.HasError()) + ctx.graphMemoryLimitBytes = 77 + ctx.remainingGraphMemoryBytes = 77 require.Equal(t, int64(77), ctx.graphMemoryLimitBytes) } @@ -163,7 +161,6 @@ func TestGraphMemoryBudgetMapAndOverflow(t *testing.T) { require.Contains(t, err.Error(), "maxGraphMemoryBytes") ctx := NewReadContext(false) - ctx.initGraphMemoryBudget() require.False(t, ctx.ReserveGraphMemory(-1)) require.Contains(t, ctx.CheckError().Error(), "non-negative") } diff --git a/go/fory/reader.go b/go/fory/reader.go index c13a39b291..01eecca3bc 100644 --- a/go/fory/reader.go +++ b/go/fory/reader.go @@ -44,7 +44,6 @@ type ReadContext struct { err Error // Accumulated error state for deferred checking lastTypePtr uintptr lastTypeInfo *TypeInfo - maxGraphMemoryBytes int64 graphMemoryLimitBytes int64 remainingGraphMemoryBytes int64 } @@ -84,44 +83,6 @@ func reserveStructGraph(ctx *ReadContext, type_ reflect.Type) bool { return ctx.ReserveGraphMemory(bytes) } -func typeHasGraphChildren(type_ reflect.Type) bool { - for type_.Kind() == reflect.Ptr { - elem := type_.Elem() - if structGraphBytes(elem) != 0 { - return true - } - type_ = elem - } - switch type_.Kind() { - case reflect.Struct: - if type_ == dateReflectType || type_ == timeReflectType { - return false - } - for i := 0; i < type_.NumField(); i++ { - if typeHasGraphChildren(type_.Field(i).Type) { - return true - } - } - return false - case reflect.Slice: - elem := type_.Elem() - switch elem.Kind() { - case reflect.Bool, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, - reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, - reflect.Float32, reflect.Float64: - return false - default: - return true - } - case reflect.Array: - return typeHasGraphChildren(type_.Elem()) - case reflect.Map, reflect.Interface: - return true - default: - return false - } -} - // IsXlang returns whether cross-language serialization mode is enabled func (c *ReadContext) IsXlang() bool { return c.xlang @@ -134,7 +95,6 @@ func NewReadContext(trackRef bool) *ReadContext { refReader: NewRefReader(trackRef), trackRef: trackRef, maxDepth: 128, // Default maximum nesting depth - maxGraphMemoryBytes: 128 * 1024 * 1024, graphMemoryLimitBytes: 128 * 1024 * 1024, remainingGraphMemoryBytes: 128 * 1024 * 1024, } @@ -156,17 +116,6 @@ func (c *ReadContext) Reset() { } } -func (c *ReadContext) initGraphMemoryBudget() { - limit := c.maxGraphMemoryBytes - if limit <= 0 { - c.graphMemoryLimitBytes = 0 - c.remainingGraphMemoryBytes = MaxInt64 - return - } - c.graphMemoryLimitBytes = limit - c.remainingGraphMemoryBytes = limit -} - // ReserveGraphMemory reserves raw estimated graph-owner bytes. func (c *ReadContext) ReserveGraphMemory(bytes int64) bool { if bytes >= 0 { @@ -663,12 +612,7 @@ func (c *ReadContext) ReadStringSlice(refMode RefMode, readType bool) []string { if readType { _ = c.buffer.ReadUint8(err) } - return c.readStringSliceData() -} - -func (c *ReadContext) readStringSliceData() []string { buf := c.buffer - err := c.Err() length := buf.ReadLength(err) if c.HasError() { return nil diff --git a/go/fory/stream.go b/go/fory/stream.go index f8f1e63430..e7bb12bd8b 100644 --- a/go/fory/stream.go +++ b/go/fory/stream.go @@ -97,10 +97,21 @@ func (f *Fory) DeserializeFromStream(is *InputStream, v any) error { origBuffer := f.readCtx.buffer f.readCtx.buffer = is.buffer target := reflect.ValueOf(v).Elem() - if err := f.initRootGraphBudget(target); err != nil { - f.readCtx.buffer = origBuffer - f.resetReadState() - return err + limit := f.config.MaxGraphMemoryBytes + if limit <= 0 { + f.readCtx.graphMemoryLimitBytes = 0 + f.readCtx.remainingGraphMemoryBytes = MaxInt64 + } else { + f.readCtx.graphMemoryLimitBytes = limit + f.readCtx.remainingGraphMemoryBytes = limit + } + if bytes, ok := f.rootGraphBytesFor(target.Type()); ok && bytes > 0 { + if !f.readCtx.ReserveGraphMemory(bytes) { + err := f.readCtx.TakeError() + f.readCtx.buffer = origBuffer + f.resetReadState() + return err + } } defer func() { f.readCtx.buffer = origBuffer @@ -129,8 +140,18 @@ func (f *Fory) DeserializeFromReader(r io.Reader, v any) error { // Always reset to enforce stateless semantics. f.readCtx.buffer.ResetWithReader(r, 0) target := reflect.ValueOf(v).Elem() - if err := f.initRootGraphBudget(target); err != nil { - return err + limit := f.config.MaxGraphMemoryBytes + if limit <= 0 { + f.readCtx.graphMemoryLimitBytes = 0 + f.readCtx.remainingGraphMemoryBytes = MaxInt64 + } else { + f.readCtx.graphMemoryLimitBytes = limit + f.readCtx.remainingGraphMemoryBytes = limit + } + if bytes, ok := f.rootGraphBytesFor(target.Type()); ok && bytes > 0 { + if !f.readCtx.ReserveGraphMemory(bytes) { + return f.readCtx.TakeError() + } } readHeader(f.readCtx) diff --git a/java/fory-core/src/main/java/org/apache/fory/context/ReadContext.java b/java/fory-core/src/main/java/org/apache/fory/context/ReadContext.java index 6ca56c1a0b..4fb8ff1a0d 100644 --- a/java/fory-core/src/main/java/org/apache/fory/context/ReadContext.java +++ b/java/fory-core/src/main/java/org/apache/fory/context/ReadContext.java @@ -116,10 +116,6 @@ public void prepare( this.buffer = buffer; this.peerOutOfBandEnabled = peerOutOfBandEnabled; this.outOfBandBuffers = outOfBandBuffers == null ? null : outOfBandBuffers.iterator(); - initGraphMemoryBudget(); - } - - private void initGraphMemoryBudget() { long limit = maxGraphMemoryBytes; if (limit <= 0) { graphMemoryLimitBytes = 0; diff --git a/python/pyfory/collection.pxi b/python/pyfory/collection.pxi index 807805dcbb..62ef71b6dc 100644 --- a/python/pyfory/collection.pxi +++ b/python/pyfory/collection.pxi @@ -469,17 +469,10 @@ cdef class ListSerializer(CollectionSerializer): cdef int32_t ref_id cdef int64_t i cdef int64_t graph_bytes - cdef int64_t remaining_graph_memory_bytes if len_ < 0: raise ValueError("Container element count is negative") graph_bytes = _OWNER_BYTES + len_ * _REFERENCE_BYTES - remaining_graph_memory_bytes = read_context.remaining_graph_memory_bytes - if graph_bytes > remaining_graph_memory_bytes: - read_context.reserve_graph_memory_fast(graph_bytes) - else: - read_context.remaining_graph_memory_bytes = ( - remaining_graph_memory_bytes - graph_bytes - ) + read_context.reserve_graph_memory_c(graph_bytes) if len_ == 0: list_ = PyList_New(0) return list_ @@ -597,17 +590,10 @@ cdef class TupleSerializer(CollectionSerializer): cdef int8_t head_flag cdef int64_t i cdef int64_t graph_bytes - cdef int64_t remaining_graph_memory_bytes if len_ < 0: raise ValueError("Container element count is negative") graph_bytes = _OWNER_BYTES + len_ * _REFERENCE_BYTES - remaining_graph_memory_bytes = read_context.remaining_graph_memory_bytes - if graph_bytes > remaining_graph_memory_bytes: - read_context.reserve_graph_memory_fast(graph_bytes) - else: - read_context.remaining_graph_memory_bytes = ( - remaining_graph_memory_bytes - graph_bytes - ) + read_context.reserve_graph_memory_c(graph_bytes) if len_ == 0: tuple_ = PyTuple_New(0) return tuple_ @@ -726,19 +712,12 @@ cdef class SetSerializer(CollectionSerializer): cdef int32_t ref_id cdef int64_t i cdef int64_t graph_bytes - cdef int64_t remaining_graph_memory_bytes len_ = buffer.read_var_uint32() if len_ < 0: raise ValueError("Container element count is negative") graph_bytes = _OWNER_BYTES + len_ * _REFERENCE_BYTES - remaining_graph_memory_bytes = read_context.remaining_graph_memory_bytes - if graph_bytes > remaining_graph_memory_bytes: - read_context.reserve_graph_memory_fast(graph_bytes) - else: - read_context.remaining_graph_memory_bytes = ( - remaining_graph_memory_bytes - graph_bytes - ) + read_context.reserve_graph_memory_c(graph_bytes) if len_ == 0: instance = set() read_context.reference(instance) @@ -1089,17 +1068,10 @@ cdef class MapSerializer(Serializer): cdef dict map_ cdef int8_t chunk_header = 0 cdef int64_t graph_bytes - cdef int64_t remaining_graph_memory_bytes if size < 0: raise ValueError("Map entry count is negative") graph_bytes = _OWNER_BYTES + size * (2 * _REFERENCE_BYTES) - remaining_graph_memory_bytes = read_context.remaining_graph_memory_bytes - if graph_bytes > remaining_graph_memory_bytes: - read_context.reserve_graph_memory_fast(graph_bytes) - else: - read_context.remaining_graph_memory_bytes = ( - remaining_graph_memory_bytes - graph_bytes - ) + read_context.reserve_graph_memory_c(graph_bytes) if size == 0: map_ = {} else: diff --git a/python/pyfory/context.pxi b/python/pyfory/context.pxi index 14405b46f5..0a0a0511bb 100644 --- a/python/pyfory/context.pxi +++ b/python/pyfory/context.pxi @@ -835,19 +835,6 @@ cdef class ReadContext: ) self.remaining_graph_memory_bytes -= num_bytes - cdef inline void reserve_graph_memory_fast(self, int64_t num_bytes): - cdef int64_t used - if self.graph_memory_limit_bytes <= 0: - return - if num_bytes > self.remaining_graph_memory_bytes: - used = self.graph_memory_limit_bytes - self.remaining_graph_memory_bytes - raise ValueError( - f"Estimated graph memory budget exceeded: requested {num_bytes} bytes, " - f"used {used} bytes, limit {self.graph_memory_limit_bytes} bytes. " - "Increase Fory(..., max_graph_memory_bytes=...) for trusted larger payloads." - ) - self.remaining_graph_memory_bytes -= num_bytes - cpdef inline reserve_graph_memory(self, num_bytes): if num_bytes < 0: raise ValueError("Estimated graph memory is negative") diff --git a/python/pyfory/struct.pxi b/python/pyfory/struct.pxi index bb813b0a57..7010a5ee8c 100644 --- a/python/pyfory/struct.pxi +++ b/python/pyfory/struct.pxi @@ -424,7 +424,7 @@ cdef class DataClassSerializer(Serializer): f"Hash {read_hash} is not consistent with {self._hash} for type {self.type_}" ) - read_context.reserve_graph_memory_fast( + read_context.reserve_graph_memory_c( _STRUCT_OWNER_BYTES + self._field_runtime_infos.size() * _STRUCT_REFERENCE_BYTES ) obj = self.type_.__new__(self.type_) diff --git a/rust/fory-core/src/context.rs b/rust/fory-core/src/context.rs index 0f57f06b23..0a69e601d8 100644 --- a/rust/fory-core/src/context.rs +++ b/rust/fory-core/src/context.rs @@ -359,9 +359,9 @@ pub struct ReadContext<'a> { max_dyn_depth: u32, check_struct_version: bool, check_string_read: bool, - max_graph_memory_bytes: i64, - graph_memory_limit_bytes: usize, - remaining_graph_memory_bytes: usize, + pub(crate) max_graph_memory_bytes: i64, + pub(crate) graph_memory_limit_bytes: usize, + pub(crate) remaining_graph_memory_bytes: usize, // Context-specific fields pub reader: Reader<'a>, @@ -449,19 +449,6 @@ impl<'a> ReadContext<'a> { self.reader = reader; } - #[inline(always)] - pub(crate) fn init_graph_memory_budget(&mut self) -> Result<(), Error> { - let limit = if self.max_graph_memory_bytes > 0 { - usize::try_from(self.max_graph_memory_bytes) - .map_err(|_| graph_memory_error("max_graph_memory_bytes does not fit usize"))? - } else { - 0 - }; - self.graph_memory_limit_bytes = limit; - self.remaining_graph_memory_bytes = if limit > 0 { limit } else { usize::MAX }; - Ok(()) - } - #[inline(always)] #[doc(hidden)] pub fn reserve_graph_memory(&mut self, bytes: usize) -> Result<(), Error> { @@ -590,12 +577,6 @@ impl<'a> ReadContext<'a> { } } -#[cold] -#[inline(never)] -fn graph_memory_error(message: &'static str) -> Error { - Error::invalid_data(message) -} - #[cold] #[inline(never)] fn graph_memory_exceeded(bytes: usize, remaining: usize, limit: usize) -> Error { diff --git a/rust/fory-core/src/fory.rs b/rust/fory-core/src/fory.rs index ea237b96fc..f73aa0c3aa 100644 --- a/rust/fory-core/src/fory.rs +++ b/rust/fory-core/src/fory.rs @@ -997,8 +997,18 @@ impl Fory { self.with_read_context(|context| { let outlive_buffer = unsafe { mem::transmute::<&[u8], &[u8]>(bf) }; context.attach_reader(Reader::new(outlive_buffer)); - let result = match context.init_graph_memory_budget() { - Ok(()) => self.deserialize_with_context(context), + let result = match if context.max_graph_memory_bytes > 0 { + usize::try_from(context.max_graph_memory_bytes) + .map_err(|_| Error::invalid_data("max_graph_memory_bytes does not fit usize")) + } else { + Ok(0) + } { + Ok(limit) => { + context.graph_memory_limit_bytes = limit; + context.remaining_graph_memory_bytes = + if limit > 0 { limit } else { usize::MAX }; + self.deserialize_with_context(context) + } Err(err) => { context.reset(); Err(err) @@ -1066,8 +1076,18 @@ impl Fory { let mut new_reader = Reader::new(outlive_buffer); new_reader.set_cursor(reader.cursor); context.attach_reader(new_reader); - let result = match context.init_graph_memory_budget() { - Ok(()) => self.deserialize_with_context(context), + let result = match if context.max_graph_memory_bytes > 0 { + usize::try_from(context.max_graph_memory_bytes) + .map_err(|_| Error::invalid_data("max_graph_memory_bytes does not fit usize")) + } else { + Ok(0) + } { + Ok(limit) => { + context.graph_memory_limit_bytes = limit; + context.remaining_graph_memory_bytes = + if limit > 0 { limit } else { usize::MAX }; + self.deserialize_with_context(context) + } Err(err) => { context.reset(); Err(err) diff --git a/swift/Sources/Fory/Fory.swift b/swift/Sources/Fory/Fory.swift index e121bd02e6..7e114a6a1a 100644 --- a/swift/Sources/Fory/Fory.swift +++ b/swift/Sources/Fory/Fory.swift @@ -493,7 +493,8 @@ public final class Fory { _ body: (ReadContext) throws -> R ) throws -> R { readContext.buffer.replace(with: data) - try readContext.initGraphMemoryBudget() + readContext.remainingGraphMemoryBytes = + readContext.maxGraphMemoryBytes > 0 ? readContext.maxGraphMemoryBytes : Int.max defer { readContext.reset() } @@ -555,7 +556,8 @@ public final class Fory { ) throws -> R { try typeResolver.finishRegistration() readContext.buffer.swapState(with: buffer) - try readContext.initGraphMemoryBudget() + readContext.remainingGraphMemoryBytes = + readContext.maxGraphMemoryBytes > 0 ? readContext.maxGraphMemoryBytes : Int.max defer { readContext.buffer.swapState(with: buffer) readContext.reset() diff --git a/swift/Sources/Fory/ReadContext.swift b/swift/Sources/Fory/ReadContext.swift index a8680932f3..7cd50e6abd 100644 --- a/swift/Sources/Fory/ReadContext.swift +++ b/swift/Sources/Fory/ReadContext.swift @@ -35,8 +35,8 @@ public final class ReadContext { private var typeInfoScopeStack: [(typeKey: UInt64, previousTypeInfo: TypeInfo?)] = [] private var lastTypeInfo = TypeInfo.uncached private let config: Config - private let maxGraphMemoryBytes: Int - private var remainingGraphMemoryBytes = Int.max + let maxGraphMemoryBytes: Int + var remainingGraphMemoryBytes = Int.max init( buffer: ByteBuffer, @@ -54,11 +54,6 @@ public final class ReadContext { self.refReader = RefReader() } - @inline(__always) - func initGraphMemoryBudget() throws { - remainingGraphMemoryBytes = maxGraphMemoryBytes > 0 ? maxGraphMemoryBytes : Int.max - } - @inline(__always) public func reserveGraphMemory(_ bytes: Int) throws { if bytes < 0 { diff --git a/swift/Tests/ForyTests/GraphMemoryBudgetTests.swift b/swift/Tests/ForyTests/GraphMemoryBudgetTests.swift index 0bb40f985c..97b91aa7ab 100644 --- a/swift/Tests/ForyTests/GraphMemoryBudgetTests.swift +++ b/swift/Tests/ForyTests/GraphMemoryBudgetTests.swift @@ -154,7 +154,7 @@ func fixedDefaultBudgetAndDisable() throws { config: config ) - try context.initGraphMemoryBudget() + context.remainingGraphMemoryBytes = context.maxGraphMemoryBytes try context.reserveGraphMemory(Int(defaultGraphMemoryBytes)) expectInvalidData { try context.reserveGraphMemory(testReferenceBytes) @@ -166,7 +166,7 @@ func fixedDefaultBudgetAndDisable() throws { typeResolver: TypeResolver(config: disabledConfig), config: disabledConfig ) - try disabled.initGraphMemoryBudget() + disabled.remainingGraphMemoryBytes = Int.max try disabled.reserveGraphMemory(Int(defaultGraphMemoryBytes) + 1) } From fc9ce852583a85c68f09d89de4ac20399fb98e3e Mon Sep 17 00:00:00 2001 From: chaokunyang Date: Fri, 3 Jul 2026 03:14:37 +0800 Subject: [PATCH 34/54] fix: enforce positive graph memory budgets --- .agents/languages/cpp.md | 2 +- .agents/languages/dart.md | 6 +-- .agents/languages/go.md | 4 +- .agents/languages/java.md | 2 +- .agents/languages/javascript.md | 4 +- .agents/languages/python.md | 4 +- .agents/languages/rust.md | 5 +-- .agents/languages/scala.md | 6 ++- .../references/language-command-matrix.md | 3 +- AGENTS.md | 2 +- cpp/fory/serialization/config.h | 2 +- cpp/fory/serialization/context.cc | 14 ++++--- cpp/fory/serialization/context.h | 5 +-- cpp/fory/serialization/fory.h | 8 ++-- .../serialization/graph_memory_budget_test.cc | 16 ++------ csharp/src/Fory/Config.cs | 10 ++++- csharp/src/Fory/Fory.cs | 12 +++--- csharp/src/Fory/ReadContext.cs | 11 ++--- .../Fory.Tests/GraphMemoryBudgetTests.cs | 12 ++---- dart/packages/fory/lib/src/config.dart | 6 ++- .../fory/lib/src/context/read_context.dart | 11 +---- .../fory/test/graph_memory_budget_test.dart | 12 +++--- docs/guide/cpp/configuration.md | 8 ++-- docs/guide/csharp/configuration.md | 3 +- docs/guide/dart/configuration.md | 5 +-- docs/guide/go/configuration.md | 5 +-- docs/guide/java/configuration.md | 7 ++-- docs/guide/javascript/configuration.md | 4 +- docs/guide/python/configuration.md | 10 ++--- docs/guide/rust/configuration.md | 6 +-- docs/guide/swift/configuration.md | 3 +- docs/security/deserialization.md | 5 +-- .../xlang_implementation_guide.md | 6 +-- go/fory/README.md | 3 +- go/fory/fory.go | 40 +++++-------------- go/fory/graph_memory_budget_test.go | 13 ++---- go/fory/reader.go | 3 -- go/fory/stream.go | 18 ++------- .../org/apache/fory/config/ForyBuilder.java | 4 +- .../org/apache/fory/context/ReadContext.java | 13 +----- .../serializer/GraphMemoryBudgetTest.java | 17 ++------ javascript/packages/core/lib/context.ts | 9 +---- javascript/packages/core/lib/fory.ts | 6 ++- javascript/test/graphMemoryBudget.test.ts | 8 ++-- python/pyfory/_fory.py | 7 ++-- python/pyfory/context.pxi | 8 +--- python/pyfory/context.py | 7 +--- python/pyfory/serialization.pyx | 23 +++++------ .../pyfory/tests/test_graph_memory_budget.py | 10 +---- rust/fory-core/src/config.rs | 3 +- rust/fory-core/src/context.rs | 3 -- rust/fory-core/src/fory.rs | 32 +++++++-------- rust/tests/tests/test_graph_memory_budget.rs | 32 ++------------- scala/README.md | 5 ++- .../scala/CollectionSerializerTest.scala | 10 ++--- swift/Sources/Fory/Fory.swift | 9 +++-- swift/Sources/Fory/ReadContext.swift | 5 +-- .../ForyTests/GraphMemoryBudgetTests.swift | 11 +---- 58 files changed, 184 insertions(+), 334 deletions(-) diff --git a/.agents/languages/cpp.md b/.agents/languages/cpp.md index 564b7f1885..0ead61046e 100644 --- a/.agents/languages/cpp.md +++ b/.agents/languages/cpp.md @@ -20,7 +20,7 @@ Load this file when changing `cpp/`, Cython build plumbing, or C++ xlang behavio - Root deserialization graph budgets are owned by `ReadContext` and initialized by the root `Fory::deserialize` overload. Keep `max_graph_memory_bytes` as a fixed-default graph limit: unset/default is `128 MiB`, positive explicit values override it, and explicit non-positive - values intentionally disable graph-memory enforcement. Byte and stream roots use the same + values are invalid at config creation. Byte and stream roots use the same configured/default budget behavior. Reserve estimated shallow graph-owner memory before allocation while preserving existing byte-availability checks and their non-empty metadata ordering. `ReadContext` may expose only raw diff --git a/.agents/languages/dart.md b/.agents/languages/dart.md index 47017e29e8..b6e96751ff 100644 --- a/.agents/languages/dart.md +++ b/.agents/languages/dart.md @@ -16,9 +16,9 @@ Load this file when changing `dart/`. - In `Buffer`, cursor, serializer, and generated-code hot paths, prefer direct byte/local integer operations and conditional import/export files over callbacks, records, holder objects, wrapper round-trips, or runtime platform branches. - Root deserialization graph memory budgets are owned by `ReadContext`; `maxGraphMemoryBytes` defaults to fixed `128 MiB`, positive explicit values override it, and - explicit non-positive values intentionally disable graph-memory enforcement. Do not derive the - budget from `buffer.readableBytes`. `ReadContext` may expose only raw byte reservation; list, set, - map, array, struct, and object formulas + explicit non-positive values are invalid at config creation. Do not derive the budget from + `buffer.readableBytes`. `ReadContext` may expose only raw byte reservation; list, set, map, array, + struct, and object formulas belong in serializer owners. Reserve Dart list/set/object-array reference slots plus nonzero owner self cost, map key/value slots plus nonzero owner self cost, compatible array-to-list materialization, and generated object reads before diff --git a/.agents/languages/go.md b/.agents/languages/go.md index 87442b91af..afce10d12e 100644 --- a/.agents/languages/go.md +++ b/.agents/languages/go.md @@ -9,8 +9,8 @@ Load this file when changing `go/fory/` or Go xlang behavior. - The Go implementation focuses on reflection-based and codegen-based serialization. - Root deserialization graph memory budgets are owned by `ReadContext`. `WithMaxGraphMemoryBytes` uses a fixed `128 MiB` default; positive explicit - values override it, and explicit non-positive values intentionally disable - graph-memory enforcement. Byte-slice and stream roots use the same + values override it, and explicit non-positive values are invalid at config creation. + Byte-slice and stream roots use the same configured/default budget behavior. `ReadContext` may expose only raw byte reservation; slice, map, array, struct, and object formulas belong in handwritten or generated serializer owners. Reserve Go diff --git a/.agents/languages/java.md b/.agents/languages/java.md index a3f9aeacb6..6be988e2da 100644 --- a/.agents/languages/java.md +++ b/.agents/languages/java.md @@ -17,7 +17,7 @@ Load this file when changing anything under `java/` or when Java drives a cross- - Java root deserialization graph memory budgeting belongs to `ReadContext` and is initialized by `Fory` root APIs. Public config is `maxGraphMemoryBytes` with fixed `128 MiB` default. Positive explicit values override the default; - explicit non-positive values intentionally disable graph-memory enforcement. + explicit non-positive values are invalid and must be rejected at config creation. Byte-array, memory-buffer, and stream roots use the same configured/default budget behavior. `ReadContext` may expose only raw byte reservation; diff --git a/.agents/languages/javascript.md b/.agents/languages/javascript.md index d7e24ca968..9bf3f62b3f 100644 --- a/.agents/languages/javascript.md +++ b/.agents/languages/javascript.md @@ -16,8 +16,8 @@ Load this file when changing `javascript/`. - Normalize optional boolean config values at config construction; do not carry `null` through runtime paths when it means `false`. - JavaScript root deserialization graph memory budgeting belongs to `ReadContext`. `maxGraphMemoryBytes` uses a fixed `128 MiB` default, positive explicit limits override it, and - explicit non-positive values intentionally disable graph-memory enforcement. Do not derive the - budget from the `Uint8Array` root length. `ReadContext` may expose only raw + explicit non-positive values are invalid at config creation. Do not derive the budget from the + `Uint8Array` root length. `ReadContext` may expose only raw byte reservation; generated and dynamic list/set/map/array/struct/object readers must reserve before allocation while preserving existing byte checks. Lists/sets/object arrays reserve nonzero owner self cost plus 4-byte reference slots, diff --git a/.agents/languages/python.md b/.agents/languages/python.md index 08db459019..81222297f8 100644 --- a/.agents/languages/python.md +++ b/.agents/languages/python.md @@ -15,8 +15,8 @@ Load this file when changing `python/`, Cython serialization, or Python xlang be - Keep Python and Cython context/ref-tracking branch conditions and stack mutations semantically aligned unless a documented intentional difference exists. - Root deserialization graph memory budgets are owned by pure-Python and Cython `ReadContext`. Keep `max_graph_memory_bytes` public on `pyfory.Fory`/`Config`; the default effective limit is - fixed `128 MiB`, positive explicit values override it, and explicit non-positive values - intentionally disable graph-memory enforcement. Byte and stream roots use the same + fixed `128 MiB`, positive explicit values override it, and explicit non-positive values are + invalid at config creation. Byte and stream roots use the same configured/default budget behavior. `ReadContext` may expose only raw byte reservation; collection, dict, array, struct, and object formulas belong in the pure-Python or Cython serializer owner. Lists, tuples, sets, and diff --git a/.agents/languages/rust.md b/.agents/languages/rust.md index 2c14970133..79458c5126 100644 --- a/.agents/languages/rust.md +++ b/.agents/languages/rust.md @@ -20,9 +20,8 @@ Load this file when changing `rust/` or Rust xlang behavior. - Compatible scalar, list-array, and binary/uint8-array adaptations are immediate-field-only. Keep recursive matched-field shape classification owned by `fory-core/src/meta/type_meta.rs`; collection elements, array elements, map keys, and map values must require exact nullability, ref tracking, generic arity, and type shape except documented user-type family normalization. - Root deserialization graph memory budget state belongs to `ReadContext` and is initialized by the root `Fory` read methods before the header is consumed. Use the fixed `128 MiB` default unless a - positive explicit value overrides it or an explicit non-positive value intentionally disables - graph-memory enforcement; do not derive the budget from root input size or add dynamic bytes-read - accounting. + positive explicit value overrides it; explicit non-positive values are invalid at config + creation. Do not derive the budget from root input size or add dynamic bytes-read accounting. `ReadContext` may expose only raw byte reservation; `Vec`, collection, map, array, struct, object, and derive codec formulas belong in their serializer owners. diff --git a/.agents/languages/scala.md b/.agents/languages/scala.md index 03f22b70b6..cdc9459bed 100644 --- a/.agents/languages/scala.md +++ b/.agents/languages/scala.md @@ -16,6 +16,8 @@ sbt compile # Run tests sbt test -# Format code -sbt scalafmt +# Repo-owned formatter pass for changed files +cd .. && ci/format.sh ``` + +The Scala module does not currently wire a `scalafmt` sbt command. diff --git a/.agents/skills/fory-performance-optimization/references/language-command-matrix.md b/.agents/skills/fory-performance-optimization/references/language-command-matrix.md index abd1fb8a6a..cf22665b8b 100644 --- a/.agents/skills/fory-performance-optimization/references/language-command-matrix.md +++ b/.agents/skills/fory-performance-optimization/references/language-command-matrix.md @@ -79,7 +79,8 @@ Canonical runtime-specific rules now live under `../../../languages/*.md` and `. - Build: `sbt compile` - Tests: `sbt test` -- Format: `sbt scalafmt` +- Format: run the repo formatter from the repository root with `ci/format.sh`; the Scala module + does not currently wire a `scalafmt` sbt command. ## Cross-Language Xlang Verification diff --git a/AGENTS.md b/AGENTS.md index a275acdbe7..fd0ccf815e 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -33,7 +33,7 @@ This is the entry point for AI guidance in Apache Fory. Read this file first, th - Check the spec before implementation. For wire behavior and xlang mapping, use the specs as the source of truth and never copy one runtime's bug into another runtime just to make tests pass. - Do not make assumptions about runtime behavior, ownership, registration, metadata construction, protocol semantics, or test coverage. Read the current code, owning docs/specs, and relevant tests before making a design judgment or implementation decision. If the evidence is incomplete, inspect more or state the uncertainty explicitly instead of filling gaps from memory or analogy with another runtime. - For untrusted deserialization, read `docs/security/deserialization.md` before changing allocation, stream filling, skip, reference, metadata, or policy validation behavior. Variable-length deserialization must not allocate or reserve backing/output capacity from attacker-declared lengths or counts before the byte owner has proven proportional readable bytes with `checkReadableBytes` or the runtime equivalent. Root graph memory reservation is accounting only and may happen before that byte check, but it must not replace the byte check. -- Root deserialization graph memory budgets estimate lower-bound shallow memory for materialized graph owners, not exact heap accounting, input byte accounting, or raw element counts. `maxGraphMemoryBytes` defaults to fixed `128 MiB`; positive values override the default; explicit non-positive values intentionally disable graph-memory enforcement and must be documented as deserialization DoS risk for compact inputs that materialize large graphs. Do not derive this budget from root input size, and do not split known-length and stream root behavior. Read context/read state owns only raw byte reservation with `reserveGraphMemory(bytes)`; it must not expose counted arithmetic helpers or collection, map, array, struct, or object semantic reservation APIs. Do not add any non-memory-budget read-context/read-state API for this feature, including ref-publication controls, temporary-owner controls, serializer-owner controls, conversion helpers, or APIs that encode what kind of value is being materialized. Concrete serializers and generated serializers own counted formulas, overflow checks, allocation-owner decisions, and reference publication timing for their allocation path. Reserve self storage exactly once at the owner that stores or allocates the value: reference/object runtimes reserve parent owner self cost plus reference storage and every referenced heap owner reserves its own shallow self cost when materialized; inline/value runtimes reserve value storage in the root, struct/product, collection, map, set, array, smart-pointer, or box owner that actually owns that storage, and nested value serializers do not unconditionally reserve their own `sizeof(T)` again. Collection, map, set, and reference-array owners reserve nonzero shallow self cost when independently materialized, plus backing/reference/inline storage. Struct, record, POJO, tuple/product, compatible, generated, and dynamic object owners reserve a nonzero shallow self cost plus shallow field storage. Reference fields use 4 bytes when reference size is not cheap or reliable to query; primitive/value fields use their storage width. Parents do not recursively include child object, collection, map, string, binary, or primitive dense-array contents. Skip enum/union as separate owners and skip dedicated string, binary, primitive scalar, primitive array, and primitive dense-array leaf owners. Do not guess allocator, bucket-table, node, debug, or map-entry overhead unless it is a documented lower-bound owner allocation. Do not add dynamic stream bytes-read accounting or nested hot-path cleanup just for this budget. +- Root deserialization graph memory budgets estimate lower-bound shallow memory for materialized graph owners, not exact heap accounting, input byte accounting, or raw element counts. `maxGraphMemoryBytes` defaults to fixed `128 MiB`; positive values override the default; explicit non-positive values are invalid and must be rejected at config/Fory creation. Do not add a disabled-budget sentinel path, derive this budget from root input size, or split known-length and stream root behavior. Read context/read state owns only raw byte reservation with `reserveGraphMemory(bytes)`; it must not expose counted arithmetic helpers or collection, map, array, struct, or object semantic reservation APIs. Do not add any non-memory-budget read-context/read-state API for this feature, including ref-publication controls, temporary-owner controls, serializer-owner controls, conversion helpers, or APIs that encode what kind of value is being materialized. Concrete serializers and generated serializers own counted formulas, overflow checks, allocation-owner decisions, and reference publication timing for their allocation path. Reserve self storage exactly once at the owner that stores or allocates the value: reference/object runtimes reserve parent owner self cost plus reference storage and every referenced heap owner reserves its own shallow self cost when materialized; inline/value runtimes reserve value storage in the root, struct/product, collection, map, set, array, smart-pointer, or box owner that actually owns that storage, and nested value serializers do not unconditionally reserve their own `sizeof(T)` again. Collection, map, set, and reference-array owners reserve nonzero shallow self cost when independently materialized, plus backing/reference/inline storage. Struct, record, POJO, tuple/product, compatible, generated, and dynamic object owners reserve a nonzero shallow self cost plus shallow field storage. Reference fields use 4 bytes when reference size is not cheap or reliable to query; primitive/value fields use their storage width. Parents do not recursively include child object, collection, map, string, binary, or primitive dense-array contents. Skip enum/union as separate owners and skip dedicated string, binary, primitive scalar, primitive array, and primitive dense-array leaf owners. Do not guess allocator, bucket-table, node, debug, or map-entry overhead unless it is a documented lower-bound owner allocation. Do not add dynamic stream bytes-read accounting or nested hot-path cleanup just for this budget. - For remote TypeDef/TypeMeta reads, the checked metadata cache is the only owner of remote "already validated" state. Cache hit means the header was previously parsed, body/hash-validated, policy-checked, and published by that cache, so the hot path must skip the body and use cached metadata without extra validation, hashing, limit checks, exact-local checks, allocation, or policy work. A known expected local TypeDef/TypeMeta header/hash match is a local-schema hit, not a remote cache miss: it may skip the body and use the local TypeInfo/TypeMeta without schema-version counting or cache publish. Cache miss is the only path that parses and validates non-local metadata, enforces limits, performs exact-local byte comparison when needed, and publishes remote metadata to the cache. Do not add nullable accepted-header fields, sentinel headers, per-TypeInfo markers, pending metadata state, parallel header-low/header-high slots, or parallel acceptance state for this decision. If a runtime needs a metadata hit hint, cache the concrete checked metadata owner object, such as the TypeInfo, TypeDef, or TypeMeta used by that runtime, and compare its validated header identity directly. - When a user corrects a non-obvious invariant, encode it in the nearest source comment before continuing, and also update `AGENTS.md`, `.agents/**`, docs, or specs when the rule is reusable beyond one file. Do not rely only on chat history, task notes, commit messages, or benchmark logs for corrections that protect security, protocol behavior, ownership, naming, or hot-path performance. - Reject semantic hacks. Do not bypass broken semantics by deleting cases, simplifying callers, adding coercion hooks, or using workaround fallbacks; fix the underlying bug and prove it with focused tests. diff --git a/cpp/fory/serialization/config.h b/cpp/fory/serialization/config.h index 910f72c37a..84d3b89a96 100644 --- a/cpp/fory/serialization/config.h +++ b/cpp/fory/serialization/config.h @@ -53,7 +53,7 @@ struct Config { bool track_ref = true; /// Maximum estimated graph memory accepted during one root deserialization. - /// Positive values are byte limits; non-positive values disable enforcement. + /// Value must be a positive byte limit. int64_t max_graph_memory_bytes = 128LL * 1024LL * 1024LL; /// Maximum accepted field count in one received struct TypeMeta. diff --git a/cpp/fory/serialization/context.cc b/cpp/fory/serialization/context.cc index 3888882b5b..8e27c80491 100644 --- a/cpp/fory/serialization/context.cc +++ b/cpp/fory/serialization/context.cc @@ -434,11 +434,13 @@ uint32_t WriteContext::get_type_id_for_cache(const std::type_index &type_idx) { ReadContext::ReadContext(const Config &config, std::unique_ptr type_resolver) : buffer_(nullptr), config_(&config), - type_resolver_(std::move(type_resolver)), current_dyn_depth_(0), - graph_memory_limit_bytes_( - config.max_graph_memory_bytes > 0 - ? static_cast(config.max_graph_memory_bytes) - : size_t{0}) {} + type_resolver_(std::move(type_resolver)), current_dyn_depth_(0) { + FORY_CHECK(config.max_graph_memory_bytes > 0) + << "max_graph_memory_bytes must be positive"; + graph_memory_limit_bytes_ = + static_cast(config.max_graph_memory_bytes); + remaining_graph_memory_bytes_ = graph_memory_limit_bytes_; +} ReadContext::~ReadContext() = default; @@ -764,7 +766,7 @@ void ReadContext::reset() { } reading_type_infos_.clear(); current_dyn_depth_ = 0; - remaining_graph_memory_bytes_ = std::numeric_limits::max(); + remaining_graph_memory_bytes_ = 0; if (meta_string_table_active_) { meta_string_table_.reset(); meta_string_table_active_ = false; diff --git a/cpp/fory/serialization/context.h b/cpp/fory/serialization/context.h index 117d2ed27a..6555481a65 100644 --- a/cpp/fory/serialization/context.h +++ b/cpp/fory/serialization/context.h @@ -507,9 +507,6 @@ class ReadContext { FORY_ALWAYS_INLINE bool reserve_graph_memory(size_t bytes) { const size_t remaining = remaining_graph_memory_bytes_; - if (FORY_PREDICT_FALSE(remaining == std::numeric_limits::max())) { - return true; - } if (FORY_PREDICT_FALSE(bytes > remaining)) { return set_graph_memory_exceeded(bytes, remaining); } @@ -689,7 +686,7 @@ class ReadContext { RefReader ref_reader_; uint32_t current_dyn_depth_; size_t graph_memory_limit_bytes_ = 0; - size_t remaining_graph_memory_bytes_ = std::numeric_limits::max(); + size_t remaining_graph_memory_bytes_ = 0; // Meta sharing state (for compatible mode) // Persistent cache storage for TypeInfo objects keyed by meta header. diff --git a/cpp/fory/serialization/fory.h b/cpp/fory/serialization/fory.h index aa6a976059..ed8541895b 100644 --- a/cpp/fory/serialization/fory.h +++ b/cpp/fory/serialization/fory.h @@ -112,9 +112,9 @@ class ForyBuilder { /// Set maximum estimated graph memory for one root deserialization. /// - /// Defaults to 128 MiB. Positive values are explicit byte limits; - /// non-positive values intentionally disable this protection. + /// Defaults to 128 MiB. Values must be positive byte limits. ForyBuilder &max_graph_memory_bytes(int64_t max_bytes) { + FORY_CHECK(max_bytes > 0) << "max_graph_memory_bytes must be positive"; config_.max_graph_memory_bytes = max_bytes; return *this; } @@ -888,9 +888,7 @@ class Fory : public BaseFory { read_ctx_->attach(buffer); read_ctx_->remaining_graph_memory_bytes_ = - read_ctx_->graph_memory_limit_bytes_ != 0 - ? read_ctx_->graph_memory_limit_bytes_ - : std::numeric_limits::max(); + read_ctx_->graph_memory_limit_bytes_; if constexpr (needs_graph_budget_v) { constexpr size_t root_owner_bytes = graph_value_owner_self_bytes(); if constexpr (root_owner_bytes != 0) { diff --git a/cpp/fory/serialization/graph_memory_budget_test.cc b/cpp/fory/serialization/graph_memory_budget_test.cc index 89937b0f74..ab8e63e672 100644 --- a/cpp/fory/serialization/graph_memory_budget_test.cc +++ b/cpp/fory/serialization/graph_memory_budget_test.cc @@ -126,22 +126,12 @@ void expect_budget_boundary(const T &value, size_t required) { EXPECT_EQ(exact_result.value(), value); } -TEST(GraphMemoryBudgetTest, FixedDefaultBudgetAndDisable) { +TEST(GraphMemoryBudgetTest, FixedDefaultBudgetAndValidation) { Config config; EXPECT_EQ(config.max_graph_memory_bytes, kDefaultGraphMemoryBytes); - Config disabled_config; - disabled_config.max_graph_memory_bytes = 0; - EXPECT_EQ(disabled_config.max_graph_memory_bytes, 0); - - constexpr size_t count = 3; - std::vector> value(count); - auto bytes = serialize_value(value); - auto disabled_result = with_fory(0, [&](Fory &fory) { - return fory.deserialize>>(bytes); - }); - ASSERT_TRUE(disabled_result.ok()) << disabled_result.error().to_string(); - EXPECT_EQ(disabled_result.value(), value); + EXPECT_DEATH((void)Fory::builder().max_graph_memory_bytes(0), + "max_graph_memory_bytes"); } TEST(GraphMemoryBudgetTest, RootKindsShareConfiguredBudget) { diff --git a/csharp/src/Fory/Config.cs b/csharp/src/Fory/Config.cs index 4aeed16315..e84193af50 100644 --- a/csharp/src/Fory/Config.cs +++ b/csharp/src/Fory/Config.cs @@ -54,6 +54,10 @@ internal Config( { throw new ArgumentOutOfRangeException(nameof(maxAverageSchemaVersionsPerType), "MaxAverageSchemaVersionsPerType must be greater than 0."); } + if (maxGraphMemoryBytes <= 0) + { + throw new ArgumentOutOfRangeException(nameof(maxGraphMemoryBytes), "MaxGraphMemoryBytes must be greater than 0."); + } TrackRef = trackRef; Compatible = compatible; @@ -179,10 +183,14 @@ public ForyBuilder MaxDepth(int value) /// /// Sets the maximum estimated graph memory accepted during one root deserialization. - /// Positive values are byte limits. Explicit non-positive values disable this budget. /// public ForyBuilder MaxGraphMemoryBytes(long value) { + if (value <= 0) + { + throw new ArgumentOutOfRangeException(nameof(value), "MaxGraphMemoryBytes must be greater than 0."); + } + _maxGraphMemoryBytes = value; return this; } diff --git a/csharp/src/Fory/Fory.cs b/csharp/src/Fory/Fory.cs index 2143682e2f..9ac66fe5d8 100644 --- a/csharp/src/Fory/Fory.cs +++ b/csharp/src/Fory/Fory.cs @@ -193,8 +193,8 @@ public T Deserialize(ReadOnlySpan payload) ByteReader reader = _readContext.Reader; reader.Reset(payload); long graphLimit = Config.MaxGraphMemoryBytes; - _readContext._graphMemoryLimitBytes = graphLimit > 0 ? graphLimit : 0; - _readContext._remainingGraphMemoryBytes = graphLimit > 0 ? graphLimit : long.MaxValue; + _readContext._graphMemoryLimitBytes = graphLimit; + _readContext._remainingGraphMemoryBytes = graphLimit; T value = DeserializeFromReader(reader); if (reader.Remaining != 0) { @@ -216,8 +216,8 @@ public T Deserialize(byte[] payload) ByteReader reader = _readContext.Reader; reader.Reset(payload); long graphLimit = Config.MaxGraphMemoryBytes; - _readContext._graphMemoryLimitBytes = graphLimit > 0 ? graphLimit : 0; - _readContext._remainingGraphMemoryBytes = graphLimit > 0 ? graphLimit : long.MaxValue; + _readContext._graphMemoryLimitBytes = graphLimit; + _readContext._remainingGraphMemoryBytes = graphLimit; T value = DeserializeFromReader(reader); if (reader.Remaining != 0) { @@ -239,8 +239,8 @@ public T Deserialize(ref ReadOnlySequence payload) ByteReader reader = _readContext.Reader; reader.Reset(bytes); long graphLimit = Config.MaxGraphMemoryBytes; - _readContext._graphMemoryLimitBytes = graphLimit > 0 ? graphLimit : 0; - _readContext._remainingGraphMemoryBytes = graphLimit > 0 ? graphLimit : long.MaxValue; + _readContext._graphMemoryLimitBytes = graphLimit; + _readContext._remainingGraphMemoryBytes = graphLimit; T value = DeserializeFromReader(reader); payload = payload.Slice(reader.Cursor); return value; diff --git a/csharp/src/Fory/ReadContext.cs b/csharp/src/Fory/ReadContext.cs index 42da224a30..bd0283c9a4 100644 --- a/csharp/src/Fory/ReadContext.cs +++ b/csharp/src/Fory/ReadContext.cs @@ -45,8 +45,8 @@ public sealed class ReadContext private readonly Dictionary _remoteSchemaVersionsByType = []; private readonly Config _config; private int _totalAcceptedSchemaVersions; - internal long _graphMemoryLimitBytes = long.MaxValue; - internal long _remainingGraphMemoryBytes = long.MaxValue; + internal long _graphMemoryLimitBytes; + internal long _remainingGraphMemoryBytes; public ReadContext( ByteReader reader, @@ -63,6 +63,8 @@ public ReadContext( RefReader = new RefReader(); _maxDynamicReadDepth = config.MaxDepth; _config = config; + _graphMemoryLimitBytes = config.MaxGraphMemoryBytes; + _remainingGraphMemoryBytes = config.MaxGraphMemoryBytes; } public ByteReader Reader { get; private set; } @@ -105,11 +107,6 @@ private void ReserveGraphMemorySlow(long bytes, long remaining) throw new InvalidDataException("graph memory estimate overflows"); } - if (_graphMemoryLimitBytes <= 0) - { - return; - } - throw new InvalidDataException( $"estimated graph memory request {bytes} bytes exceeds MaxGraphMemoryBytes remaining budget {remaining} bytes out of effective limit {_graphMemoryLimitBytes} bytes"); } diff --git a/csharp/tests/Fory.Tests/GraphMemoryBudgetTests.cs b/csharp/tests/Fory.Tests/GraphMemoryBudgetTests.cs index 2c73e744d5..60b65c0e30 100644 --- a/csharp/tests/Fory.Tests/GraphMemoryBudgetTests.cs +++ b/csharp/tests/Fory.Tests/GraphMemoryBudgetTests.cs @@ -142,23 +142,17 @@ private static long MapBudget(int count) } [Fact] - public void DefaultFixedBudgetAndDisable() + public void DefaultFixedBudgetAndValidation() { Assert.Equal(DefaultGraphMemoryBytes, NewFory().Config.MaxGraphMemoryBytes); - Assert.Equal(0, NewFory(0).Config.MaxGraphMemoryBytes); - Assert.Equal(-2, NewFory(-2).Config.MaxGraphMemoryBytes); + Assert.Throws(() => NewFory(0)); + Assert.Throws(() => NewFory(-2)); ReadContext context = new(new ByteReader([]), new TypeResolver(), NewFory().Config); context._graphMemoryLimitBytes = DefaultGraphMemoryBytes; context._remainingGraphMemoryBytes = DefaultGraphMemoryBytes; context.ReserveGraphMemory(DefaultGraphMemoryBytes); Assert.Throws(() => context.ReserveGraphMemory(ReferenceBytes)); - - ReadContext disabled = new(new ByteReader([]), new TypeResolver(), NewFory(0).Config); - disabled._graphMemoryLimitBytes = 0; - disabled._remainingGraphMemoryBytes = long.MaxValue; - disabled.ReserveGraphMemory(long.MaxValue); - Assert.Throws(() => disabled.ReserveGraphMemory(-1)); } [Fact] diff --git a/dart/packages/fory/lib/src/config.dart b/dart/packages/fory/lib/src/config.dart index 6232285dbe..2207819321 100644 --- a/dart/packages/fory/lib/src/config.dart +++ b/dart/packages/fory/lib/src/config.dart @@ -59,7 +59,7 @@ final class Config { /// Maximum estimated graph memory per root deserialization. /// - /// Positive values are explicit byte limits. Non-positive values disable enforcement. + /// Value must be a positive byte limit. final int maxGraphMemoryBytes; /// Creates an immutable configuration object. @@ -87,5 +87,9 @@ final class Config { assert( maxAverageSchemaVersionsPerType > 0, 'maxAverageSchemaVersionsPerType must be positive', + ), + assert( + maxGraphMemoryBytes > 0 && maxGraphMemoryBytes <= 9007199254740991, + 'maxGraphMemoryBytes must be a positive safe integer', ); } diff --git a/dart/packages/fory/lib/src/context/read_context.dart b/dart/packages/fory/lib/src/context/read_context.dart index 3a0cbc1b58..c978ad2193 100644 --- a/dart/packages/fory/lib/src/context/read_context.dart +++ b/dart/packages/fory/lib/src/context/read_context.dart @@ -56,7 +56,6 @@ final class ReadContext { late Buffer _buffer; final List _sharedTypes = []; int _depth = 0; - int _effectiveGraphMemoryBytes = 0; int _remainingGraphMemoryBytes = 0; @internal @@ -71,12 +70,10 @@ final class ReadContext { @pragma('vm:prefer-inline') void prepare(Buffer buffer) { _buffer = buffer; - final configured = config.maxGraphMemoryBytes; - final limit = configured > 0 ? configured : 0; + final limit = config.maxGraphMemoryBytes; if (limit > _maxSafeBudgetBytes) { _throwGraphMemoryOverflow(limit); } - _effectiveGraphMemoryBytes = limit; _remainingGraphMemoryBytes = limit; } @@ -86,7 +83,6 @@ final class ReadContext { _refReader.reset(); _metaStringReader.reset(); _depth = 0; - _effectiveGraphMemoryBytes = 0; _remainingGraphMemoryBytes = 0; } @@ -105,9 +101,6 @@ final class ReadContext { if (bytes < 0 || bytes > _maxSafeBudgetBytes) { _throwGraphMemoryOverflow(bytes); } - if (_effectiveGraphMemoryBytes <= 0) { - return; - } final remaining = _remainingGraphMemoryBytes - bytes; if (remaining < 0) { _throwGraphMemoryExceeded(bytes); @@ -127,7 +120,7 @@ final class ReadContext { throw StateError( 'maxGraphMemoryBytes exceeded: requested $bytes estimated graph bytes, ' '$_remainingGraphMemoryBytes remaining, effective limit ' - '$_effectiveGraphMemoryBytes.', + '${config.maxGraphMemoryBytes}.', ); } diff --git a/dart/packages/fory/test/graph_memory_budget_test.dart b/dart/packages/fory/test/graph_memory_budget_test.dart index eadddb5611..299f6992c5 100644 --- a/dart/packages/fory/test/graph_memory_budget_test.dart +++ b/dart/packages/fory/test/graph_memory_budget_test.dart @@ -145,19 +145,21 @@ void main() { expect(() => context.reserveGraphMemory(1), _throwsGraphBudget); }); - test('explicit config overrides default and non-positive disables', () { + test('explicit config overrides default and invalid config fails', () { final buffer = Buffer.wrap(Uint8List(4096)); final context = _readContext(buffer, maxGraphMemoryBytes: 31); expect(() => context.reserveGraphMemory(31), returnsNormally); expect(() => context.reserveGraphMemory(1), _throwsGraphBudget); - final disabled = _readContext(buffer, maxGraphMemoryBytes: 0); expect( - () => disabled.reserveGraphMemory(_defaultGraphMemoryBytes + 1), - returnsNormally, + () => Fory(maxGraphMemoryBytes: 0), + throwsA(isA()), + ); + expect( + () => Fory(maxGraphMemoryBytes: -2), + throwsA(isA()), ); - expect(() => Fory(maxGraphMemoryBytes: -2), returnsNormally); }); test('uses parent storage for nested empty containers', () { diff --git a/docs/guide/cpp/configuration.md b/docs/guide/cpp/configuration.md index 2c4ec2524b..4239fcaed9 100644 --- a/docs/guide/cpp/configuration.md +++ b/docs/guide/cpp/configuration.md @@ -108,9 +108,8 @@ auto fory = Fory::builder() ``` The default limit is a fixed `128 MiB` for byte-array, `Buffer`, and stream -roots. Positive values override the default. Explicit non-positive values -disable this budget and can expose deserialization DoS risk from compact inputs -that materialize large object graphs. +roots. Positive values override the default. Explicit non-positive values are +rejected when the runtime is created. This budget is a portable lower-bound estimate for shallow materialized graph owners such as dynamic collection backing storage, map key/value storage, @@ -247,8 +246,7 @@ Security-related configuration: - Register all structs and polymorphic implementations before deserializing untrusted payloads. - Use `check_struct_version(true)` with `compatible(false)` for intentional same-schema payloads. - Keep `max_graph_memory_bytes(...)` at the fixed `128 MiB` default for most inputs, or set a - positive value for a trusted workload that needs a different envelope. Avoid explicit - non-positive values for untrusted data because they disable graph-memory enforcement. + positive value for a trusted workload that needs a different envelope. - Keep `max_dyn_depth(...)` as low as your model permits to reject unexpectedly deep polymorphic graphs. - Keep the remote schema metadata limits at their defaults unless the data is not malicious and a diff --git a/docs/guide/csharp/configuration.md b/docs/guide/csharp/configuration.md index d08c622791..841964d955 100644 --- a/docs/guide/csharp/configuration.md +++ b/docs/guide/csharp/configuration.md @@ -108,8 +108,7 @@ Fory fory = Fory.Builder() ``` The default limit is a fixed `128 MiB` for all root input forms. A positive value overrides the -default. Passing an explicit non-positive value disables this budget and can expose deserialization -DoS risk from compact inputs that materialize large object graphs. +default. Explicit non-positive values are rejected when the runtime is created. ### `MaxTypeFields(int value)` diff --git a/docs/guide/dart/configuration.md b/docs/guide/dart/configuration.md index d832c28d7b..2fced7675f 100644 --- a/docs/guide/dart/configuration.md +++ b/docs/guide/dart/configuration.md @@ -124,8 +124,7 @@ payloads: final fory = Fory(maxGraphMemoryBytes: 256 * 1024 * 1024); ``` -Passing an explicit non-positive value disables this budget and can expose deserialization DoS risk -from compact inputs that materialize large object graphs. +Explicit non-positive values are rejected when the runtime is created. ## Defaults @@ -156,7 +155,7 @@ Security-related configuration: - Use `checkStructVersion: true` with `compatible: false` for intentional same-schema payloads. - Set `maxDepth` to reject unexpectedly deep payload shapes. - Keep `maxGraphMemoryBytes` at the default for most inputs, or set an explicit positive byte limit - for known trusted graph-heavy payloads. Avoid disabling it for untrusted data. + for known trusted graph-heavy payloads. - Keep the remote schema metadata limits at their defaults unless the data is not malicious and a trusted peer sends larger metadata or many schema versions. - Prefer generated schemas and explicit field metadata over broad dynamic fields for untrusted input. diff --git a/docs/guide/go/configuration.md b/docs/guide/go/configuration.md index 201c508df6..e29202acc7 100644 --- a/docs/guide/go/configuration.md +++ b/docs/guide/go/configuration.md @@ -138,9 +138,8 @@ f := fory.New(fory.WithMaxGraphMemoryBytes(256 * 1024 * 1024)) ``` The default limit is a fixed `128 MiB` for byte-slice and stream roots. A -positive value overrides the default. Passing an explicit non-positive value -disables this budget and can expose deserialization DoS risk from compact inputs -that materialize large object graphs. The budget covers lower-bound slice +positive value overrides the default. Explicit non-positive values are rejected +when the runtime is created. The budget covers lower-bound slice backing storage, map key/value storage, sets, generated object reads, and materialized struct field storage. Strings, binary blobs, and primitive dense array owners keep their byte-availability checks and are not reserved against diff --git a/docs/guide/java/configuration.md b/docs/guide/java/configuration.md index 73827559b6..f947ac6ed7 100644 --- a/docs/guide/java/configuration.md +++ b/docs/guide/java/configuration.md @@ -38,7 +38,7 @@ This page documents all configuration options available through `ForyBuilder`. | `registerGuavaTypes` | Whether to pre-register Guava types such as `RegularImmutableMap`/`RegularImmutableList`. These types are not public API, but seem pretty stable. | `true` | | `requireClassRegistration` | Disabling may allow unknown classes to be deserialized, potentially causing security risks. | `true` | | `maxDepth` | Set max depth for deserialization, when depth exceeds, an exception will be thrown. This can be used to refuse deserialization DDOS attack. | `50` | -| `maxGraphMemoryBytes` | Maximum estimated shallow graph memory accepted during one root deserialization. The default is a fixed `128 MiB`; positive values set an explicit byte limit. Explicit non-positive values disable this budget and can expose deserialization DoS risk from compact inputs that materialize large object graphs. | `134217728` | +| `maxGraphMemoryBytes` | Maximum estimated shallow graph memory accepted during one root deserialization. The default is a fixed `128 MiB`; positive values set an explicit byte limit. Explicit non-positive values are rejected when the runtime is created. | `134217728` | | `maxTypeFields` | Maximum fields accepted in one received remote struct metadata body. | `512` | | `maxTypeMetaBytes` | Maximum encoded body bytes accepted for one received TypeDef or TypeMeta body, excluding the 8-byte header and any extended-size varint. | `4096` | | `maxSchemaVersionsPerType` | Maximum accepted remote metadata versions for one logical type. | `10` | @@ -101,9 +101,8 @@ Security-related options: - `withMaxDepth(...)` rejects unexpectedly deep object graphs. - `withMaxGraphMemoryBytes(...)` bounds estimated shallow graph memory during one root deserialization. The default is a fixed `128 MiB`; set a positive byte limit when trusted - workloads need a larger or smaller limit. Passing an explicit non-positive value disables this - budget and can expose deserialization DoS risk from compact inputs that materialize large object - graphs. + workloads need a larger or smaller limit. Explicit non-positive values are rejected when the + runtime is created. - `withMaxTypeFields(...)` and `withMaxTypeMetaBytes(...)` bound the field count and encoded body size of one received remote metadata body. - `withMaxSchemaVersionsPerType(...)` and diff --git a/docs/guide/javascript/configuration.md b/docs/guide/javascript/configuration.md index 6b02d50828..6b96329c92 100644 --- a/docs/guide/javascript/configuration.md +++ b/docs/guide/javascript/configuration.md @@ -109,9 +109,7 @@ const fory = new Fory({ }); ``` -Passing an explicit non-positive value disables this budget and can expose -deserialization DoS risk from compact inputs that materialize large object -graphs. +Explicit non-positive values are rejected when the runtime is created. String, binary, and dedicated dense primitive array payloads keep their normal byte-size checks and do not consume this graph budget. Raise the limit only for diff --git a/docs/guide/python/configuration.md b/docs/guide/python/configuration.md index f9bfd1ac96..c798b4cea2 100644 --- a/docs/guide/python/configuration.md +++ b/docs/guide/python/configuration.md @@ -71,7 +71,7 @@ class ThreadSafeFory: | `max_type_meta_bytes` | `int` | `4096` | Maximum encoded body bytes accepted for one received TypeDef body, excluding the 8-byte header and any extended-size varint. | | `max_schema_versions_per_type` | `int` | `10` | Maximum accepted remote metadata versions for one logical type. | | `max_average_schema_versions_per_type` | `int` | `3` | Average accepted remote metadata versions across accepted remote types. The effective global floor is `8192` schemas. | -| `max_graph_memory_bytes` | `int` | `134217728` | Maximum estimated shallow graph memory for one root deserialization. Explicit non-positive values disable this budget. | +| `max_graph_memory_bytes` | `int` | `134217728` | Maximum estimated shallow graph memory for one root deserialization. Explicit non-positive values are rejected. | | `policy` | `DeserializationPolicy \| None` | `None` | Deserialization policy used for security checks. Strongly recommended when `strict=False`. | | `field_nullable` | `bool` | `False` | Treat dataclass fields as nullable by default. | | `meta_compressor` | `Any` | `None` | Optional metadata compressor used for compatible-mode metadata encoding. | @@ -228,9 +228,8 @@ Received remote metadata is also limited: - `max_graph_memory_bytes` limits estimated shallow graph memory created during one root deserialization, including materialized lists, tuples, sets, dicts, object arrays, structs, and Python objects. The default is a fixed `128 MiB` for all root input forms. Set a positive byte - value for trusted payloads that legitimately contain larger or smaller object graphs. Passing an - explicit non-positive value disables this budget and can expose deserialization DoS risk from - compact inputs that materialize large object graphs. + value for trusted payloads that legitimately contain larger or smaller object graphs. Explicit + non-positive values are rejected when the runtime is created. These limits do not change `strict`, `policy`, dynamic loading, unknown-class handling, or schema-evolution semantics. @@ -288,8 +287,7 @@ unchanged. - Use `DeserializationPolicy` when `strict=False` is necessary. - Keep `max_depth` low enough to reject unexpectedly deep payloads. - Keep `max_graph_memory_bytes` at the fixed `128 MiB` default for most inputs, or set a positive - explicit limit for trusted workloads with different legitimate object-graph sizes. Avoid - explicit non-positive values for untrusted data because they disable graph-memory enforcement. + explicit limit for trusted workloads with different legitimate object-graph sizes. - Do not treat xlang/native mode choice as a security control. ## Related Topics diff --git a/docs/guide/rust/configuration.md b/docs/guide/rust/configuration.md index 7426f968f6..37455840e5 100644 --- a/docs/guide/rust/configuration.md +++ b/docs/guide/rust/configuration.md @@ -124,8 +124,7 @@ let fory = Fory::builder() .build(); ``` -Passing an explicit non-positive value disables this budget and can expose deserialization DoS risk -from compact inputs that materialize large object graphs. +Explicit non-positive values are rejected when the runtime is created. ### Explicit Xlang Examples @@ -193,8 +192,7 @@ Security-related configuration: payloads. - Use `max_dyn_depth(...)` to reject unexpectedly deep dynamic object graphs. - Keep `max_graph_memory_bytes(...)` at the fixed `128 MiB` default for most inputs, or set a - positive byte limit for trusted workloads with different legitimate object-graph sizes. Avoid - explicit non-positive values for untrusted data because they disable graph-memory enforcement. + positive byte limit for trusted workloads with different legitimate object-graph sizes. - Keep the remote schema metadata limits at their defaults unless the data is not malicious and a trusted peer sends larger metadata or many schema versions. - Prefer concrete typed fields over `dyn Any` or broad trait-object fields for untrusted input. diff --git a/docs/guide/swift/configuration.md b/docs/guide/swift/configuration.md index 75fd8d281b..dcfb169f60 100644 --- a/docs/guide/swift/configuration.md +++ b/docs/guide/swift/configuration.md @@ -95,8 +95,7 @@ let fory = Fory(compatible: false, checkClassVersion: true) `maxGraphMemoryBytes` bounds estimated shallow graph memory accepted during one root deserialization. The default limit is a fixed `128 MiB` for all root input forms. A positive value -overrides the default. Passing an explicit non-positive value disables this budget and can expose -deserialization DoS risk from compact inputs that materialize large object graphs. +overrides the default. Explicit non-positive values are rejected when the runtime is created. Compatible-mode remote metadata is also limited: diff --git a/docs/security/deserialization.md b/docs/security/deserialization.md index 2ec48856bf..39edb7af54 100644 --- a/docs/security/deserialization.md +++ b/docs/security/deserialization.md @@ -213,9 +213,8 @@ not exact heap measurement and it is not a raw element-slot limit. The public configuration is `maxGraphMemoryBytes`. The default is a fixed `128 MiB` for all root input forms; positive user configuration overrides the default. Explicit non-positive configuration -disables this budget and can expose deserialization DoS risk from compact inputs that materialize -large object graphs. The budget is not derived from root input size, and stream budgeting should not -depend on dynamic bytes-read accounting. +is invalid and should be rejected when the runtime is created. The budget is not derived from root +input size, and stream budgeting should not depend on dynamic bytes-read accounting. Graph budget accounting should: diff --git a/docs/specification/xlang_implementation_guide.md b/docs/specification/xlang_implementation_guide.md index a738bd8ace..dd0f81bc0b 100644 --- a/docs/specification/xlang_implementation_guide.md +++ b/docs/specification/xlang_implementation_guide.md @@ -403,9 +403,9 @@ memory budget before allocation or size hinting. The budget belongs to `ReadContext` or the equivalent root read state, not to serializers and not to ambient thread-local state. `maxGraphMemoryBytes` defaults to a fixed `128 MiB`; positive configuration overrides the default; explicit non-positive -configuration disables graph-memory enforcement. Do not derive this budget from -root input size, and do not add dynamic stream bytes-read accounting for this -budget. +configuration is invalid and must be rejected when the runtime is created. Do +not derive this budget from root input size, and do not add dynamic stream +bytes-read accounting for this budget. Read context or equivalent read state owns only raw byte reservation. It must not expose counted arithmetic helpers or collection, map, array, struct, or diff --git a/go/fory/README.md b/go/fory/README.md index c3f7f9f8ed..5a18ddba06 100644 --- a/go/fory/README.md +++ b/go/fory/README.md @@ -95,13 +95,14 @@ f := fory.New(fory.WithMaxDepth(20)) // Set maximum estimated graph memory for one root read f := fory.New(fory.WithMaxGraphMemoryBytes(256 * 1024 * 1024)) +// The value must be positive. // Combine multiple options f := fory.New( fory.WithXlang(true), fory.WithTrackRef(true), fory.WithMaxDepth(20), - fory.WithMaxGraphMemoryBytes(-1), + fory.WithMaxGraphMemoryBytes(256 * 1024 * 1024), ) ``` diff --git a/go/fory/fory.go b/go/fory/fory.go index 874c7e3619..1014c6b5d1 100644 --- a/go/fory/fory.go +++ b/go/fory/fory.go @@ -118,8 +118,10 @@ func WithMaxDepth(depth int) Option { } // WithMaxGraphMemoryBytes sets the maximum estimated graph memory accepted during one root deserialization. -// Non-positive values disable graph-memory enforcement. func WithMaxGraphMemoryBytes(size int64) Option { + if size <= 0 { + panic("MaxGraphMemoryBytes must be positive") + } return func(f *Fory) { f.config.MaxGraphMemoryBytes = size } @@ -580,13 +582,8 @@ func (f *Fory) Deserialize(data []byte, v any) error { target := reflect.ValueOf(v).Elem() targetType := target.Type() limit := f.config.MaxGraphMemoryBytes - if limit <= 0 { - f.readCtx.graphMemoryLimitBytes = 0 - f.readCtx.remainingGraphMemoryBytes = MaxInt64 - } else { - f.readCtx.graphMemoryLimitBytes = limit - f.readCtx.remainingGraphMemoryBytes = limit - } + f.readCtx.graphMemoryLimitBytes = limit + f.readCtx.remainingGraphMemoryBytes = limit if bytes, ok := f.rootGraphBytesFor(targetType); ok && bytes > 0 { if !f.readCtx.ReserveGraphMemory(bytes) { return f.readCtx.TakeError() @@ -686,13 +683,8 @@ func (f *Fory) DeserializeFrom(buf *ByteBuffer, v any) error { target := reflect.ValueOf(v).Elem() targetType := target.Type() limit := f.config.MaxGraphMemoryBytes - if limit <= 0 { - f.readCtx.graphMemoryLimitBytes = 0 - f.readCtx.remainingGraphMemoryBytes = MaxInt64 - } else { - f.readCtx.graphMemoryLimitBytes = limit - f.readCtx.remainingGraphMemoryBytes = limit - } + f.readCtx.graphMemoryLimitBytes = limit + f.readCtx.remainingGraphMemoryBytes = limit if bytes, ok := f.rootGraphBytesFor(targetType); ok && bytes > 0 { if !f.readCtx.ReserveGraphMemory(bytes) { f.readCtx.buffer = origBuffer @@ -813,13 +805,8 @@ func (f *Fory) DeserializeWithCallbackBuffers(buffer *ByteBuffer, v any, buffers target := rv.Elem() targetType := target.Type() limit := f.config.MaxGraphMemoryBytes - if limit <= 0 { - f.readCtx.graphMemoryLimitBytes = 0 - f.readCtx.remainingGraphMemoryBytes = MaxInt64 - } else { - f.readCtx.graphMemoryLimitBytes = limit - f.readCtx.remainingGraphMemoryBytes = limit - } + f.readCtx.graphMemoryLimitBytes = limit + f.readCtx.remainingGraphMemoryBytes = limit if bytes, ok := f.rootGraphBytesFor(targetType); ok && bytes > 0 { if !f.readCtx.ReserveGraphMemory(bytes) { return f.readCtx.TakeError() @@ -1083,13 +1070,8 @@ func Deserialize[T any](f *Fory, data []byte, target *T) error { f.readCtx.Reset() f.readCtx.SetData(data) limit := f.config.MaxGraphMemoryBytes - if limit <= 0 { - f.readCtx.graphMemoryLimitBytes = 0 - f.readCtx.remainingGraphMemoryBytes = MaxInt64 - } else { - f.readCtx.graphMemoryLimitBytes = limit - f.readCtx.remainingGraphMemoryBytes = limit - } + f.readCtx.graphMemoryLimitBytes = limit + f.readCtx.remainingGraphMemoryBytes = limit var targetVal reflect.Value var targetType reflect.Type diff --git a/go/fory/graph_memory_budget_test.go b/go/fory/graph_memory_budget_test.go index b7e3c18520..2fe91ae480 100644 --- a/go/fory/graph_memory_budget_test.go +++ b/go/fory/graph_memory_budget_test.go @@ -46,11 +46,11 @@ func graphOwnerSizeOf[T any]() int64 { func TestGraphMemoryBudgetConfig(t *testing.T) { require.Equal(t, int64(128*1024*1024), New().config.MaxGraphMemoryBytes) require.Equal(t, int64(123), New(WithMaxGraphMemoryBytes(123)).config.MaxGraphMemoryBytes) - require.Equal(t, int64(0), New(WithMaxGraphMemoryBytes(0)).config.MaxGraphMemoryBytes) - require.Equal(t, int64(-2), New(WithMaxGraphMemoryBytes(-2)).config.MaxGraphMemoryBytes) + require.Panics(t, func() { WithMaxGraphMemoryBytes(0) }) + require.Panics(t, func() { WithMaxGraphMemoryBytes(-2) }) } -func TestGraphMemoryBudgetFixedDefaultAndDisable(t *testing.T) { +func TestGraphMemoryBudgetFixedDefault(t *testing.T) { ctx := NewReadContext(false) ctx.graphMemoryLimitBytes = 128 * 1024 * 1024 ctx.remainingGraphMemoryBytes = 128 * 1024 * 1024 @@ -59,13 +59,6 @@ func TestGraphMemoryBudgetFixedDefaultAndDisable(t *testing.T) { require.False(t, ctx.ReserveGraphMemory(1)) require.Contains(t, ctx.CheckError().Error(), "maxGraphMemoryBytes") - ctx = NewReadContext(false) - ctx.graphMemoryLimitBytes = 0 - ctx.remainingGraphMemoryBytes = MaxInt64 - require.Equal(t, int64(0), ctx.graphMemoryLimitBytes) - require.True(t, ctx.ReserveGraphMemory(MaxInt64)) - require.False(t, ctx.HasError()) - ctx = NewReadContext(false) ctx.graphMemoryLimitBytes = 77 ctx.remainingGraphMemoryBytes = 77 diff --git a/go/fory/reader.go b/go/fory/reader.go index 01eecca3bc..321c5db4ea 100644 --- a/go/fory/reader.go +++ b/go/fory/reader.go @@ -119,9 +119,6 @@ func (c *ReadContext) Reset() { // ReserveGraphMemory reserves raw estimated graph-owner bytes. func (c *ReadContext) ReserveGraphMemory(bytes int64) bool { if bytes >= 0 { - if c.graphMemoryLimitBytes <= 0 { - return true - } remaining := c.remainingGraphMemoryBytes if bytes <= remaining { c.remainingGraphMemoryBytes = remaining - bytes diff --git a/go/fory/stream.go b/go/fory/stream.go index e7bb12bd8b..f2391d7152 100644 --- a/go/fory/stream.go +++ b/go/fory/stream.go @@ -98,13 +98,8 @@ func (f *Fory) DeserializeFromStream(is *InputStream, v any) error { f.readCtx.buffer = is.buffer target := reflect.ValueOf(v).Elem() limit := f.config.MaxGraphMemoryBytes - if limit <= 0 { - f.readCtx.graphMemoryLimitBytes = 0 - f.readCtx.remainingGraphMemoryBytes = MaxInt64 - } else { - f.readCtx.graphMemoryLimitBytes = limit - f.readCtx.remainingGraphMemoryBytes = limit - } + f.readCtx.graphMemoryLimitBytes = limit + f.readCtx.remainingGraphMemoryBytes = limit if bytes, ok := f.rootGraphBytesFor(target.Type()); ok && bytes > 0 { if !f.readCtx.ReserveGraphMemory(bytes) { err := f.readCtx.TakeError() @@ -141,13 +136,8 @@ func (f *Fory) DeserializeFromReader(r io.Reader, v any) error { f.readCtx.buffer.ResetWithReader(r, 0) target := reflect.ValueOf(v).Elem() limit := f.config.MaxGraphMemoryBytes - if limit <= 0 { - f.readCtx.graphMemoryLimitBytes = 0 - f.readCtx.remainingGraphMemoryBytes = MaxInt64 - } else { - f.readCtx.graphMemoryLimitBytes = limit - f.readCtx.remainingGraphMemoryBytes = limit - } + f.readCtx.graphMemoryLimitBytes = limit + f.readCtx.remainingGraphMemoryBytes = limit if bytes, ok := f.rootGraphBytesFor(target.Type()); ok && bytes > 0 { if !f.readCtx.ReserveGraphMemory(bytes) { return f.readCtx.TakeError() diff --git a/java/fory-core/src/main/java/org/apache/fory/config/ForyBuilder.java b/java/fory-core/src/main/java/org/apache/fory/config/ForyBuilder.java index facd26013c..30b18257c7 100644 --- a/java/fory-core/src/main/java/org/apache/fory/config/ForyBuilder.java +++ b/java/fory-core/src/main/java/org/apache/fory/config/ForyBuilder.java @@ -575,10 +575,10 @@ public ForyBuilder withMaxAverageSchemaVersionsPerType(int maxAverageSchemaVersi /** * Sets the maximum estimated graph memory accepted during one root deserialization. * - *

The default is a fixed 128 MiB. Positive values are explicit byte limits. Explicit - * non-positive values disable this budget. + *

The default is a fixed 128 MiB. Values must be positive byte limits. */ public ForyBuilder withMaxGraphMemoryBytes(long maxGraphMemoryBytes) { + Preconditions.checkArgument(maxGraphMemoryBytes > 0, "maxGraphMemoryBytes must be positive"); this.maxGraphMemoryBytes = maxGraphMemoryBytes; recordAction(b -> b.withMaxGraphMemoryBytes(maxGraphMemoryBytes)); return this; diff --git a/java/fory-core/src/main/java/org/apache/fory/context/ReadContext.java b/java/fory-core/src/main/java/org/apache/fory/context/ReadContext.java index 4fb8ff1a0d..d55c6520d0 100644 --- a/java/fory-core/src/main/java/org/apache/fory/context/ReadContext.java +++ b/java/fory-core/src/main/java/org/apache/fory/context/ReadContext.java @@ -116,14 +116,8 @@ public void prepare( this.buffer = buffer; this.peerOutOfBandEnabled = peerOutOfBandEnabled; this.outOfBandBuffers = outOfBandBuffers == null ? null : outOfBandBuffers.iterator(); - long limit = maxGraphMemoryBytes; - if (limit <= 0) { - graphMemoryLimitBytes = 0; - remainingGraphMemoryBytes = Long.MAX_VALUE; - return; - } - graphMemoryLimitBytes = limit; - remainingGraphMemoryBytes = limit; + graphMemoryLimitBytes = maxGraphMemoryBytes; + remainingGraphMemoryBytes = maxGraphMemoryBytes; } /** @@ -332,9 +326,6 @@ public void reserveGraphMemory(long bytes) { if (bytes < 0) { throwNegativeGraphMemory(bytes); } - if (graphMemoryLimitBytes <= 0) { - return; - } long remaining = remainingGraphMemoryBytes; if (bytes > remaining) { throwGraphMemoryExceeded(bytes, remaining); diff --git a/java/fory-core/src/test/java/org/apache/fory/serializer/GraphMemoryBudgetTest.java b/java/fory-core/src/test/java/org/apache/fory/serializer/GraphMemoryBudgetTest.java index 867eba616d..5d26ea65b4 100644 --- a/java/fory-core/src/test/java/org/apache/fory/serializer/GraphMemoryBudgetTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/serializer/GraphMemoryBudgetTest.java @@ -44,11 +44,11 @@ public class GraphMemoryBudgetTest extends ForyTestBase { private static final int OBJECT_SELF_BYTES = 1; @Test - public void testConfigDefaultsAndDisable() { + public void testConfigDefaultsAndValidation() { assertEquals(builder().build().getConfig().maxGraphMemoryBytes(), DEFAULT_GRAPH_MEMORY_BYTES); assertEquals(newFory(123).getConfig().maxGraphMemoryBytes(), 123); - assertEquals(newFory(0).getConfig().maxGraphMemoryBytes(), 0); - assertEquals(newFory(-2).getConfig().maxGraphMemoryBytes(), -2); + assertThrows(IllegalArgumentException.class, () -> newFory(0)); + assertThrows(IllegalArgumentException.class, () -> newFory(-2)); } @Test @@ -62,17 +62,6 @@ public void testDefaultFixedBudget() { } } - @Test - public void testDisabledBudget() { - ReadContext readContext = prepareContext(newFory(0)); - try { - readContext.reserveGraphMemory(DEFAULT_GRAPH_MEMORY_BYTES + 1); - readContext.reserveGraphMemory(Long.MAX_VALUE); - } finally { - readContext.reset(); - } - } - @Test public void testExplicitBudgetWins() { Fory fory = newFory(7); diff --git a/javascript/packages/core/lib/context.ts b/javascript/packages/core/lib/context.ts index 60f51de71e..1b34ebdd71 100644 --- a/javascript/packages/core/lib/context.ts +++ b/javascript/packages/core/lib/context.ts @@ -540,7 +540,6 @@ export class ReadContext { private _depth = 0; private _maxDepth: number; private readonly maxGraphMemoryBytes: number; - private effectiveGraphMemoryBytes = 0; private remainingGraphMemoryBytes = 0; private remoteSchemaVersionsByType: Map | undefined = undefined; @@ -561,17 +560,13 @@ export class ReadContext { this.metaStringReader.reset(); this.typeMeta = []; this._depth = 0; - this.effectiveGraphMemoryBytes = this.maxGraphMemoryBytes > 0 ? this.maxGraphMemoryBytes : 0; - this.remainingGraphMemoryBytes = this.effectiveGraphMemoryBytes; + this.remainingGraphMemoryBytes = this.maxGraphMemoryBytes; } reserveGraphMemory(bytes: number) { if (!Number.isSafeInteger(bytes) || bytes < 0) { this.throwGraphMemoryOverflow(bytes); } - if (this.effectiveGraphMemoryBytes <= 0) { - return; - } const remaining = this.remainingGraphMemoryBytes - bytes; if (remaining < 0) { this.throwGraphBudgetExceeded(bytes); @@ -587,7 +582,7 @@ export class ReadContext { throw new Error( `maxGraphMemoryBytes exceeded: requested ${bytes} estimated graph bytes, ` + `${this.remainingGraphMemoryBytes} remaining, effective limit ` + - `${this.effectiveGraphMemoryBytes}`, + `${this.maxGraphMemoryBytes}`, ); } diff --git a/javascript/packages/core/lib/fory.ts b/javascript/packages/core/lib/fory.ts index c1ba3fdb12..abef11da3a 100644 --- a/javascript/packages/core/lib/fory.ts +++ b/javascript/packages/core/lib/fory.ts @@ -90,8 +90,10 @@ export default class Fory { ); } const maxGraphMemoryBytes = config?.maxGraphMemoryBytes ?? DEFAULT_MAX_GRAPH_MEMORY_BYTES; - if (!Number.isSafeInteger(maxGraphMemoryBytes)) { - throw new Error(`maxGraphMemoryBytes must be a safe integer but got ${maxGraphMemoryBytes}`); + if (!Number.isSafeInteger(maxGraphMemoryBytes) || maxGraphMemoryBytes <= 0) { + throw new Error( + `maxGraphMemoryBytes must be a positive safe integer but got ${maxGraphMemoryBytes}`, + ); } return { ref: Boolean(config?.ref), diff --git a/javascript/test/graphMemoryBudget.test.ts b/javascript/test/graphMemoryBudget.test.ts index 15d271b34a..8a0f633cca 100644 --- a/javascript/test/graphMemoryBudget.test.ts +++ b/javascript/test/graphMemoryBudget.test.ts @@ -49,17 +49,15 @@ describe("graph memory budget", () => { expect(() => fory.readContext.reserveGraphMemory(1)).toThrow(/maxGraphMemoryBytes/); }); - test("handles explicit config and disable", () => { + test("handles explicit config and validation", () => { const fory = new Fory({ maxGraphMemoryBytes: 24 }); fory.readContext.reset(new Uint8Array(1)); expect(() => fory.readContext.reserveGraphMemory(0)).not.toThrow(); expect(() => fory.readContext.reserveGraphMemory(24)).not.toThrow(); expect(() => fory.readContext.reserveGraphMemory(1)).toThrow(/maxGraphMemoryBytes/); - const disabled = new Fory({ maxGraphMemoryBytes: 0 }); - disabled.readContext.reset(new Uint8Array(1)); - expect(() => disabled.readContext.reserveGraphMemory(Number.MAX_SAFE_INTEGER)).not.toThrow(); - expect(() => new Fory({ maxGraphMemoryBytes: -2 })).not.toThrow(); + expect(() => new Fory({ maxGraphMemoryBytes: 0 })).toThrow(/maxGraphMemoryBytes/); + expect(() => new Fory({ maxGraphMemoryBytes: -2 })).toThrow(/maxGraphMemoryBytes/); }); test("uses parent storage for nested empty containers", () => { diff --git a/python/pyfory/_fory.py b/python/pyfory/_fory.py index 27a5d56d8f..9f5615565f 100644 --- a/python/pyfory/_fory.py +++ b/python/pyfory/_fory.py @@ -186,8 +186,7 @@ def __init__( across accepted remote types. max_graph_memory_bytes: Maximum estimated graph memory per root - deserialization. Defaults to 128 MiB; positive values are explicit byte limits, - and non-positive values intentionally disable this protection. + deserialization. Defaults to 128 MiB and must be a positive byte limit. policy: Custom deserialization policy for security checks. When provided, it controls which types can be deserialized, overriding the default policy. @@ -219,8 +218,8 @@ def __init__( raise ValueError("max_schema_versions_per_type must be a positive integer") if not isinstance(max_average_schema_versions_per_type, int) or max_average_schema_versions_per_type <= 0: raise ValueError("max_average_schema_versions_per_type must be a positive integer") - if not isinstance(max_graph_memory_bytes, int) or max_graph_memory_bytes > (1 << 63) - 1 or max_graph_memory_bytes < -(1 << 63): - raise ValueError("max_graph_memory_bytes must be a 63-bit integer") + if not isinstance(max_graph_memory_bytes, int) or max_graph_memory_bytes <= 0 or max_graph_memory_bytes > (1 << 63) - 1: + raise ValueError("max_graph_memory_bytes must be a positive 63-bit integer") self.max_graph_memory_bytes = max_graph_memory_bytes self.config = Config( xlang=xlang, diff --git a/python/pyfory/context.pxi b/python/pyfory/context.pxi index 0a0a0511bb..03076f093e 100644 --- a/python/pyfory/context.pxi +++ b/python/pyfory/context.pxi @@ -791,15 +791,13 @@ cdef class ReadContext: unsupported_objects=None, bint peer_out_of_band_enabled=False, ): - cdef int64_t limit - limit = self.max_graph_memory_bytes if self.max_graph_memory_bytes > 0 else 0 self.buffer = buffer self.c_buffer = buffer.c_buffer self.buffers = iter(buffers) if buffers is not None else None self.unsupported_objects = iter(unsupported_objects) if unsupported_objects is not None else None self.peer_out_of_band_enabled = peer_out_of_band_enabled - self.graph_memory_limit_bytes = limit - self.remaining_graph_memory_bytes = limit if limit > 0 else _MAX_GRAPH_MEMORY_BYTES + self.graph_memory_limit_bytes = self.max_graph_memory_bytes + self.remaining_graph_memory_bytes = self.max_graph_memory_bytes self.depth = 0 cpdef inline reset(self): @@ -824,8 +822,6 @@ cdef class ReadContext: raise ValueError("Estimated graph memory is negative") if num_bytes > _MAX_GRAPH_MEMORY_BYTES: raise ValueError("Estimated graph memory overflow") - if self.graph_memory_limit_bytes <= 0: - return if num_bytes > self.remaining_graph_memory_bytes: used = self.graph_memory_limit_bytes - self.remaining_graph_memory_bytes raise ValueError( diff --git a/python/pyfory/context.py b/python/pyfory/context.py index d894516d80..edf30e31d6 100644 --- a/python/pyfory/context.py +++ b/python/pyfory/context.py @@ -527,13 +527,12 @@ def prepare( unsupported_objects=None, peer_out_of_band_enabled=False, ): - limit = self.max_graph_memory_bytes if self.max_graph_memory_bytes > 0 else 0 self.buffer = buffer self.buffers = iter(buffers) if buffers is not None else None self.unsupported_objects = iter(unsupported_objects) if unsupported_objects is not None else None self.peer_out_of_band_enabled = peer_out_of_band_enabled - self.graph_memory_limit_bytes = limit - self.remaining_graph_memory_bytes = limit if limit > 0 else _MAX_GRAPH_MEMORY_BYTES + self.graph_memory_limit_bytes = self.max_graph_memory_bytes + self.remaining_graph_memory_bytes = self.max_graph_memory_bytes self.depth = 0 def reset(self): @@ -556,8 +555,6 @@ def reserve_graph_memory(self, num_bytes): raise ValueError("Estimated graph memory is negative") if num_bytes > _MAX_GRAPH_MEMORY_BYTES: raise ValueError("Estimated graph memory overflow") - if self.graph_memory_limit_bytes <= 0: - return remaining = self.remaining_graph_memory_bytes if num_bytes > remaining: used = self.graph_memory_limit_bytes - remaining diff --git a/python/pyfory/serialization.pyx b/python/pyfory/serialization.pyx index 56e46b32c6..9c1fe10643 100644 --- a/python/pyfory/serialization.pyx +++ b/python/pyfory/serialization.pyx @@ -114,8 +114,7 @@ cdef class Config: max_schema_versions_per_type: Maximum accepted remote metadata versions for one logical type. max_average_schema_versions_per_type: Average remote schema versions allowed across accepted remote types. max_graph_memory_bytes: Maximum estimated graph memory per root - deserialization. Defaults to 128 MiB; positive values are explicit byte limits, - and non-positive values intentionally disable this protection. + deserialization. Defaults to 128 MiB and must be a positive byte limit. field_nullable: Treats struct/dataclass fields as nullable by default. policy: Deserialization policy used for security-sensitive checks. meta_compressor: Optional typedef/meta compressor implementation. @@ -172,8 +171,7 @@ cdef class Config: max_schema_versions_per_type: Maximum accepted remote metadata versions for one logical type. max_average_schema_versions_per_type: Average remote schema versions allowed across accepted remote types. max_graph_memory_bytes: Maximum estimated graph memory per root - deserialization. Defaults to 128 MiB; positive values are explicit byte limits, - and non-positive values intentionally disable this protection. + deserialization. Defaults to 128 MiB and must be a positive byte limit. field_nullable: Treat all struct fields as nullable by default. policy: Deserialization policy implementation. meta_compressor: Optional typedef/meta compressor. @@ -195,10 +193,10 @@ cdef class Config: raise ValueError("max_average_schema_versions_per_type must be a positive integer") if ( not isinstance(max_graph_memory_bytes, int) + or max_graph_memory_bytes <= 0 or max_graph_memory_bytes > 9223372036854775807 - or max_graph_memory_bytes < -9223372036854775808 ): - raise ValueError("max_graph_memory_bytes must be a 63-bit integer") + raise ValueError("max_graph_memory_bytes must be a positive 63-bit integer") self.max_type_fields = max_type_fields self.max_type_meta_bytes = max_type_meta_bytes self.max_schema_versions_per_type = max_schema_versions_per_type @@ -883,8 +881,7 @@ cdef class Fory: max_schema_versions_per_type: Maximum accepted remote metadata versions for one logical type. max_average_schema_versions_per_type: Average remote schema versions allowed across accepted remote types. max_graph_memory_bytes: Maximum estimated graph memory per root - deserialization. Defaults to 128 MiB; positive values are explicit byte limits, - and non-positive values intentionally disable this protection. + deserialization. Defaults to 128 MiB and must be a positive byte limit. policy: Optional deserialization policy implementation. field_nullable: Treat struct fields as nullable by default. meta_compressor: Optional typedef/meta compressor implementation. @@ -904,10 +901,10 @@ cdef class Fory: self.max_depth = max_depth if ( not isinstance(max_graph_memory_bytes, int) + or max_graph_memory_bytes <= 0 or max_graph_memory_bytes > 9223372036854775807 - or max_graph_memory_bytes < -9223372036854775808 ): - raise ValueError("max_graph_memory_bytes must be a 63-bit integer") + raise ValueError("max_graph_memory_bytes must be a positive 63-bit integer") self.max_graph_memory_bytes = max_graph_memory_bytes self.config = Config( xlang=xlang, @@ -1079,7 +1076,6 @@ cdef class Fory: cdef int32_t reader_index cdef uint8_t bitmap cdef bint peer_out_of_band_enabled - cdef int64_t graph_memory_limit if isinstance(buffer, bytes): buffer = Buffer(buffer) read_buffer = buffer @@ -1095,7 +1091,6 @@ cdef class Fory: raise ValueError("Out-of-band buffers are required by the root header") if not peer_out_of_band_enabled and buffers is not None: raise ValueError("Out-of-band buffers were provided for an in-band root payload") - graph_memory_limit = self.max_graph_memory_bytes if self.max_graph_memory_bytes > 0 else 0 # Keep the root context setup inline. Top-level deserialize is a hot path, # so it should not pay an extra method call just to bind the active buffer. read_context.buffer = read_buffer @@ -1105,8 +1100,8 @@ cdef class Fory: iter(unsupported_objects) if unsupported_objects is not None else None ) read_context.peer_out_of_band_enabled = peer_out_of_band_enabled - read_context.graph_memory_limit_bytes = graph_memory_limit - read_context.remaining_graph_memory_bytes = graph_memory_limit if graph_memory_limit > 0 else _MAX_GRAPH_MEMORY_BYTES + read_context.graph_memory_limit_bytes = self.max_graph_memory_bytes + read_context.remaining_graph_memory_bytes = self.max_graph_memory_bytes read_context.depth = 0 return read_context.read_ref() diff --git a/python/pyfory/tests/test_graph_memory_budget.py b/python/pyfory/tests/test_graph_memory_budget.py index 514f1fbac2..bbaaec57f3 100644 --- a/python/pyfory/tests/test_graph_memory_budget.py +++ b/python/pyfory/tests/test_graph_memory_budget.py @@ -148,16 +148,10 @@ def test_stream_uses_fixed_default_budget(): fory.reset_read() -def test_explicit_config_and_disable(): +def test_explicit_config_overrides_default(): value = [1] budget = collection_memory(1) assert expect_budget(value, budget) == value - disabled = new_fory(0, xlang=False) - try: - disabled.read_context.prepare(Buffer(b"x")) - disabled.read_context.reserve_graph_memory(MAX_GRAPH_MEMORY_BYTES) - finally: - disabled.reset_read() def test_nested_empty_containers_use_parent_storage(): @@ -261,7 +255,7 @@ def test_declared_large_list_still_needs_bytes(): fory.reset_read() -@pytest.mark.parametrize("limit", [1 << 63, -(1 << 63) - 1]) +@pytest.mark.parametrize("limit", [0, -2, 1 << 63]) def test_invalid_config(limit): with pytest.raises(ValueError, match="max_graph_memory_bytes"): new_fory(limit) diff --git a/rust/fory-core/src/config.rs b/rust/fory-core/src/config.rs index 76f4d652c2..aeeae86625 100644 --- a/rust/fory-core/src/config.rs +++ b/rust/fory-core/src/config.rs @@ -41,8 +41,7 @@ pub struct Config { /// and preserved during serialization/deserialization. pub track_ref: bool, /// Maximum estimated graph memory accepted during one root deserialization. - /// Defaults to 128 MiB. Positive values are explicit limits; non-positive - /// values intentionally disable this protection. + /// Defaults to 128 MiB. Value must be a positive byte limit. pub max_graph_memory_bytes: i64, /// Maximum accepted field count in one received struct TypeMeta. pub max_type_fields: u32, diff --git a/rust/fory-core/src/context.rs b/rust/fory-core/src/context.rs index 0a69e601d8..13c48a3e47 100644 --- a/rust/fory-core/src/context.rs +++ b/rust/fory-core/src/context.rs @@ -452,9 +452,6 @@ impl<'a> ReadContext<'a> { #[inline(always)] #[doc(hidden)] pub fn reserve_graph_memory(&mut self, bytes: usize) -> Result<(), Error> { - if self.graph_memory_limit_bytes == 0 { - return Ok(()); - } let remaining = self.remaining_graph_memory_bytes; if bytes > remaining { return Err(graph_memory_exceeded( diff --git a/rust/fory-core/src/fory.rs b/rust/fory-core/src/fory.rs index f73aa0c3aa..8c759f6f21 100644 --- a/rust/fory-core/src/fory.rs +++ b/rust/fory-core/src/fory.rs @@ -263,9 +263,13 @@ impl ForyBuilder { /// Sets the maximum estimated graph memory accepted during one root deserialization. /// - /// Defaults to 128 MiB. Positive values are explicit byte limits; non-positive - /// values intentionally disable this protection. + /// Defaults to 128 MiB. Values must be positive byte limits. pub fn max_graph_memory_bytes(mut self, max_bytes: i64) -> Self { + assert!(max_bytes > 0, "max_graph_memory_bytes must be positive"); + assert!( + usize::try_from(max_bytes).is_ok(), + "max_graph_memory_bytes does not fit usize" + ); self.config.max_graph_memory_bytes = max_bytes; self } @@ -997,16 +1001,12 @@ impl Fory { self.with_read_context(|context| { let outlive_buffer = unsafe { mem::transmute::<&[u8], &[u8]>(bf) }; context.attach_reader(Reader::new(outlive_buffer)); - let result = match if context.max_graph_memory_bytes > 0 { - usize::try_from(context.max_graph_memory_bytes) - .map_err(|_| Error::invalid_data("max_graph_memory_bytes does not fit usize")) - } else { - Ok(0) - } { + let result = match usize::try_from(context.max_graph_memory_bytes) + .map_err(|_| Error::invalid_data("max_graph_memory_bytes does not fit usize")) + { Ok(limit) => { context.graph_memory_limit_bytes = limit; - context.remaining_graph_memory_bytes = - if limit > 0 { limit } else { usize::MAX }; + context.remaining_graph_memory_bytes = limit; self.deserialize_with_context(context) } Err(err) => { @@ -1076,16 +1076,12 @@ impl Fory { let mut new_reader = Reader::new(outlive_buffer); new_reader.set_cursor(reader.cursor); context.attach_reader(new_reader); - let result = match if context.max_graph_memory_bytes > 0 { - usize::try_from(context.max_graph_memory_bytes) - .map_err(|_| Error::invalid_data("max_graph_memory_bytes does not fit usize")) - } else { - Ok(0) - } { + let result = match usize::try_from(context.max_graph_memory_bytes) + .map_err(|_| Error::invalid_data("max_graph_memory_bytes does not fit usize")) + { Ok(limit) => { context.graph_memory_limit_bytes = limit; - context.remaining_graph_memory_bytes = - if limit > 0 { limit } else { usize::MAX }; + context.remaining_graph_memory_bytes = limit; self.deserialize_with_context(context) } Err(err) => { diff --git a/rust/tests/tests/test_graph_memory_budget.rs b/rust/tests/tests/test_graph_memory_budget.rs index 0241aa25e8..aa3ff1da4c 100644 --- a/rust/tests/tests/test_graph_memory_budget.rs +++ b/rust/tests/tests/test_graph_memory_budget.rs @@ -83,39 +83,15 @@ fn config_validation() { Fory::builder().build().config().max_graph_memory_bytes, DEFAULT_GRAPH_MEMORY_BYTES ); - assert_eq!( - Fory::builder() - .max_graph_memory_bytes(0) - .build() - .config() - .max_graph_memory_bytes, - 0 + assert!( + std::panic::catch_unwind(|| Fory::builder().max_graph_memory_bytes(0).build()).is_err() ); - assert_eq!( - Fory::builder() - .max_graph_memory_bytes(-2) - .build() - .config() - .max_graph_memory_bytes, - -2 + assert!( + std::panic::catch_unwind(|| Fory::builder().max_graph_memory_bytes(-2).build()).is_err() ); let _ = Fory::builder().max_graph_memory_bytes(1).build(); } -#[test] -fn non_positive_budget_disables_enforcement() { - let value: Vec = Vec::new(); - let writer = fory_with_budget(DEFAULT_GRAPH_MEMORY_BYTES); - let bytes = writer.serialize(&value).unwrap(); - - assert!(fory_with_budget(1) - .deserialize::>(&bytes) - .is_err()); - assert!(fory_with_budget(0) - .deserialize::>(&bytes) - .is_ok()); -} - #[test] fn byte_root_uses_fixed_default_budget() { let value = compact_empty_lists(12000); diff --git a/scala/README.md b/scala/README.md index a1149ccf60..daffbc05a9 100644 --- a/scala/README.md +++ b/scala/README.md @@ -188,9 +188,12 @@ sbt test ## Code Format ```bash -sbt scalafmt +cd .. +ci/format.sh ``` +The Scala module does not currently wire a `scalafmt` sbt command. + ## Additional Notes - **Fory Reuse**: Always reuse Fory instances; creation is expensive diff --git a/scala/src/test/scala/org/apache/fory/serializer/scala/CollectionSerializerTest.scala b/scala/src/test/scala/org/apache/fory/serializer/scala/CollectionSerializerTest.scala index 085d0cc5a9..6489118eef 100644 --- a/scala/src/test/scala/org/apache/fory/serializer/scala/CollectionSerializerTest.scala +++ b/scala/src/test/scala/org/apache/fory/serializer/scala/CollectionSerializerTest.scala @@ -92,22 +92,20 @@ class CollectionSerializerTest extends AnyWordSpec with Matchers { } "fory scala graph memory budget" should { - def runtime(maxGraphMemoryBytes: Long = -1): Fory = { + def runtime(maxGraphMemoryBytes: Option[Long] = None): Fory = { val builder = ForyScala.builder() .withXlang(false) .withRefTracking(true) .requireClassRegistration(false) .suppressClassRegistrationWarnings(false) .withSerializerFactory(new ScalaSerializerFactory()) - if (maxGraphMemoryBytes > 0) { - builder.withMaxGraphMemoryBytes(maxGraphMemoryBytes) - } + maxGraphMemoryBytes.foreach(builder.withMaxGraphMemoryBytes) builder.build() } "reserve scala collection storage" in { val writer = runtime() - val reader = runtime(maxGraphMemoryBytes = 23) + val reader = runtime(maxGraphMemoryBytes = Some(23)) intercept[InsecureException] { reader.deserialize(writer.serialize(List.fill(6)("v"))) } @@ -115,7 +113,7 @@ class CollectionSerializerTest extends AnyWordSpec with Matchers { "reserve scala map storage" in { val writer = runtime() - val reader = runtime(maxGraphMemoryBytes = 23) + val reader = runtime(maxGraphMemoryBytes = Some(23)) intercept[InsecureException] { reader.deserialize(writer.serialize(Map("a" -> 1, "b" -> 2, "c" -> 3))) } diff --git a/swift/Sources/Fory/Fory.swift b/swift/Sources/Fory/Fory.swift index 7e114a6a1a..4b653bcba7 100644 --- a/swift/Sources/Fory/Fory.swift +++ b/swift/Sources/Fory/Fory.swift @@ -45,6 +45,9 @@ public struct Config { precondition( maxAverageSchemaVersionsPerType > 0, "maxAverageSchemaVersionsPerType must be positive") + precondition( + maxGraphMemoryBytes > 0 && maxGraphMemoryBytes <= Int64(Int.max), + "maxGraphMemoryBytes must be a positive byte limit") let effectiveCompatible = compatible ?? true let effectiveCheckClassVersion = checkClassVersion ?? !effectiveCompatible self.trackRef = trackRef @@ -493,8 +496,7 @@ public final class Fory { _ body: (ReadContext) throws -> R ) throws -> R { readContext.buffer.replace(with: data) - readContext.remainingGraphMemoryBytes = - readContext.maxGraphMemoryBytes > 0 ? readContext.maxGraphMemoryBytes : Int.max + readContext.remainingGraphMemoryBytes = readContext.maxGraphMemoryBytes defer { readContext.reset() } @@ -556,8 +558,7 @@ public final class Fory { ) throws -> R { try typeResolver.finishRegistration() readContext.buffer.swapState(with: buffer) - readContext.remainingGraphMemoryBytes = - readContext.maxGraphMemoryBytes > 0 ? readContext.maxGraphMemoryBytes : Int.max + readContext.remainingGraphMemoryBytes = readContext.maxGraphMemoryBytes defer { readContext.buffer.swapState(with: buffer) readContext.reset() diff --git a/swift/Sources/Fory/ReadContext.swift b/swift/Sources/Fory/ReadContext.swift index 7cd50e6abd..0e15f5f481 100644 --- a/swift/Sources/Fory/ReadContext.swift +++ b/swift/Sources/Fory/ReadContext.swift @@ -36,7 +36,7 @@ public final class ReadContext { private var lastTypeInfo = TypeInfo.uncached private let config: Config let maxGraphMemoryBytes: Int - var remainingGraphMemoryBytes = Int.max + var remainingGraphMemoryBytes = 0 init( buffer: ByteBuffer, @@ -59,9 +59,6 @@ public final class ReadContext { if bytes < 0 { try throwGraphMemoryOverflow() } - if maxGraphMemoryBytes <= 0 { - return - } if bytes > remainingGraphMemoryBytes { try throwGraphMemoryExceeded(bytes: bytes) } diff --git a/swift/Tests/ForyTests/GraphMemoryBudgetTests.swift b/swift/Tests/ForyTests/GraphMemoryBudgetTests.swift index 97b91aa7ab..9204220921 100644 --- a/swift/Tests/ForyTests/GraphMemoryBudgetTests.swift +++ b/swift/Tests/ForyTests/GraphMemoryBudgetTests.swift @@ -146,7 +146,7 @@ private func expectInvalidData(_ body: () throws -> Void) { } @Test -func fixedDefaultBudgetAndDisable() throws { +func fixedDefaultBudget() throws { let config = Config(trackRef: false, compatible: false) let context = ReadContext( buffer: ByteBuffer(), @@ -159,15 +159,6 @@ func fixedDefaultBudgetAndDisable() throws { expectInvalidData { try context.reserveGraphMemory(testReferenceBytes) } - - let disabledConfig = Config(trackRef: false, compatible: false, maxGraphMemoryBytes: 0) - let disabled = ReadContext( - buffer: ByteBuffer(), - typeResolver: TypeResolver(config: disabledConfig), - config: disabledConfig - ) - disabled.remainingGraphMemoryBytes = Int.max - try disabled.reserveGraphMemory(Int(defaultGraphMemoryBytes) + 1) } @Test From 04c70893ed17759d587bf7ec45c336a4b921e308 Mon Sep 17 00:00:00 2001 From: chaokunyang Date: Fri, 3 Jul 2026 03:44:50 +0800 Subject: [PATCH 35/54] refactor: remove graph budget cleanup leftovers --- .../serialization/collection_serializer.h | 65 ----------- cpp/fory/serialization/map_serializer.h | 13 --- cpp/fory/serialization/struct_serializer.h | 9 -- cpp/fory/serialization/union_serializer.h | 6 - go/fory/array.go | 2 +- go/fory/map.go | 6 +- go/fory/map_primitive.go | 4 +- go/fory/reader.go | 11 +- go/fory/set.go | 11 +- go/fory/slice.go | 4 +- go/fory/slice_dyn.go | 4 +- go/fory/slice_primitive.go | 4 +- go/fory/slice_primitive_list.go | 11 +- .../src/main/java/org/apache/fory/Fory.java | 108 ++++++++---------- .../java/org/apache/fory/config/Config.java | 2 +- .../org/apache/fory/context/ReadContext.java | 7 +- 16 files changed, 75 insertions(+), 192 deletions(-) diff --git a/cpp/fory/serialization/collection_serializer.h b/cpp/fory/serialization/collection_serializer.h index 8bdd5bb7fc..d1d28979d9 100644 --- a/cpp/fory/serialization/collection_serializer.h +++ b/cpp/fory/serialization/collection_serializer.h @@ -457,14 +457,6 @@ inline bool reserve_collection(std::vector &result, return true; } -template -inline bool reserve_empty_collection(ReadContext &ctx) { - if (FORY_PREDICT_FALSE(ctx.has_error())) { - return false; - } - return ctx.reserve_graph_memory(0); -} - // Helper to insert element into container (vector or set) template inline void collection_insert(Container &result, T &&elem) { @@ -480,9 +472,6 @@ template inline Container read_collection_data_slow(ReadContext &ctx, uint32_t length) { Container result; if (length == 0) { - if (FORY_PREDICT_FALSE(!reserve_empty_collection(ctx))) { - return result; - } return result; } @@ -612,10 +601,6 @@ inline std::forward_list read_forward_list_data_slow(ReadContext &ctx, uint32_t length) { std::forward_list result; if (length == 0) { - if (FORY_PREDICT_FALSE( - (!reserve_empty_collection>(ctx)))) { - return result; - } return result; } @@ -1122,10 +1107,6 @@ struct Serializer< std::vector result; // Per xlang spec: header and type_info are omitted when length is 0 if (length == 0) { - if (FORY_PREDICT_FALSE( - (!reserve_empty_collection>(ctx)))) { - return result; - } return result; } // Fast path for non-polymorphic, non-shared-ref elements @@ -1253,10 +1234,6 @@ struct Serializer< std::vector result; if (size == 0) { - if (FORY_PREDICT_FALSE( - (!reserve_empty_collection>(ctx)))) { - return result; - } return result; } if (FORY_PREDICT_FALSE(!reserve_collection(result, ctx, size))) { @@ -1424,10 +1401,6 @@ template struct Serializer> { std::list result; // Per xlang spec: header and type_info are omitted when length is 0 if (length == 0) { - if (FORY_PREDICT_FALSE( - (!reserve_empty_collection>(ctx)))) { - return result; - } return result; } // Fast path for non-polymorphic, non-shared-ref elements @@ -1555,10 +1528,6 @@ template struct Serializer> { std::list result; if (size == 0) { - if (FORY_PREDICT_FALSE( - (!reserve_empty_collection>(ctx)))) { - return result; - } return result; } if (FORY_PREDICT_FALSE(!reserve_collection(result, ctx, size))) { @@ -1632,10 +1601,6 @@ template struct Serializer> { std::deque result; // Per xlang spec: header and type_info are omitted when length is 0 if (length == 0) { - if (FORY_PREDICT_FALSE( - (!reserve_empty_collection>(ctx)))) { - return result; - } return result; } // Fast path for non-polymorphic, non-shared-ref elements @@ -1763,10 +1728,6 @@ template struct Serializer> { std::deque result; if (size == 0) { - if (FORY_PREDICT_FALSE( - (!reserve_empty_collection>(ctx)))) { - return result; - } return result; } if (FORY_PREDICT_FALSE(!reserve_collection(result, ctx, size))) { @@ -1836,10 +1797,6 @@ struct Serializer> { std::forward_list result; // Per xlang spec: header and type_info are omitted when length is 0 if (length == 0) { - if (FORY_PREDICT_FALSE( - (!reserve_empty_collection>(ctx)))) { - return result; - } return result; } @@ -2198,10 +2155,6 @@ struct Serializer> { std::forward_list result; if (size == 0) { - if (FORY_PREDICT_FALSE( - (!reserve_empty_collection>(ctx)))) { - return result; - } return result; } if (FORY_PREDICT_FALSE(!reserve_collection(result, ctx, size))) { @@ -2308,10 +2261,6 @@ struct Serializer> { std::set result; // Per xlang spec: header and type_info are omitted when length is 0 if (size == 0) { - if (FORY_PREDICT_FALSE( - (!reserve_empty_collection>(ctx)))) { - return result; - } return result; } // Fast path for non-polymorphic, non-shared-ref elements @@ -2390,10 +2339,6 @@ struct Serializer> { std::set result; if (size == 0) { - if (FORY_PREDICT_FALSE( - (!reserve_empty_collection>(ctx)))) { - return result; - } return result; } if (FORY_PREDICT_FALSE(!reserve_collection(result, ctx, size))) { @@ -2501,11 +2446,6 @@ struct Serializer> { std::unordered_set result; // Per xlang spec: header and type_info are omitted when length is 0 if (size == 0) { - if (FORY_PREDICT_FALSE( - (!reserve_empty_collection>( - ctx)))) { - return result; - } return result; } // Fast path for non-polymorphic, non-shared-ref elements @@ -2584,11 +2524,6 @@ struct Serializer> { std::unordered_set result; if (size == 0) { - if (FORY_PREDICT_FALSE( - (!reserve_empty_collection>( - ctx)))) { - return result; - } return result; } if (FORY_PREDICT_FALSE(!reserve_collection(result, ctx, size))) { diff --git a/cpp/fory/serialization/map_serializer.h b/cpp/fory/serialization/map_serializer.h index 12fcdac775..44ea9bc8b9 100644 --- a/cpp/fory/serialization/map_serializer.h +++ b/cpp/fory/serialization/map_serializer.h @@ -127,13 +127,6 @@ inline bool reserve_map(MapType &map, ReadContext &ctx, uint32_t length) { return true; } -template inline bool reserve_empty_map(ReadContext &ctx) { - if (FORY_PREDICT_FALSE(ctx.has_error())) { - return false; - } - return ctx.reserve_graph_memory(0); -} - /// write chunk size at header offset inline void write_chunk_size(WriteContext &ctx, size_t header_offset, uint8_t size) { @@ -606,9 +599,6 @@ inline MapType read_map_data_fast(ReadContext &ctx, uint32_t length) { MapType result; if (length == 0) { - if (FORY_PREDICT_FALSE(!reserve_empty_map(ctx))) { - return result; - } return result; } if (FORY_PREDICT_FALSE(!reserve_map(result, ctx, length))) { @@ -741,9 +731,6 @@ template inline MapType read_map_data_slow(ReadContext &ctx, uint32_t length) { MapType result; if (length == 0) { - if (FORY_PREDICT_FALSE(!reserve_empty_map(ctx))) { - return result; - } return result; } if (FORY_PREDICT_FALSE(!reserve_map(result, ctx, length))) { diff --git a/cpp/fory/serialization/struct_serializer.h b/cpp/fory/serialization/struct_serializer.h index 00acf71b2a..a438a75c4e 100644 --- a/cpp/fory/serialization/struct_serializer.h +++ b/cpp/fory/serialization/struct_serializer.h @@ -897,9 +897,6 @@ Container read_configured_list_data(ReadContext &ctx) { uint32_t length = ctx.read_var_uint32(ctx.error()); Container result; if (length == 0) { - if (FORY_PREDICT_FALSE(!reserve_empty_collection(ctx))) { - return result; - } return result; } uint8_t bitmap = ctx.read_uint8(ctx.error()); @@ -946,9 +943,6 @@ FORY_NOINLINE Container read_configured_list_data_as_array_field( return result; } if (length == 0) { - if (FORY_PREDICT_FALSE(!reserve_empty_collection(ctx))) { - return result; - } return result; } uint8_t bitmap = ctx.read_uint8(ctx.error()); @@ -1060,9 +1054,6 @@ MapType read_configured_map_data(ReadContext &ctx) { uint32_t length = ctx.read_var_uint32(ctx.error()); MapType result; if (length == 0) { - if (FORY_PREDICT_FALSE(!reserve_empty_map(ctx))) { - return result; - } return result; } if (FORY_PREDICT_FALSE(!reserve_map(result, ctx, length))) { diff --git a/cpp/fory/serialization/union_serializer.h b/cpp/fory/serialization/union_serializer.h index 8a8bc99fe3..6a2c4db55d 100644 --- a/cpp/fory/serialization/union_serializer.h +++ b/cpp/fory/serialization/union_serializer.h @@ -466,9 +466,6 @@ Container read_union_configured_list_data(ReadContext &ctx) { uint32_t length = ctx.read_var_uint32(ctx.error()); Container result; if (length == 0) { - if (FORY_PREDICT_FALSE(!reserve_empty_collection(ctx))) { - return result; - } return result; } uint8_t bitmap = ctx.read_uint8(ctx.error()); @@ -556,9 +553,6 @@ MapType read_union_configured_map_data(ReadContext &ctx) { uint32_t length = ctx.read_var_uint32(ctx.error()); MapType result; if (length == 0) { - if (FORY_PREDICT_FALSE(!reserve_empty_map(ctx))) { - return result; - } return result; } if (FORY_PREDICT_FALSE(!reserve_map(result, ctx, length))) { diff --git a/go/fory/array.go b/go/fory/array.go index 7c05463320..f9c917bb1e 100644 --- a/go/fory/array.go +++ b/go/fory/array.go @@ -320,7 +320,7 @@ func (s arrayDynSerializer) ReadData(ctx *ReadContext, value reflect.Value) { sliceType := reflect.SliceOf(value.Type().Elem()) elemBytes := int64(value.Type().Elem().Size()) if int64(value.Len()) > maxGraphCount(elemBytes) { - ctx.setGraphMemoryLimitError("graph memory estimate overflows: length=%d elementBytes=%d", value.Len(), elemBytes) + ctx.SetError(DeserializationErrorf("graph memory estimate overflows: length=%d elementBytes=%d", value.Len(), elemBytes)) return } if !ctx.ReserveGraphMemory(int64(value.Len()) * elemBytes) { diff --git a/go/fory/map.go b/go/fory/map.go index e2cab2ad71..9761efc2a8 100644 --- a/go/fory/map.go +++ b/go/fory/map.go @@ -307,15 +307,15 @@ func (s mapSerializer) ReadData(ctx *ReadContext, value reflect.Value) { valueBytes := int64(mapType.Elem().Size()) elemBytes := keyBytes + valueBytes if elemBytes < keyBytes { - ctx.setGraphMemoryLimitError("map entry size overflows: key=%d value=%d", keyBytes, valueBytes) + ctx.SetError(DeserializationErrorf("map entry size overflows: key=%d value=%d", keyBytes, valueBytes)) return } if size < 0 { - ctx.setGraphMemoryLimitError("negative graph element count: %d", size) + ctx.SetError(DeserializationErrorf("negative graph element count: %d", size)) return } if int64(size) > maxGraphCount(elemBytes) { - ctx.setGraphMemoryLimitError("graph memory estimate overflows: length=%d elementBytes=%d", size, elemBytes) + ctx.SetError(DeserializationErrorf("graph memory estimate overflows: length=%d elementBytes=%d", size, elemBytes)) return } if !ctx.ReserveGraphMemory(int64(size) * elemBytes) { diff --git a/go/fory/map_primitive.go b/go/fory/map_primitive.go index 2ba6b55b4d..ee98f310b6 100644 --- a/go/fory/map_primitive.go +++ b/go/fory/map_primitive.go @@ -95,11 +95,11 @@ func readTypedMapSize(ctx *ReadContext, elemBytes int64, maxLength int64) (int, return 0, false } if size < 0 { - ctx.setGraphMemoryLimitError("negative graph element count: %d", size) + ctx.SetError(DeserializationErrorf("negative graph element count: %d", size)) return 0, false } if int64(size) > maxLength { - ctx.setGraphMemoryLimitError("graph memory estimate overflows: length=%d elementBytes=%d", size, elemBytes) + ctx.SetError(DeserializationErrorf("graph memory estimate overflows: length=%d elementBytes=%d", size, elemBytes)) return 0, false } if !ctx.ReserveGraphMemory(int64(size) * elemBytes) { diff --git a/go/fory/reader.go b/go/fory/reader.go index 321c5db4ea..2821e6b42b 100644 --- a/go/fory/reader.go +++ b/go/fory/reader.go @@ -131,15 +131,10 @@ func (c *ReadContext) ReserveGraphMemory(bytes int64) bool { //go:noinline func (c *ReadContext) rejectGraphMemoryBytes(bytes int64) bool { - c.setGraphMemoryLimitError("estimated graph memory must be non-negative, got %d bytes", bytes) + c.SetError(DeserializationErrorf("estimated graph memory must be non-negative, got %d bytes", bytes)) return false } -//go:noinline -func (c *ReadContext) setGraphMemoryLimitError(format string, args ...any) { - c.SetError(DeserializationErrorf(format, args...)) -} - //go:noinline func (c *ReadContext) rejectGraphMemoryExceeded(bytes int64, remaining int64) bool { c.SetError(DeserializationErrorf( @@ -615,11 +610,11 @@ func (c *ReadContext) ReadStringSlice(refMode RefMode, readType bool) []string { return nil } if length < 0 { - c.setGraphMemoryLimitError("negative graph element count: %d", length) + c.SetError(DeserializationErrorf("negative graph element count: %d", length)) return nil } if int64(length) > stringMaxLength { - c.setGraphMemoryLimitError("graph memory estimate overflows: length=%d elementBytes=%d", length, stringElementBytes) + c.SetError(DeserializationErrorf("graph memory estimate overflows: length=%d elementBytes=%d", length, stringElementBytes)) return nil } if !c.ReserveGraphMemory(int64(length) * stringElementBytes) { diff --git a/go/fory/set.go b/go/fory/set.go index 0bd22f68fc..7ff903fcf2 100644 --- a/go/fory/set.go +++ b/go/fory/set.go @@ -322,10 +322,7 @@ func (s setSerializer) ReadData(ctx *ReadContext, value reflect.Value) { valueBytes := int64(type_.Elem().Size()) elemBytes := keyBytes + valueBytes if elemBytes < keyBytes { - ctx.setGraphMemoryLimitError("map entry size overflows: key=%d value=%d", keyBytes, valueBytes) - return - } - if !ctx.ReserveGraphMemory(0) { + ctx.SetError(DeserializationErrorf("map entry size overflows: key=%d value=%d", keyBytes, valueBytes)) return } // Initialize empty set if length is 0 @@ -370,15 +367,15 @@ func (s setSerializer) ReadData(ctx *ReadContext, value reflect.Value) { valueBytes := int64(type_.Elem().Size()) elemBytes := keyBytes + valueBytes if elemBytes < keyBytes { - ctx.setGraphMemoryLimitError("map entry size overflows: key=%d value=%d", keyBytes, valueBytes) + ctx.SetError(DeserializationErrorf("map entry size overflows: key=%d value=%d", keyBytes, valueBytes)) return } if length < 0 { - ctx.setGraphMemoryLimitError("negative graph element count: %d", length) + ctx.SetError(DeserializationErrorf("negative graph element count: %d", length)) return } if int64(length) > maxGraphCount(elemBytes) { - ctx.setGraphMemoryLimitError("graph memory estimate overflows: length=%d elementBytes=%d", length, elemBytes) + ctx.SetError(DeserializationErrorf("graph memory estimate overflows: length=%d elementBytes=%d", length, elemBytes)) return } if !ctx.ReserveGraphMemory(int64(length) * elemBytes) { diff --git a/go/fory/slice.go b/go/fory/slice.go index 95cf93e1e7..6a59df49a1 100644 --- a/go/fory/slice.go +++ b/go/fory/slice.go @@ -321,11 +321,11 @@ func (s *sliceSerializer) ReadData(ctx *ReadContext, value reflect.Value) { } if !isArrayType { if length < 0 { - ctx.setGraphMemoryLimitError("negative graph element count: %d", length) + ctx.SetError(DeserializationErrorf("negative graph element count: %d", length)) return } if int64(length) > s.maxLength { - ctx.setGraphMemoryLimitError("graph memory estimate overflows: length=%d elementBytes=%d", length, s.elemBytes) + ctx.SetError(DeserializationErrorf("graph memory estimate overflows: length=%d elementBytes=%d", length, s.elemBytes)) return } if !ctx.ReserveGraphMemory(int64(length) * s.elemBytes) { diff --git a/go/fory/slice_dyn.go b/go/fory/slice_dyn.go index 80df03f0e8..076ada5071 100644 --- a/go/fory/slice_dyn.go +++ b/go/fory/slice_dyn.go @@ -289,11 +289,11 @@ func (s *sliceDynSerializer) readData(ctx *ReadContext, value reflect.Value, exp } if !allocatedByCaller { if length < 0 { - ctx.setGraphMemoryLimitError("negative graph element count: %d", length) + ctx.SetError(DeserializationErrorf("negative graph element count: %d", length)) return } if int64(length) > s.maxLength { - ctx.setGraphMemoryLimitError("graph memory estimate overflows: length=%d elementBytes=%d", length, s.elemBytes) + ctx.SetError(DeserializationErrorf("graph memory estimate overflows: length=%d elementBytes=%d", length, s.elemBytes)) return } if !ctx.ReserveGraphMemory(int64(length) * s.elemBytes) { diff --git a/go/fory/slice_primitive.go b/go/fory/slice_primitive.go index c7089776d5..6fc371075a 100644 --- a/go/fory/slice_primitive.go +++ b/go/fory/slice_primitive.go @@ -653,11 +653,11 @@ func (s stringSliceSerializer) ReadData(ctx *ReadContext, value reflect.Value) { } ptr := (*[]string)(value.Addr().UnsafePointer()) if length < 0 { - ctx.setGraphMemoryLimitError("negative graph element count: %d", length) + ctx.SetError(DeserializationErrorf("negative graph element count: %d", length)) return } if int64(length) > stringMaxLength { - ctx.setGraphMemoryLimitError("graph memory estimate overflows: length=%d elementBytes=%d", length, stringElementBytes) + ctx.SetError(DeserializationErrorf("graph memory estimate overflows: length=%d elementBytes=%d", length, stringElementBytes)) return } if !ctx.ReserveGraphMemory(int64(length) * stringElementBytes) { diff --git a/go/fory/slice_primitive_list.go b/go/fory/slice_primitive_list.go index c327815aa6..55bb5d5f23 100644 --- a/go/fory/slice_primitive_list.go +++ b/go/fory/slice_primitive_list.go @@ -180,11 +180,11 @@ func (s primitiveListSerializer) ReadData(ctx *ReadContext, value reflect.Value) return } if length < 0 { - ctx.setGraphMemoryLimitError("negative graph element count: %d", length) + ctx.SetError(DeserializationErrorf("negative graph element count: %d", length)) return } if int64(length) > s.maxLength { - ctx.setGraphMemoryLimitError("graph memory estimate overflows: length=%d elementBytes=%d", length, s.elemBytes) + ctx.SetError(DeserializationErrorf("graph memory estimate overflows: length=%d elementBytes=%d", length, s.elemBytes)) return } if !ctx.ReserveGraphMemory(int64(length) * s.elemBytes) { @@ -251,9 +251,6 @@ func (s compatiblePrimitiveListToArraySerializer) ReadData(ctx *ReadContext, val } if length == 0 { if value.Kind() == reflect.Slice { - if !ctx.ReserveGraphMemory(0) { - return - } value.Set(reflect.MakeSlice(value.Type(), 0, 0)) } else if value.Len() != 0 { ctx.SetError(DeserializationErrorf("array-compatible list length %d does not match array length %d", length, value.Len())) @@ -293,11 +290,11 @@ func (s compatiblePrimitiveListToArraySerializer) ReadData(ctx *ReadContext, val } if value.Kind() == reflect.Slice { if length < 0 { - ctx.setGraphMemoryLimitError("negative graph element count: %d", length) + ctx.SetError(DeserializationErrorf("negative graph element count: %d", length)) return } if int64(length) > s.listReader.maxLength { - ctx.setGraphMemoryLimitError("graph memory estimate overflows: length=%d elementBytes=%d", length, s.listReader.elemBytes) + ctx.SetError(DeserializationErrorf("graph memory estimate overflows: length=%d elementBytes=%d", length, s.listReader.elemBytes)) return } if !ctx.ReserveGraphMemory(int64(length) * s.listReader.elemBytes) { diff --git a/java/fory-core/src/main/java/org/apache/fory/Fory.java b/java/fory-core/src/main/java/org/apache/fory/Fory.java index 1db63df9b6..a7916400ae 100644 --- a/java/fory-core/src/main/java/org/apache/fory/Fory.java +++ b/java/fory-core/src/main/java/org/apache/fory/Fory.java @@ -420,18 +420,38 @@ public Object deserialize(ByteBuffer byteBuffer) { @Override public T deserialize(byte[] bytes, Class type) { - return deserializeRoot(MemoryUtils.wrap(bytes), type); + return deserialize(MemoryUtils.wrap(bytes), type); } @Override public T deserialize(MemoryBuffer buffer, Class type) { - return deserializeRoot(buffer, type); + ensureRegistrationFinished(); + byte bitmap = buffer.readByte(); + if (bitmap != headerBitmap) { + checkHeaderBitmapWithoutOutOfBand(bitmap); + } + readContext.prepare(buffer, null, false); + try { + try { + jitContext.lock(); + if (readContext.getDepth() > 0) { + throwDepthDeserializationException(); + } + return deserializeByType(buffer, type); + } finally { + jitContext.unlock(); + } + } catch (Throwable t) { + throw ExceptionUtils.handleReadFailed(this, t); + } finally { + readContext.reset(); + } } @Override public T deserialize(ForyInputStream inputStream, Class type) { try { - return deserializeRoot(inputStream.getBuffer(), type); + return deserialize(inputStream.getBuffer(), type); } finally { inputStream.shrinkBuffer(); } @@ -439,7 +459,7 @@ public T deserialize(ForyInputStream inputStream, Class type) { @Override public T deserialize(ForyReadableChannel channel, Class type) { - return deserializeRoot(channel.getBuffer(), type); + return deserialize(channel.getBuffer(), type); } @Override @@ -467,60 +487,6 @@ public Object deserialize(MemoryBuffer buffer) { */ @Override public Object deserialize(MemoryBuffer buffer, Iterable outOfBandBuffers) { - return deserializeRoot(buffer, outOfBandBuffers); - } - - @Override - public Object deserialize(ForyInputStream inputStream) { - return deserialize(inputStream, (Iterable) null); - } - - @Override - public Object deserialize(ForyInputStream inputStream, Iterable outOfBandBuffers) { - try { - MemoryBuffer buf = inputStream.getBuffer(); - return deserializeRoot(buf, outOfBandBuffers); - } finally { - inputStream.shrinkBuffer(); - } - } - - @Override - public Object deserialize(ForyReadableChannel channel) { - return deserialize(channel, (Iterable) null); - } - - @Override - public Object deserialize(ForyReadableChannel channel, Iterable outOfBandBuffers) { - MemoryBuffer buf = channel.getBuffer(); - return deserializeRoot(buf, outOfBandBuffers); - } - - private T deserializeRoot(MemoryBuffer buffer, Class type) { - ensureRegistrationFinished(); - byte bitmap = buffer.readByte(); - if (bitmap != headerBitmap) { - checkHeaderBitmapWithoutOutOfBand(bitmap); - } - readContext.prepare(buffer, null, false); - try { - try { - jitContext.lock(); - if (readContext.getDepth() > 0) { - throwDepthDeserializationException(); - } - return deserializeByType(buffer, type); - } finally { - jitContext.unlock(); - } - } catch (Throwable t) { - throw ExceptionUtils.handleReadFailed(this, t); - } finally { - readContext.reset(); - } - } - - private Object deserializeRoot(MemoryBuffer buffer, Iterable outOfBandBuffers) { ensureRegistrationFinished(); byte bitmap = buffer.readByte(); boolean peerOutOfBandEnabled = false; @@ -557,6 +523,32 @@ private Object deserializeRoot(MemoryBuffer buffer, Iterable outOf } } + @Override + public Object deserialize(ForyInputStream inputStream) { + return deserialize(inputStream, (Iterable) null); + } + + @Override + public Object deserialize(ForyInputStream inputStream, Iterable outOfBandBuffers) { + try { + MemoryBuffer buf = inputStream.getBuffer(); + return deserialize(buf, outOfBandBuffers); + } finally { + inputStream.shrinkBuffer(); + } + } + + @Override + public Object deserialize(ForyReadableChannel channel) { + return deserialize(channel, (Iterable) null); + } + + @Override + public Object deserialize(ForyReadableChannel channel, Iterable outOfBandBuffers) { + MemoryBuffer buf = channel.getBuffer(); + return deserialize(buf, outOfBandBuffers); + } + @SuppressWarnings("unchecked") private T deserializeByType(MemoryBuffer buffer, Class type) { readContext diff --git a/java/fory-core/src/main/java/org/apache/fory/config/Config.java b/java/fory-core/src/main/java/org/apache/fory/config/Config.java index 869bce3ef1..8f07945706 100644 --- a/java/fory-core/src/main/java/org/apache/fory/config/Config.java +++ b/java/fory-core/src/main/java/org/apache/fory/config/Config.java @@ -322,7 +322,7 @@ public int maxAverageSchemaVersionsPerType() { return maxAverageSchemaVersionsPerType; } - /** Returns the root-operation estimated graph memory limit in bytes. Non-positive disables it. */ + /** Returns the root-operation estimated graph memory limit in bytes. */ public long maxGraphMemoryBytes() { return maxGraphMemoryBytes; } diff --git a/java/fory-core/src/main/java/org/apache/fory/context/ReadContext.java b/java/fory-core/src/main/java/org/apache/fory/context/ReadContext.java index d55c6520d0..ff30a0950a 100644 --- a/java/fory-core/src/main/java/org/apache/fory/context/ReadContext.java +++ b/java/fory-core/src/main/java/org/apache/fory/context/ReadContext.java @@ -617,12 +617,7 @@ public T readRef(Serializer serializer) { return (T) readNonRef(serializer); } - /** - * Reads the root object for one deserialization operation. - * - *

Root no-ref deserialization owns the null marker and type metadata directly; using the - * generic ref-reader path here makes scalar roots pay reference dispatch they can never use. - */ + /** Reads the root object for one deserialization operation. */ public Object readRootRef() { if (trackingRef) { return readRef(rootTypeInfoHolder); From 6456dce012865367c430db9f35ae15703e9cd3c1 Mon Sep 17 00:00:00 2001 From: chaokunyang Date: Fri, 3 Jul 2026 04:31:51 +0800 Subject: [PATCH 36/54] refactor: tighten graph memory cleanup --- .../serialization/graph_memory_budget_test.cc | 2 +- .../Fory.Tests/GraphMemoryBudgetTests.cs | 6 +-- .../test/generated_round_trip_test.dart | 4 +- dart/packages/fory/lib/src/config.dart | 53 ++++++++++++------- .../fory/test/graph_memory_budget_test.dart | 7 +-- ...calar_and_typed_array_serializer_test.dart | 2 +- .../fory/test/xlang_protocol_test.dart | 22 ++++---- go/fory/fory.go | 35 ++---------- go/fory/graph_memory_budget_test.go | 12 ++--- .../serializer/GraphMemoryBudgetTest.java | 2 +- .../pyfory/tests/test_graph_memory_budget.py | 14 ++--- rust/tests/tests/test_graph_memory_budget.rs | 2 +- .../ForyTests/GraphMemoryBudgetTests.swift | 10 ++-- 13 files changed, 79 insertions(+), 92 deletions(-) diff --git a/cpp/fory/serialization/graph_memory_budget_test.cc b/cpp/fory/serialization/graph_memory_budget_test.cc index ab8e63e672..3410cc0dd9 100644 --- a/cpp/fory/serialization/graph_memory_budget_test.cc +++ b/cpp/fory/serialization/graph_memory_budget_test.cc @@ -265,7 +265,7 @@ TEST(GraphMemoryBudgetTest, EmptyStructRootChargesOwner) { expect_budget_boundary(value, sizeof(BudgetEmpty)); } -TEST(GraphMemoryBudgetTest, NestedEmptyContainersUseParentStorage) { +TEST(GraphMemoryBudgetTest, NestedEmptyContainers) { std::vector> value(1); auto bytes = serialize_value(value); const size_t required = sizeof(std::vector>) + diff --git a/csharp/tests/Fory.Tests/GraphMemoryBudgetTests.cs b/csharp/tests/Fory.Tests/GraphMemoryBudgetTests.cs index 60b65c0e30..05061ed855 100644 --- a/csharp/tests/Fory.Tests/GraphMemoryBudgetTests.cs +++ b/csharp/tests/Fory.Tests/GraphMemoryBudgetTests.cs @@ -220,7 +220,7 @@ public void MapBudgetIsCharged() } [Fact] - public void ReferenceArrayAndInlineValueListAreCharged() + public void ArrayAndInlineListBudget() { BudgetArrayHolder holder = new() { @@ -310,7 +310,7 @@ void Check(T value, Func assertValue) } [Fact] - public void DenseStringBinaryAndPrimitiveArraysAreSkipped() + public void DenseLeafOwnersAreSkipped() { Assert.Equal("budget", NewFory(1).Deserialize(Serialize("budget"))); Assert.Equal(new byte[] { 1, 2, 3 }, NewFory(1).Deserialize(Serialize(new byte[] { 1, 2, 3 }))); @@ -335,7 +335,7 @@ public void CompatibleListToDenseArrayIsSkipped() } [Fact] - public void ByteAvailabilityCheckStillRejectsLargeLength() + public void ByteChecksRejectLargeLength() { byte[] bytes = [64, 0]; ReadContext context = new(new ByteReader(bytes), new TypeResolver(), NewFory().Config); diff --git a/dart/packages/fory-test/test/generated_round_trip_test.dart b/dart/packages/fory-test/test/generated_round_trip_test.dart index 2d8b603511..6dd2026309 100644 --- a/dart/packages/fory-test/test/generated_round_trip_test.dart +++ b/dart/packages/fory-test/test/generated_round_trip_test.dart @@ -28,8 +28,8 @@ import 'package:test/test.dart'; void main() { group('configuration', () { test('defaults to compatible mode unless explicitly set', () { - expect(const Config().compatible, isTrue); - expect(const Config(compatible: false).compatible, isFalse); + expect(Config().compatible, isTrue); + expect(Config(compatible: false).compatible, isFalse); }); }); diff --git a/dart/packages/fory/lib/src/config.dart b/dart/packages/fory/lib/src/config.dart index 2207819321..88c3f36a0d 100644 --- a/dart/packages/fory/lib/src/config.dart +++ b/dart/packages/fory/lib/src/config.dart @@ -66,30 +66,45 @@ final class Config { /// /// Invalid numeric limits fail fast. When [compatible] is `true`, /// [checkStructVersion] is normalized to `false`. - const Config({ + Config({ this.compatible = true, bool checkStructVersion = true, - this.maxDepth = defaultMaxDepth, - this.maxTypeFields = defaultMaxTypeFields, - this.maxTypeMetaBytes = defaultMaxTypeMetaBytes, - this.maxSchemaVersionsPerType = defaultMaxSchemaVersionsPerType, - this.maxAverageSchemaVersionsPerType = + int maxDepth = defaultMaxDepth, + int maxTypeFields = defaultMaxTypeFields, + int maxTypeMetaBytes = defaultMaxTypeMetaBytes, + int maxSchemaVersionsPerType = defaultMaxSchemaVersionsPerType, + int maxAverageSchemaVersionsPerType = defaultMaxAverageSchemaVersionsPerType, - this.maxGraphMemoryBytes = defaultMaxGraphMemoryBytes, + int maxGraphMemoryBytes = defaultMaxGraphMemoryBytes, }) : checkStructVersion = compatible ? false : checkStructVersion, - assert(maxDepth > 0, 'maxDepth must be positive'), - assert(maxTypeFields > 0, 'maxTypeFields must be positive'), - assert(maxTypeMetaBytes > 0, 'maxTypeMetaBytes must be positive'), - assert( - maxSchemaVersionsPerType > 0, - 'maxSchemaVersionsPerType must be positive', + maxDepth = _positive(maxDepth, 'maxDepth'), + maxTypeFields = _positive(maxTypeFields, 'maxTypeFields'), + maxTypeMetaBytes = _positive(maxTypeMetaBytes, 'maxTypeMetaBytes'), + maxSchemaVersionsPerType = _positive( + maxSchemaVersionsPerType, + 'maxSchemaVersionsPerType', ), - assert( - maxAverageSchemaVersionsPerType > 0, - 'maxAverageSchemaVersionsPerType must be positive', + maxAverageSchemaVersionsPerType = _positive( + maxAverageSchemaVersionsPerType, + 'maxAverageSchemaVersionsPerType', ), - assert( - maxGraphMemoryBytes > 0 && maxGraphMemoryBytes <= 9007199254740991, - 'maxGraphMemoryBytes must be a positive safe integer', + maxGraphMemoryBytes = _positiveSafeInteger( + maxGraphMemoryBytes, + 'maxGraphMemoryBytes', ); + + static int _positive(int value, String name) { + if (value <= 0) { + throw ArgumentError.value(value, name, 'must be positive'); + } + return value; + } + + static int _positiveSafeInteger(int value, String name) { + const maxSafeInteger = 9007199254740991; + if (value <= 0 || value > maxSafeInteger) { + throw ArgumentError.value(value, name, 'must be a positive safe integer'); + } + return value; + } } diff --git a/dart/packages/fory/test/graph_memory_budget_test.dart b/dart/packages/fory/test/graph_memory_budget_test.dart index 299f6992c5..1d3672c942 100644 --- a/dart/packages/fory/test/graph_memory_budget_test.dart +++ b/dart/packages/fory/test/graph_memory_budget_test.dart @@ -152,13 +152,10 @@ void main() { expect(() => context.reserveGraphMemory(31), returnsNormally); expect(() => context.reserveGraphMemory(1), _throwsGraphBudget); - expect( - () => Fory(maxGraphMemoryBytes: 0), - throwsA(isA()), - ); + expect(() => Fory(maxGraphMemoryBytes: 0), throwsA(isA())); expect( () => Fory(maxGraphMemoryBytes: -2), - throwsA(isA()), + throwsA(isA()), ); }); diff --git a/dart/packages/fory/test/scalar_and_typed_array_serializer_test.dart b/dart/packages/fory/test/scalar_and_typed_array_serializer_test.dart index 7ef820c666..a3d64234e2 100644 --- a/dart/packages/fory/test/scalar_and_typed_array_serializer_test.dart +++ b/dart/packages/fory/test/scalar_and_typed_array_serializer_test.dart @@ -62,7 +62,7 @@ FieldInfo _compatibleScalarField({ } ReadContext _compatibleReadContext(Buffer buffer) { - const config = Config(); + final config = Config(); final resolver = TypeResolver(config); return ReadContext(config, resolver, RefReader(), MetaStringReader(resolver)) ..prepare(buffer); diff --git a/dart/packages/fory/test/xlang_protocol_test.dart b/dart/packages/fory/test/xlang_protocol_test.dart index 21e2ada7d3..74b2bcd140 100644 --- a/dart/packages/fory/test/xlang_protocol_test.dart +++ b/dart/packages/fory/test/xlang_protocol_test.dart @@ -151,7 +151,7 @@ void _rememberLateHolder() { } Uint8List _lateHolderTypeDefBytes({required bool registerExtFirst}) { - final resolver = TypeResolver(const Config()); + final resolver = TypeResolver(Config()); _rememberLateHolder(); if (registerExtFirst) { resolver.registerSerializer( @@ -183,7 +183,7 @@ Uint8List _typeMetaBytes( String name, List fields, ) { - final resolver = TypeResolver(const Config()); + final resolver = TypeResolver(Config()); _rememberSchema(type, fields); final parts = name.split('.'); resolver.registerGenerated( @@ -203,7 +203,7 @@ Uint8List _typeMetaBytes( } Uint8List _enumTypeMetaBytes(Type type, String name) { - final resolver = TypeResolver(const Config()); + final resolver = TypeResolver(Config()); _rememberEnum(type); final parts = name.split('.'); resolver.registerGenerated( @@ -372,7 +372,7 @@ void main() { test('remote schema limit rejects extra versions', () { const name = 'example.Unknown'; - final reader = TypeResolver(const Config(maxSchemaVersionsPerType: 1)); + final reader = TypeResolver(Config(maxSchemaVersionsPerType: 1)); _rememberSchema(_SchemaLocal, []); reader.registerGenerated( _SchemaLocal, @@ -393,7 +393,7 @@ void main() { test('named enum TypeDef uses metadata byte limit', () { const name = 'example.RemoteEnum'; - final reader = TypeResolver(const Config(maxTypeMetaBytes: 1)); + final reader = TypeResolver(Config(maxTypeMetaBytes: 1)); _rememberEnum(_SchemaLocal); final bytes = _enumTypeMetaBytes(_SchemaRemoteA, name); @@ -402,7 +402,7 @@ void main() { test('registered named enum TypeDef uses metadata byte limit', () { const name = 'example.RemoteEnum'; - final reader = TypeResolver(const Config(maxTypeMetaBytes: 1)); + final reader = TypeResolver(Config(maxTypeMetaBytes: 1)); _rememberEnum(_SchemaLocal); reader.registerGenerated( _SchemaLocal, @@ -416,7 +416,7 @@ void main() { test('exact local named enum TypeDef is accepted', () { const name = 'example.SharedEnum'; - final reader = TypeResolver(const Config(maxSchemaVersionsPerType: 1)); + final reader = TypeResolver(Config(maxSchemaVersionsPerType: 1)); _rememberEnum(_SchemaLocal); reader.registerGenerated( _SchemaLocal, @@ -429,7 +429,7 @@ void main() { }); test('type meta field limit rejects large struct', () { - final reader = TypeResolver(const Config(maxTypeFields: 1)); + final reader = TypeResolver(Config(maxTypeFields: 1)); final bytes = _typeMetaBytes( _SchemaRemoteA, 'example.TooManyFields', @@ -443,7 +443,7 @@ void main() { }); test('type meta body limit rejects large metadata', () { - final reader = TypeResolver(const Config(maxTypeMetaBytes: 1)); + final reader = TypeResolver(Config(maxTypeMetaBytes: 1)); final bytes = _typeMetaBytes( _SchemaRemoteA, 'example.LargeTypeMeta', @@ -454,7 +454,7 @@ void main() { }); test('remote schema limit keeps unknown types separate', () { - final reader = TypeResolver(const Config(maxSchemaVersionsPerType: 1)); + final reader = TypeResolver(Config(maxSchemaVersionsPerType: 1)); _rememberSchema(_SchemaLocal, []); reader.registerGenerated( _SchemaLocal, @@ -484,7 +484,7 @@ void main() { test('failed remote schema does not consume schema limit', () { const name = 'example.Accepted'; - final reader = TypeResolver(const Config(maxSchemaVersionsPerType: 1)); + final reader = TypeResolver(Config(maxSchemaVersionsPerType: 1)); _rememberSchema(_SchemaLocal, [ _generatedField('value'), ]); diff --git a/go/fory/fory.go b/go/fory/fory.go index 1014c6b5d1..d9613eeb34 100644 --- a/go/fory/fory.go +++ b/go/fory/fory.go @@ -48,11 +48,6 @@ func splitRegisteredName(name string) (string, string, error) { return namespace, typeName, nil } -type ifaceWords struct { - typ unsafe.Pointer - data unsafe.Pointer -} - // ============================================================================ // Constants // ============================================================================ @@ -201,10 +196,8 @@ type Fory struct { typeResolver *TypeResolver refResolver *RefResolver - rootGraphType reflect.Type - rootGraphBytes int64 - rootReadTypeID unsafe.Pointer - rootReadSerializer Serializer + rootGraphType reflect.Type + rootGraphBytes int64 } // New creates a new Fory instance with the given options @@ -578,7 +571,6 @@ func (f *Fory) Serialize(value any) ([]byte, error) { func (f *Fory) Deserialize(data []byte, v any) error { defer f.resetReadState() f.readCtx.SetData(data) - typeID := (*ifaceWords)(unsafe.Pointer(&v)).typ target := reflect.ValueOf(v).Elem() targetType := target.Type() limit := f.config.MaxGraphMemoryBytes @@ -595,8 +587,9 @@ func (f *Fory) Deserialize(data []byte, v any) error { return f.readCtx.TakeError() } - // Deserialize the value - TypeMeta is read inline using streaming protocol - f.readRootValue(target, typeID) + // Root writes include type metadata, so keep the root ReadValue path. + // Calling a cached serializer directly would read that metadata byte as payload. + f.readCtx.ReadValue(target, RefModeTracking, true) if f.readCtx.HasError() { return f.readCtx.TakeError() } @@ -1257,21 +1250,3 @@ func (f *Fory) rootGraphBytesFor(targetType reflect.Type) (int64, bool) { f.rootGraphBytes = bytes return bytes, true } - -func (f *Fory) readRootValue(target reflect.Value, typeID unsafe.Pointer) { - serializer := f.rootReadSerializer - if typeID == f.rootReadTypeID && serializer != nil { - serializer.Read(f.readCtx, RefModeTracking, true, false, target) - return - } - targetType := target.Type() - if targetType.Kind() == reflect.Struct { - if typeInfo := f.readCtx.getTypeInfoByType(targetType); typeInfo != nil && typeInfo.Serializer != nil { - f.rootReadTypeID = typeID - f.rootReadSerializer = typeInfo.Serializer - typeInfo.Serializer.Read(f.readCtx, RefModeTracking, true, false, target) - return - } - } - f.readCtx.ReadValue(target, RefModeTracking, true) -} diff --git a/go/fory/graph_memory_budget_test.go b/go/fory/graph_memory_budget_test.go index 2fe91ae480..57536abc9b 100644 --- a/go/fory/graph_memory_budget_test.go +++ b/go/fory/graph_memory_budget_test.go @@ -65,7 +65,7 @@ func TestGraphMemoryBudgetFixedDefault(t *testing.T) { require.Equal(t, int64(77), ctx.graphMemoryLimitBytes) } -func TestGraphMemoryBudgetRootKindsShareDefault(t *testing.T) { +func TestGraphBudgetRootKinds(t *testing.T) { writer := New(WithCompatible(false)) values := make([]any, 12000) for i := range values { @@ -103,7 +103,7 @@ func TestGraphMemoryBudgetBufferRoots(t *testing.T) { require.Equal(t, value, fromBuffer) } -func TestGraphMemoryBudgetExplicitOverride(t *testing.T) { +func TestGraphBudgetOverride(t *testing.T) { writer := New(WithCompatible(false)) values := make([]any, 12000) for i := range values { @@ -118,7 +118,7 @@ func TestGraphMemoryBudgetExplicitOverride(t *testing.T) { require.Len(t, out, len(values)) } -func TestGraphMemoryBudgetEmptyAndCumulative(t *testing.T) { +func TestGraphBudgetCumulative(t *testing.T) { data, err := New(WithCompatible(false)).Serialize([]any{}) require.NoError(t, err) var empty []any @@ -158,7 +158,7 @@ func TestGraphMemoryBudgetMapAndOverflow(t *testing.T) { require.Contains(t, ctx.CheckError().Error(), "non-negative") } -func TestGraphMemoryBudgetSlicesAndInlineValues(t *testing.T) { +func TestGraphBudgetSlices(t *testing.T) { data, err := New().Serialize([]string{"a"}) require.NoError(t, err) var stringsOut []string @@ -236,7 +236,7 @@ func TestGraphMemoryBudgetStructOwners(t *testing.T) { require.Equal(t, int32(7), outValue.A) } -func TestGraphMemoryBudgetSkipsDenseOwners(t *testing.T) { +func TestGraphBudgetSkipsDense(t *testing.T) { f := New(WithMaxGraphMemoryBytes(1)) stringData, err := New().Serialize(strings.Repeat("x", 128)) @@ -258,7 +258,7 @@ func TestGraphMemoryBudgetSkipsDenseOwners(t *testing.T) { require.Equal(t, []int32{1, 2, 3, 4}, ints) } -func TestGraphMemoryBudgetPreservesByteChecks(t *testing.T) { +func TestGraphBudgetByteChecks(t *testing.T) { buf := NewByteBuffer(nil) buf.WriteByte_(XLangFlag) buf.WriteInt8(NotNullValueFlag) diff --git a/java/fory-core/src/test/java/org/apache/fory/serializer/GraphMemoryBudgetTest.java b/java/fory-core/src/test/java/org/apache/fory/serializer/GraphMemoryBudgetTest.java index 5d26ea65b4..b4aa372fee 100644 --- a/java/fory-core/src/test/java/org/apache/fory/serializer/GraphMemoryBudgetTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/serializer/GraphMemoryBudgetTest.java @@ -75,7 +75,7 @@ public void testExplicitBudgetWins() { } @Test - public void testNestedEmptyContainersUseParentStorage() { + public void testNestedEmptyContainers() { List value = emptyLists(1); byte[] bytes = builder().build().serialize(value); long required = collectionBytes(1) + collectionBytes(0); diff --git a/python/pyfory/tests/test_graph_memory_budget.py b/python/pyfory/tests/test_graph_memory_budget.py index bbaaec57f3..aa83a406e5 100644 --- a/python/pyfory/tests/test_graph_memory_budget.py +++ b/python/pyfory/tests/test_graph_memory_budget.py @@ -138,7 +138,7 @@ def test_fixed_default_budget(): fory.reset_read() -def test_stream_uses_fixed_default_budget(): +def test_stream_default_budget(): fory = new_fory(xlang=False) try: buffer = Buffer.from_stream(OneByteStream(b"streamed")) @@ -148,19 +148,19 @@ def test_stream_uses_fixed_default_budget(): fory.reset_read() -def test_explicit_config_overrides_default(): +def test_explicit_budget(): value = [1] budget = collection_memory(1) assert expect_budget(value, budget) == value -def test_nested_empty_containers_use_parent_storage(): +def test_nested_empty_containers(): value = [[]] budget = collection_memory(1) + collection_memory(0) assert expect_budget(value, budget) == value -def test_sibling_nested_containers_are_cumulative(): +def test_sibling_cumulative_budget(): value = [[], [], []] budget = collection_memory(3) + 3 * collection_memory(0) assert expect_budget(value, budget) == value @@ -181,7 +181,7 @@ def test_empty_object_owner_is_charged(): assert reader.deserialize(data) == value -def test_dynamic_object_owner_is_charged(): +def test_dynamic_object_budget(): value = BudgetObject() value.left = 1 value.right = "x" @@ -226,7 +226,7 @@ def test_object_ndarray_budget(): np.testing.assert_array_equal(restored, value) -def test_string_binary_and_dense_arrays_skip_budget(): +def test_dense_leaf_owners_skipped(): values = [ "x" * 256, b"x" * 256, @@ -243,7 +243,7 @@ def test_string_binary_and_dense_arrays_skip_budget(): assert restored == value -def test_declared_large_list_still_needs_bytes(): +def test_large_list_needs_bytes(): fory = new_fory(10_000_000, xlang=False) serializer = ListSerializer(fory.type_resolver, list) try: diff --git a/rust/tests/tests/test_graph_memory_budget.rs b/rust/tests/tests/test_graph_memory_budget.rs index aa3ff1da4c..8abd761558 100644 --- a/rust/tests/tests/test_graph_memory_budget.rs +++ b/rust/tests/tests/test_graph_memory_budget.rs @@ -102,7 +102,7 @@ fn byte_root_uses_fixed_default_budget() { } #[test] -fn reader_root_uses_fixed_default_budget() { +fn reader_root_default_budget() { let value = compact_empty_lists(12000); let writer = fory_with_budget(DEFAULT_GRAPH_MEMORY_BYTES); let bytes = writer.serialize(&value).unwrap(); diff --git a/swift/Tests/ForyTests/GraphMemoryBudgetTests.swift b/swift/Tests/ForyTests/GraphMemoryBudgetTests.swift index 9204220921..0ecc4b857e 100644 --- a/swift/Tests/ForyTests/GraphMemoryBudgetTests.swift +++ b/swift/Tests/ForyTests/GraphMemoryBudgetTests.swift @@ -162,7 +162,7 @@ func fixedDefaultBudget() throws { } @Test -func byteBufferRootUsesFixedDefaultBudget() throws { +func byteBufferRootDefaultBudget() throws { let count = 6 let value = Array(repeating: [String](), count: count) let bytes = try makeBudgetFory().serialize(value) @@ -254,7 +254,7 @@ func emptyTypedMapOwnerIsCharged() throws { } @Test -func referenceAndInlineValueArraysAreCharged() throws { +func arrayInlineValueBudget() throws { let nodes = (0..<4).map { BudgetNode(id: Int32($0)) } let nodeBytes = try makeBudgetFory().serialize(nodes) let nodeBudget = rootArrayBudget( @@ -296,7 +296,7 @@ func setConversionOwnerChargedOnce() throws { } @Test -func stringBinaryAndPrimitiveDenseArrayOwnersAreSkipped() throws { +func denseLeafOwnersSkipped() throws { let value = BudgetDenseHolder( text: "budget", data: Data([1, 2, 3]), @@ -426,7 +426,7 @@ func dynamicAnyArrayBudget() throws { } @Test -func compatibleListToDenseArraySkipsLeafOwner() throws { +func compatibleDenseArraySkip() throws { let writer = makeCompatibleBudgetFory() writer.register(BudgetListDenseWriter.self, id: 9804) let reader = makeCompatibleBudgetFory( @@ -447,7 +447,7 @@ func compatibleListToDenseArraySkipsLeafOwner() throws { } @Test -func byteAvailabilityCheckStillRejectsLargeLength() throws { +func byteCheckRejectsLargeLength() throws { let buffer = ByteBuffer() buffer.writeVarUInt32(64) buffer.writeUInt8(CollectionHeader.sameType | CollectionHeader.declaredElementType) From 216301f2c38694ce68ab7a224fc22bba2b7e979e Mon Sep 17 00:00:00 2001 From: chaokunyang Date: Fri, 3 Jul 2026 05:03:12 +0800 Subject: [PATCH 37/54] refactor(rust): remove unused resolver harness paths --- rust/fory-core/src/resolver/type_resolver.rs | 109 ------------------- 1 file changed, 109 deletions(-) diff --git a/rust/fory-core/src/resolver/type_resolver.rs b/rust/fory-core/src/resolver/type_resolver.rs index 4813180594..39424ba499 100644 --- a/rust/fory-core/src/resolver/type_resolver.rs +++ b/rust/fory-core/src/resolver/type_resolver.rs @@ -21,7 +21,6 @@ use crate::meta::{ MetaString, TypeMeta, NAMESPACE_ENCODER, NAMESPACE_ENCODINGS, TYPE_NAME_ENCODER, TYPE_NAME_ENCODINGS, }; -use crate::resolver::RefMode; use crate::serializer::{ForyDefault, Serializer, StructSerializer}; use crate::type_id::{get_ext_actual_type_id, is_enum_type_id}; use crate::types::{Date, Duration, Timestamp}; @@ -50,16 +49,6 @@ fn supports_type_def(type_id: u32) -> bool { ) } -type WriteFn = fn( - &dyn Any, - &mut WriteContext, - ref_mode: RefMode, - write_type_info: bool, - has_enerics: bool, -) -> Result<(), Error>; -type ReadFn = - fn(&mut ReadContext, ref_mode: RefMode, read_type_info: bool) -> Result, Error>; - type WriteDataFn = fn(&dyn Any, &mut WriteContext, has_generics: bool) -> Result<(), Error>; type ReadDataFn = fn(&mut ReadContext) -> Result, Error>; type ReadDataAsSendSyncAnyFn = fn(&mut ReadContext) -> Result, Error>; @@ -94,8 +83,6 @@ fn split_named_registration<'a>(name: &'a str, api: &str) -> Result<(&'a str, &' #[derive(Clone, Debug)] pub struct Harness { - write_fn: WriteFn, - read_fn: ReadFn, write_data_fn: WriteDataFn, read_data_fn: ReadDataFn, read_data_as_send_sync_any_fn: ReadDataAsSendSyncAnyFn, @@ -108,8 +95,6 @@ pub struct Harness { impl Harness { pub fn stub() -> Harness { Harness { - write_fn: stub_write_fn, - read_fn: stub_read_fn, write_data_fn: stub_write_data_fn, read_data_fn: stub_read_data_fn, read_data_as_send_sync_any_fn: stub_read_data_as_send_sync_any_fn, @@ -120,16 +105,6 @@ impl Harness { } } - #[inline(always)] - pub fn get_write_fn(&self) -> WriteFn { - self.write_fn - } - - #[inline(always)] - pub fn get_read_fn(&self) -> ReadFn { - self.read_fn - } - #[inline(always)] pub fn get_write_data_fn(&self) -> WriteDataFn { self.write_data_fn @@ -338,8 +313,6 @@ impl TypeInfo { } else { // Create a stub harness that returns errors when called Harness { - write_fn: stub_write_fn, - read_fn: stub_read_fn, write_data_fn: stub_write_data_fn, read_data_fn: stub_read_data_fn, read_data_as_send_sync_any_fn: stub_read_data_as_send_sync_any_fn, @@ -365,24 +338,6 @@ impl TypeInfo { } // Stub functions for when a type doesn't exist locally -fn stub_write_fn( - _: &dyn Any, - _: &mut WriteContext, - _: RefMode, - _: bool, - _: bool, -) -> Result<(), Error> { - Err(Error::type_error( - "Cannot serialize unknown remote type - type not registered locally", - )) -} - -fn stub_read_fn(_: &mut ReadContext, _: RefMode, _: bool) -> Result, Error> { - Err(Error::type_error( - "Cannot deserialize unknown remote type - type not registered locally", - )) -} - fn stub_write_data_fn(_: &dyn Any, _: &mut WriteContext, _: bool) -> Result<(), Error> { Err(Error::type_error( "Cannot serialize unknown remote type - type not registered locally", @@ -926,36 +881,6 @@ impl TypeResolver { || x == TypeId::NAMED_COMPATIBLE_STRUCT as u32 ); - fn write( - this: &dyn Any, - context: &mut WriteContext, - ref_mode: RefMode, - write_type_info: bool, - has_generics: bool, - ) -> Result<(), Error> { - let this = this.downcast_ref::(); - match this { - Some(v) => T2::fory_write(v, context, ref_mode, write_type_info, has_generics), - None => Err(Error::type_error(format!( - "Cast type to {:?} error when writing: {:?}", - std::any::type_name::(), - T2::fory_static_type_id() - ))), - } - } - - fn read( - context: &mut ReadContext, - ref_mode: RefMode, - read_type_info: bool, - ) -> Result, Error> { - let graph_self_size = T2::fory_graph_self_size(); - if graph_self_size != 0 { - context.reserve_graph_memory(graph_self_size)?; - } - Ok(Box::new(T2::fory_read(context, ref_mode, read_type_info)?)) - } - fn write_data( this: &dyn Any, context: &mut WriteContext, @@ -1025,8 +950,6 @@ impl TypeResolver { } let harness = Harness { - write_fn: write::, - read_fn: read::, write_data_fn: write_data::, read_data_fn: read_data::, read_data_as_send_sync_any_fn: read_data_as_send_sync_any::, @@ -1181,36 +1104,6 @@ impl TypeResolver { ))); } - fn write( - this: &dyn Any, - context: &mut WriteContext, - ref_mode: RefMode, - write_type_info: bool, - has_generics: bool, - ) -> Result<(), Error> { - let this = this.downcast_ref::(); - match this { - Some(v) => v.fory_write(context, ref_mode, write_type_info, has_generics), - None => Err(Error::type_error(format!( - "Cast type to {:?} error when writing: {:?}", - std::any::type_name::(), - T2::fory_static_type_id() - ))), - } - } - - fn read( - context: &mut ReadContext, - ref_mode: RefMode, - read_type_info: bool, - ) -> Result, Error> { - let graph_self_size = T2::fory_graph_self_size(); - if graph_self_size != 0 { - context.reserve_graph_memory(graph_self_size)?; - } - Ok(Box::new(T2::fory_read(context, ref_mode, read_type_info)?)) - } - fn write_data( this: &dyn Any, context: &mut WriteContext, @@ -1272,8 +1165,6 @@ impl TypeResolver { // EXT types don't support fory_read_compatible let harness = Harness { - write_fn: write::, - read_fn: read::, write_data_fn: write_data::, read_data_fn: read_data::, read_data_as_send_sync_any_fn: read_data_as_send_sync_any::, From bb0e0dc82acd4e14534153954d6d0f087e91fef0 Mon Sep 17 00:00:00 2001 From: chaokunyang Date: Fri, 3 Jul 2026 07:07:46 +0800 Subject: [PATCH 38/54] refactor(csharp): simplify union read return --- csharp/src/Fory/UnionSerializer.cs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/csharp/src/Fory/UnionSerializer.cs b/csharp/src/Fory/UnionSerializer.cs index 71a4834d84..2cd549068c 100644 --- a/csharp/src/Fory/UnionSerializer.cs +++ b/csharp/src/Fory/UnionSerializer.cs @@ -67,8 +67,7 @@ public override TUnion ReadData(ReadContext context) caseValue = DynamicAnyCodec.ReadAny(context, RefMode.Tracking, true); } - TUnion value = Factory(caseId, caseValue); - return value; + return Factory(caseId, caseValue); } private static void CheckWireCaseId(int caseId) From 2be1e51023704a6d895da8467dcedfdb335e3359 Mon Sep 17 00:00:00 2001 From: chaokunyang Date: Fri, 3 Jul 2026 07:11:07 +0800 Subject: [PATCH 39/54] refactor(dart): simplify graph budget checks --- .../fory/lib/src/context/read_context.dart | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/dart/packages/fory/lib/src/context/read_context.dart b/dart/packages/fory/lib/src/context/read_context.dart index c978ad2193..f15022dcfd 100644 --- a/dart/packages/fory/lib/src/context/read_context.dart +++ b/dart/packages/fory/lib/src/context/read_context.dart @@ -45,8 +45,6 @@ import 'package:fory/src/types/uint64.dart'; /// deserialization operation. Application code normally interacts with [Fory] /// instead of preparing contexts directly. final class ReadContext { - static const int _maxSafeBudgetBytes = 9007199254740991; - /// Effective runtime configuration for the active operation. final Config config; final TypeResolver _typeResolver; @@ -70,11 +68,7 @@ final class ReadContext { @pragma('vm:prefer-inline') void prepare(Buffer buffer) { _buffer = buffer; - final limit = config.maxGraphMemoryBytes; - if (limit > _maxSafeBudgetBytes) { - _throwGraphMemoryOverflow(limit); - } - _remainingGraphMemoryBytes = limit; + _remainingGraphMemoryBytes = config.maxGraphMemoryBytes; } @internal @@ -98,14 +92,13 @@ final class ReadContext { @internal @pragma('vm:prefer-inline') void reserveGraphMemory(int bytes) { - if (bytes < 0 || bytes > _maxSafeBudgetBytes) { + if (bytes < 0) { _throwGraphMemoryOverflow(bytes); } - final remaining = _remainingGraphMemoryBytes - bytes; - if (remaining < 0) { + if (bytes > _remainingGraphMemoryBytes) { _throwGraphMemoryExceeded(bytes); } - _remainingGraphMemoryBytes = remaining; + _remainingGraphMemoryBytes -= bytes; } @pragma('vm:never-inline') From 6d71b10ef53e3e2ca10e619c543659667aeefd42 Mon Sep 17 00:00:00 2001 From: chaokunyang Date: Fri, 3 Jul 2026 07:14:37 +0800 Subject: [PATCH 40/54] refactor(go): remove graph budget cleanup drift --- go/fory/array.go | 4 ++-- go/fory/buffer.go | 11 ----------- 2 files changed, 2 insertions(+), 13 deletions(-) diff --git a/go/fory/array.go b/go/fory/array.go index f9c917bb1e..dc99f50c1f 100644 --- a/go/fory/array.go +++ b/go/fory/array.go @@ -290,7 +290,7 @@ func (s *arrayConcreteValueSerializer) ReadWithTypeInfo(ctx *ReadContext, refMod // arrayDynSerializer wraps sliceDynSerializer for arrays with interface element types. // It converts arrays to slices and delegates to sliceDynSerializer. type arrayDynSerializer struct { - sliceSerializer *sliceDynSerializer + sliceSerializer sliceDynSerializer } func newArrayDynSerializer(elemType reflect.Type) (arrayDynSerializer, error) { @@ -298,7 +298,7 @@ func newArrayDynSerializer(elemType reflect.Type) (arrayDynSerializer, error) { if err != nil { return arrayDynSerializer{}, err } - return arrayDynSerializer{sliceSerializer: sliceSer}, nil + return arrayDynSerializer{sliceSerializer: *sliceSer}, nil } func (s arrayDynSerializer) WriteData(ctx *WriteContext, value reflect.Value) { diff --git a/go/fory/buffer.go b/go/fory/buffer.go index 1a1a067b7e..89e29f938d 100644 --- a/go/fory/buffer.go +++ b/go/fory/buffer.go @@ -482,17 +482,6 @@ func (b *ByteBuffer) ReaderIndex() int { return b.readerIndex } -func (b *ByteBuffer) readableBytes() int { - end := b.writerIndex - if len(b.data) > end { - end = len(b.data) - } - if b.readerIndex >= end { - return 0 - } - return end - b.readerIndex -} - func (b *ByteBuffer) SetReaderIndex(index int) { b.readerIndex = index } From d97dc71e955663390871beb4c76d94e51d818ddb Mon Sep 17 00:00:00 2001 From: chaokunyang Date: Fri, 3 Jul 2026 07:48:45 +0800 Subject: [PATCH 41/54] refactor: move graph budget root charges to owners --- .agents/languages/cpp.md | 3 +- .agents/languages/csharp.md | 10 +-- .agents/languages/go.md | 5 +- .agents/languages/java.md | 3 +- .agents/languages/rust.md | 4 +- .agents/languages/swift.md | 15 +++-- .agents/repo-reference.md | 5 +- AGENTS.md | 2 +- .../serialization/collection_serializer.h | 55 +++++++++++----- cpp/fory/serialization/fory.h | 9 --- cpp/fory/serialization/map_serializer.h | 31 ++++----- .../serialization/smart_ptr_serializers.h | 8 +++ cpp/fory/serialization/struct_serializer.h | 4 ++ csharp/src/Fory/Fory.cs | 8 +-- csharp/src/Fory/GraphMemory.cs | 10 --- csharp/src/Fory/Serializer.cs | 6 ++ csharp/src/Fory/TypeInfo.cs | 10 +-- docs/security/deserialization.md | 7 +- .../xlang_implementation_guide.md | 21 +++--- go/fory/fory.go | 65 ++++++------------- go/fory/pointer.go | 4 +- go/fory/reader.go | 17 ++--- go/fory/stream.go | 25 +++---- rust/fory-core/src/fory.rs | 5 -- rust/fory-core/src/serializer/core.rs | 4 ++ rust/fory-derive/src/object/read.rs | 4 ++ 26 files changed, 165 insertions(+), 175 deletions(-) diff --git a/.agents/languages/cpp.md b/.agents/languages/cpp.md index 0ead61046e..8ebc82ae70 100644 --- a/.agents/languages/cpp.md +++ b/.agents/languages/cpp.md @@ -21,7 +21,8 @@ Load this file when changing `cpp/`, Cython build plumbing, or C++ xlang behavio `Fory::deserialize` overload. Keep `max_graph_memory_bytes` as a fixed-default graph limit: unset/default is `128 MiB`, positive explicit values override it, and explicit non-positive values are invalid at config creation. Byte and stream roots use the same - configured/default budget behavior. + configured/default budget behavior. Root `Fory` overloads reset the budget only; they must not + pre-reserve root type or root self bytes. Reserve estimated shallow graph-owner memory before allocation while preserving existing byte-availability checks and their non-empty metadata ordering. `ReadContext` may expose only raw byte reservation; collection, map, array, struct, and object diff --git a/.agents/languages/csharp.md b/.agents/languages/csharp.md index 4d82f2518b..3c8b4a14dd 100644 --- a/.agents/languages/csharp.md +++ b/.agents/languages/csharp.md @@ -14,6 +14,7 @@ Load this file when changing `csharp/` or C# xlang behavior. - Compatible scalar, list-array, and binary/uint8-array adaptations are immediate-field-only. Recursive matched-field comparison for collection elements, array elements, map keys, and map values must require exact nullability, ref tracking, generic arity, and type shape except documented user-type family normalization. - Root deserialization graph memory budget state belongs to `ReadContext`. C# public roots are memory-backed today, but the graph budget uses the same fixed default for every root shape. + Root APIs reset the budget only; they must not pre-reserve root type or root self bytes. `ReadContext` may expose only raw byte reservation; concrete serializers and generated serializers must compute list, array, map, struct, and object byte formulas before calling it. - `ReadContext` must not expose ref-publication pause/resume APIs or any non-budget owner @@ -22,10 +23,11 @@ Load this file when changing `csharp/` or C# xlang behavior. - For C# graph budget formulas, distinguish inline value storage from reference storage: use cheap value-type size for `List`/`T[]` value paths and the 4-byte reference fallback for reference paths. Class/reference serializers reserve their own shallow self cost plus field storage when - materialized; struct/value serializers do not unconditionally charge self storage because root, - field, list, array, map, set, or box owners reserve the inline storage they own. Maps reserve key - plus value storage, linked/hash/tree conversions must not add guessed node or entry overhead, and - independently materialized collection/map/array owners reserve nonzero shallow self cost. + materialized; struct/value serializers reserve self storage only on standalone serializer, + dynamic/boxed, or root materialization entries because field, list, array, map, set, and box + holders reserve the inline storage they own. Maps reserve key plus value storage, linked/hash/tree + conversions must not add guessed node or entry overhead, and independently materialized + collection/map/array owners reserve nonzero shallow self cost. Dedicated string, binary, primitive scalar, and primitive dense-array serializers stay skipped and rely on byte availability checks. - When extending C# tests from Java references, prioritize xlang spec behavior and the public C# contract before adding complex Java-specific parity cases. diff --git a/.agents/languages/go.md b/.agents/languages/go.md index afce10d12e..c0c00716c5 100644 --- a/.agents/languages/go.md +++ b/.agents/languages/go.md @@ -11,12 +11,13 @@ Load this file when changing `go/fory/` or Go xlang behavior. `WithMaxGraphMemoryBytes` uses a fixed `128 MiB` default; positive explicit values override it, and explicit non-positive values are invalid at config creation. Byte-slice and stream roots use the same - configured/default budget behavior. `ReadContext` may expose only raw byte + configured/default budget behavior. Root APIs reset the budget only; they must not pre-reserve + root type or root self bytes. `ReadContext` may expose only raw byte reservation; slice, map, array, struct, and object formulas belong in handwritten or generated serializer owners. Reserve Go slices as `len * elemBytes`, maps as `len * (keyBytes + valueBytes)`, map-backed sets, and LIST-encoded inline/value slices in the owner that - allocates that storage. Root struct owners, pointer allocations, and generated + allocates that storage. Struct root materialization paths, pointer allocations, and generated allocation entries reserve shallow value storage exactly once; nested inline struct serializers do not charge their own self storage again. Fixed arrays are caller-owned unless a read path materializes a temporary owner. Skip diff --git a/.agents/languages/java.md b/.agents/languages/java.md index 6be988e2da..9c9f0f3a30 100644 --- a/.agents/languages/java.md +++ b/.agents/languages/java.md @@ -19,7 +19,8 @@ Load this file when changing anything under `java/` or when Java drives a cross- with fixed `128 MiB` default. Positive explicit values override the default; explicit non-positive values are invalid and must be rejected at config creation. Byte-array, memory-buffer, and stream roots use the same configured/default - budget behavior. `ReadContext` + budget behavior. Root APIs reset the budget only; they must not pre-reserve + root type or root self bytes. `ReadContext` may expose only raw byte reservation; collection, map, array, struct, and object formulas belong in the concrete serializer or generated serializer owner. Java collection, map, and diff --git a/.agents/languages/rust.md b/.agents/languages/rust.md index 79458c5126..29b14032a5 100644 --- a/.agents/languages/rust.md +++ b/.agents/languages/rust.md @@ -21,7 +21,9 @@ Load this file when changing `rust/` or Rust xlang behavior. - Root deserialization graph memory budget state belongs to `ReadContext` and is initialized by the root `Fory` read methods before the header is consumed. Use the fixed `128 MiB` default unless a positive explicit value overrides it; explicit non-positive values are invalid at config - creation. Do not derive the budget from root input size or add dynamic bytes-read accounting. + creation. Root `Fory` read methods reset the budget only; they must not pre-reserve root type or + root self bytes. Do not derive the budget from root input size or add dynamic bytes-read + accounting. `ReadContext` may expose only raw byte reservation; `Vec`, collection, map, array, struct, object, and derive codec formulas belong in their serializer owners. diff --git a/.agents/languages/swift.md b/.agents/languages/swift.md index 8371b107f0..2b73998a52 100644 --- a/.agents/languages/swift.md +++ b/.agents/languages/swift.md @@ -19,16 +19,17 @@ Load this file when changing `swift/` or Swift xlang behavior. - Compatible scalar, list-array, and binary/uint8-array adaptations are immediate-field-only. Recursive matched-field comparison for collection elements, array elements, map keys, and map values must require exact nullability, ref tracking, generic arity, and type shape except documented user-type family normalization. - Root deserialization graph memory budget state belongs to `ReadContext`. Swift public roots are `Data` and `ByteBuffer`, and both use the same fixed default graph budget; do not add stream - bytes-read accounting or serializer-local budget state. `ReadContext` may expose only raw byte - reservation; array, set, map, struct, and object formulas belong in serializer and field-codec - owners. + bytes-read accounting or serializer-local budget state. Root APIs reset the budget only; they must + not pre-reserve root type or root self bytes. `ReadContext` may expose only raw byte reservation; + array, set, map, struct, and object formulas belong in serializer and field-codec owners. - For Swift graph budget formulas, distinguish inline/value storage from reference storage: use `MemoryLayout.stride` for value arrays/lists/sets/maps and the 4-byte reference fallback for `Serializer.isRefType` / `FieldCodec.isRefType` paths. Class/reference paths reserve their own - shallow self cost plus field storage when materialized; value serializers do not unconditionally - charge self storage because root, field, array, set, map, box, or generated owners reserve inline - storage exactly once. Independently materialized collection/map/array owners reserve nonzero - shallow self cost plus backing/reference/inline storage. Dedicated `String`, `Data`/binary, + shallow self cost plus field storage when materialized; value serializers reserve self storage + only on standalone serializer, generated, or root materialization entries because field, array, + set, map, and box holders reserve inline storage exactly once. Independently materialized + collection/map/array owners reserve nonzero shallow self cost plus backing/reference/inline + storage. Dedicated `String`, `Data`/binary, primitive scalar, and primitive packed-array owners stay skipped, except compatible packed-array-to-list reads must reserve the target list materialization before allocation. diff --git a/.agents/repo-reference.md b/.agents/repo-reference.md index 3934449a9d..f110bf9bb3 100644 --- a/.agents/repo-reference.md +++ b/.agents/repo-reference.md @@ -84,10 +84,11 @@ Apache Fory is a multi-language serialization framework with multiple wire forma Root graph memory budgeting is a read-state accounting feature only. Read context or equivalent read state may expose raw byte reservation and, when a runtime cannot reasonably avoid it, -root-operation budget setup/reset. It must not grow semantic APIs for collection, map, array, +root-operation budget setup/reset. Root facades may reset the per-operation budget, but must not +pre-reserve root type or root self bytes. It must not grow semantic APIs for collection, map, array, struct, object, temporary-owner, serializer-owner, conversion, counted-allocation, or ref-publication control. Concrete serializers and generated serializers own allocation formulas, -overflow checks, and reference publication timing. +overflow checks, root value storage reservation, and reference publication timing. ## Runtime Map diff --git a/AGENTS.md b/AGENTS.md index fd0ccf815e..a6e1d9a915 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -33,7 +33,7 @@ This is the entry point for AI guidance in Apache Fory. Read this file first, th - Check the spec before implementation. For wire behavior and xlang mapping, use the specs as the source of truth and never copy one runtime's bug into another runtime just to make tests pass. - Do not make assumptions about runtime behavior, ownership, registration, metadata construction, protocol semantics, or test coverage. Read the current code, owning docs/specs, and relevant tests before making a design judgment or implementation decision. If the evidence is incomplete, inspect more or state the uncertainty explicitly instead of filling gaps from memory or analogy with another runtime. - For untrusted deserialization, read `docs/security/deserialization.md` before changing allocation, stream filling, skip, reference, metadata, or policy validation behavior. Variable-length deserialization must not allocate or reserve backing/output capacity from attacker-declared lengths or counts before the byte owner has proven proportional readable bytes with `checkReadableBytes` or the runtime equivalent. Root graph memory reservation is accounting only and may happen before that byte check, but it must not replace the byte check. -- Root deserialization graph memory budgets estimate lower-bound shallow memory for materialized graph owners, not exact heap accounting, input byte accounting, or raw element counts. `maxGraphMemoryBytes` defaults to fixed `128 MiB`; positive values override the default; explicit non-positive values are invalid and must be rejected at config/Fory creation. Do not add a disabled-budget sentinel path, derive this budget from root input size, or split known-length and stream root behavior. Read context/read state owns only raw byte reservation with `reserveGraphMemory(bytes)`; it must not expose counted arithmetic helpers or collection, map, array, struct, or object semantic reservation APIs. Do not add any non-memory-budget read-context/read-state API for this feature, including ref-publication controls, temporary-owner controls, serializer-owner controls, conversion helpers, or APIs that encode what kind of value is being materialized. Concrete serializers and generated serializers own counted formulas, overflow checks, allocation-owner decisions, and reference publication timing for their allocation path. Reserve self storage exactly once at the owner that stores or allocates the value: reference/object runtimes reserve parent owner self cost plus reference storage and every referenced heap owner reserves its own shallow self cost when materialized; inline/value runtimes reserve value storage in the root, struct/product, collection, map, set, array, smart-pointer, or box owner that actually owns that storage, and nested value serializers do not unconditionally reserve their own `sizeof(T)` again. Collection, map, set, and reference-array owners reserve nonzero shallow self cost when independently materialized, plus backing/reference/inline storage. Struct, record, POJO, tuple/product, compatible, generated, and dynamic object owners reserve a nonzero shallow self cost plus shallow field storage. Reference fields use 4 bytes when reference size is not cheap or reliable to query; primitive/value fields use their storage width. Parents do not recursively include child object, collection, map, string, binary, or primitive dense-array contents. Skip enum/union as separate owners and skip dedicated string, binary, primitive scalar, primitive array, and primitive dense-array leaf owners. Do not guess allocator, bucket-table, node, debug, or map-entry overhead unless it is a documented lower-bound owner allocation. Do not add dynamic stream bytes-read accounting or nested hot-path cleanup just for this budget. +- Root deserialization graph memory budgets estimate lower-bound shallow memory for materialized graph owners, not exact heap accounting, input byte accounting, or raw element counts. `maxGraphMemoryBytes` defaults to fixed `128 MiB`; positive values override the default; explicit non-positive values are invalid and must be rejected at config/Fory creation. Do not add a disabled-budget sentinel path, derive this budget from root input size, or split known-length and stream root behavior. Read context/read state owns only raw byte reservation with `reserveGraphMemory(bytes)`; it must not expose counted arithmetic helpers or collection, map, array, struct, or object semantic reservation APIs. Do not add any non-memory-budget read-context/read-state API for this feature, including ref-publication controls, temporary-owner controls, serializer-owner controls, conversion helpers, or APIs that encode what kind of value is being materialized. Root facades may set/reset the per-operation budget, but they must not pre-reserve root type or root self bytes; the serializer or materialization owner reserves root value storage. Concrete serializers and generated serializers own counted formulas, overflow checks, allocation-owner decisions, and reference publication timing for their allocation path. Reserve self storage exactly once at the owner that stores or allocates the value: reference/object runtimes reserve parent owner self cost plus reference storage and every referenced heap owner reserves its own shallow self cost when materialized; inline/value runtimes reserve value storage in the struct/product, collection, map, set, array, smart-pointer, box, or root materialization owner that actually owns that storage, and nested value serializers do not unconditionally reserve their own `sizeof(T)` again. Collection, map, set, and reference-array owners reserve nonzero shallow self cost when independently materialized, plus backing/reference/inline storage. Struct, record, POJO, tuple/product, compatible, generated, and dynamic object owners reserve a nonzero shallow self cost plus shallow field storage. Reference fields use 4 bytes when reference size is not cheap or reliable to query; primitive/value fields use their storage width. Parents do not recursively include child object, collection, map, string, binary, or primitive dense-array contents. Skip enum/union as separate owners and skip dedicated string, binary, primitive scalar, primitive array, and primitive dense-array leaf owners. Do not guess allocator, bucket-table, node, debug, or map-entry overhead unless it is a documented lower-bound owner allocation. Do not add dynamic stream bytes-read accounting or nested hot-path cleanup just for this budget. - For remote TypeDef/TypeMeta reads, the checked metadata cache is the only owner of remote "already validated" state. Cache hit means the header was previously parsed, body/hash-validated, policy-checked, and published by that cache, so the hot path must skip the body and use cached metadata without extra validation, hashing, limit checks, exact-local checks, allocation, or policy work. A known expected local TypeDef/TypeMeta header/hash match is a local-schema hit, not a remote cache miss: it may skip the body and use the local TypeInfo/TypeMeta without schema-version counting or cache publish. Cache miss is the only path that parses and validates non-local metadata, enforces limits, performs exact-local byte comparison when needed, and publishes remote metadata to the cache. Do not add nullable accepted-header fields, sentinel headers, per-TypeInfo markers, pending metadata state, parallel header-low/header-high slots, or parallel acceptance state for this decision. If a runtime needs a metadata hit hint, cache the concrete checked metadata owner object, such as the TypeInfo, TypeDef, or TypeMeta used by that runtime, and compare its validated header identity directly. - When a user corrects a non-obvious invariant, encode it in the nearest source comment before continuing, and also update `AGENTS.md`, `.agents/**`, docs, or specs when the rule is reusable beyond one file. Do not rely only on chat history, task notes, commit messages, or benchmark logs for corrections that protect security, protocol behavior, ownership, naming, or hot-path performance. - Reject semantic hacks. Do not bypass broken semantics by deleting cases, simplifying callers, adding coercion hooks, or using workaround fallbacks; fix the underlying bug and prove it with focused tests. diff --git a/cpp/fory/serialization/collection_serializer.h b/cpp/fory/serialization/collection_serializer.h index d1d28979d9..c54ca9cbb9 100644 --- a/cpp/fory/serialization/collection_serializer.h +++ b/cpp/fory/serialization/collection_serializer.h @@ -392,22 +392,15 @@ constexpr size_t collection_element_memory_bytes() { template inline bool reserve_collection_storage(ReadContext &ctx, uint32_t length) { - constexpr size_t kMaxLength = - static_cast(std::numeric_limits::max()); - if constexpr (elem_bytes <= std::numeric_limits::max() / kMaxLength) { - return ctx.reserve_graph_memory(static_cast(length) * elem_bytes); - } else { - if (FORY_PREDICT_FALSE(elem_bytes != 0 && - static_cast(length) > - std::numeric_limits::max() / - elem_bytes)) { - ctx.set_error(Error::invalid_data( - "graph memory estimate overflows: length=" + std::to_string(length) + - " elementBytes=" + std::to_string(elem_bytes))); - return false; - } - return ctx.reserve_graph_memory(static_cast(length) * elem_bytes); + if (FORY_PREDICT_FALSE(elem_bytes != 0 && + static_cast(length) > + std::numeric_limits::max() / elem_bytes)) { + ctx.set_error(Error::invalid_data( + "graph memory estimate overflows: length=" + std::to_string(length) + + " elementBytes=" + std::to_string(elem_bytes))); + return false; } + return ctx.reserve_graph_memory(static_cast(length) * elem_bytes); } template @@ -1079,6 +1072,11 @@ struct Serializer< if (ctx.has_error() || !has_value) { return std::vector(); } + if (ref_mode != RefMode::None && + FORY_PREDICT_FALSE( + (!reserve_allocated_value_owner>(ctx)))) { + return std::vector(); + } // Optional type info for polymorphic containers if (read_type) { @@ -1373,6 +1371,11 @@ template struct Serializer> { if (ctx.has_error() || !has_value) { return std::list(); } + if (ref_mode != RefMode::None && + FORY_PREDICT_FALSE( + (!reserve_allocated_value_owner>(ctx)))) { + return std::list(); + } // Optional type info for polymorphic containers if (read_type) { @@ -1573,6 +1576,11 @@ template struct Serializer> { if (ctx.has_error() || !has_value) { return std::deque(); } + if (ref_mode != RefMode::None && + FORY_PREDICT_FALSE( + (!reserve_allocated_value_owner>(ctx)))) { + return std::deque(); + } // Optional type info for polymorphic containers if (read_type) { @@ -1774,6 +1782,12 @@ struct Serializer> { if (ctx.has_error() || !has_value) { return std::forward_list(); } + if (ref_mode != RefMode::None && + FORY_PREDICT_FALSE( + (!reserve_allocated_value_owner>( + ctx)))) { + return std::forward_list(); + } // Optional type info for polymorphic containers if (read_type) { @@ -2233,6 +2247,11 @@ struct Serializer> { if (ctx.has_error() || !has_value) { return std::set(); } + if (ref_mode != RefMode::None && + FORY_PREDICT_FALSE( + (!reserve_allocated_value_owner>(ctx)))) { + return std::set(); + } // Read type info if (read_type) { @@ -2417,6 +2436,12 @@ struct Serializer> { if (ctx.has_error() || !has_value) { return std::unordered_set(); } + if (ref_mode != RefMode::None && + FORY_PREDICT_FALSE( + (!reserve_allocated_value_owner>( + ctx)))) { + return std::unordered_set(); + } // Read type info if (read_type) { diff --git a/cpp/fory/serialization/fory.h b/cpp/fory/serialization/fory.h index ed8541895b..13e52b2eb1 100644 --- a/cpp/fory/serialization/fory.h +++ b/cpp/fory/serialization/fory.h @@ -889,15 +889,6 @@ class Fory : public BaseFory { read_ctx_->attach(buffer); read_ctx_->remaining_graph_memory_bytes_ = read_ctx_->graph_memory_limit_bytes_; - if constexpr (needs_graph_budget_v) { - constexpr size_t root_owner_bytes = graph_value_owner_self_bytes(); - if constexpr (root_owner_bytes != 0) { - if (FORY_PREDICT_FALSE( - !read_ctx_->reserve_graph_memory(root_owner_bytes))) { - return Unexpected(read_ctx_->take_error()); - } - } - } ReadContextGuard guard(*read_ctx_); return deserialize_impl(buffer); } diff --git a/cpp/fory/serialization/map_serializer.h b/cpp/fory/serialization/map_serializer.h index 44ea9bc8b9..28db52cc53 100644 --- a/cpp/fory/serialization/map_serializer.h +++ b/cpp/fory/serialization/map_serializer.h @@ -84,22 +84,15 @@ struct MapReserver inline bool reserve_map_storage(ReadContext &ctx, uint32_t length) { - constexpr size_t kMaxLength = - static_cast(std::numeric_limits::max()); - if constexpr (elem_bytes <= std::numeric_limits::max() / kMaxLength) { - return ctx.reserve_graph_memory(static_cast(length) * elem_bytes); - } else { - if (FORY_PREDICT_FALSE(elem_bytes != 0 && - static_cast(length) > - std::numeric_limits::max() / - elem_bytes)) { - ctx.set_error(Error::invalid_data( - "graph memory estimate overflows: length=" + std::to_string(length) + - " elementBytes=" + std::to_string(elem_bytes))); - return false; - } - return ctx.reserve_graph_memory(static_cast(length) * elem_bytes); + if (FORY_PREDICT_FALSE(elem_bytes != 0 && + static_cast(length) > + std::numeric_limits::max() / elem_bytes)) { + ctx.set_error(Error::invalid_data( + "graph memory estimate overflows: length=" + std::to_string(length) + + " elementBytes=" + std::to_string(elem_bytes))); + return false; } + return ctx.reserve_graph_memory(static_cast(length) * elem_bytes); } template @@ -1051,6 +1044,10 @@ struct Serializer> { if (!has_value) { return MapType{}; } + if (ref_mode != RefMode::None && + FORY_PREDICT_FALSE(!reserve_allocated_value_owner(ctx))) { + return MapType{}; + } if (read_type) { uint32_t type_id_read = ctx.read_uint8(ctx.error()); @@ -1159,6 +1156,10 @@ struct Serializer> { if (!has_value) { return MapType{}; } + if (ref_mode != RefMode::None && + FORY_PREDICT_FALSE(!reserve_allocated_value_owner(ctx))) { + return MapType{}; + } if (read_type) { uint32_t type_id_read = ctx.read_uint8(ctx.error()); diff --git a/cpp/fory/serialization/smart_ptr_serializers.h b/cpp/fory/serialization/smart_ptr_serializers.h index 6796fb2e4a..91e506325a 100644 --- a/cpp/fory/serialization/smart_ptr_serializers.h +++ b/cpp/fory/serialization/smart_ptr_serializers.h @@ -511,6 +511,10 @@ template struct Serializer> { } reserved_ref_id = ctx.ref_reader().reserve_ref_id(); } + if (FORY_PREDICT_FALSE( + !ctx.reserve_graph_memory(sizeof(std::shared_ptr)))) { + return nullptr; + } // For polymorphic types, read type info AFTER handling ref flags if constexpr (is_polymorphic) { @@ -961,6 +965,10 @@ template struct Serializer> { std::to_string(static_cast(flag)))); return nullptr; } + if (FORY_PREDICT_FALSE( + !ctx.reserve_graph_memory(sizeof(std::unique_ptr)))) { + return nullptr; + } // For polymorphic types, read type info AFTER handling ref flags if constexpr (is_polymorphic) { diff --git a/cpp/fory/serialization/struct_serializer.h b/cpp/fory/serialization/struct_serializer.h index a438a75c4e..f2c3488031 100644 --- a/cpp/fory/serialization/struct_serializer.h +++ b/cpp/fory/serialization/struct_serializer.h @@ -4584,6 +4584,10 @@ struct Serializer>> { if (ctx.track_ref() && ref_flag == ref_value_flag) { ctx.ref_reader().reserve_ref_id(); } + if (ref_mode != RefMode::None && + FORY_PREDICT_FALSE(!reserve_allocated_value_owner(ctx))) { + return T{}; + } // In compatible mode: use meta sharing (matches Rust behavior) if (ctx.is_compatible()) { // In compatible mode: always use remote TypeMeta for schema evolution diff --git a/csharp/src/Fory/Fory.cs b/csharp/src/Fory/Fory.cs index 9ac66fe5d8..54387c5e87 100644 --- a/csharp/src/Fory/Fory.cs +++ b/csharp/src/Fory/Fory.cs @@ -290,11 +290,6 @@ private T DeserializeFromReader(ByteReader reader) Serializer serializer = _typeResolver.GetSerializer(); ReadContext readContext = _readContext; readContext.ResetFor(reader); - if (typeof(T).IsValueType) - { - GraphMemory.ReserveRootValue(readContext); - } - T value = _trackRef ? serializer.Read(readContext, RefMode.Tracking, true) : ReadRootNoRef(serializer, readContext); @@ -322,8 +317,7 @@ private static T ReadRootNoRef(Serializer serializer, ReadContext context) RefFlag flag = (RefFlag)context.Reader.ReadInt8(); if (flag == RefFlag.NotNullValue) { - context.TypeResolver.ReadTypeInfo(serializer, context); - return serializer.ReadData(context); + return serializer.Read(context, RefMode.None, true); } if (flag == RefFlag.Null) diff --git a/csharp/src/Fory/GraphMemory.cs b/csharp/src/Fory/GraphMemory.cs index 3997388798..e1cd3820c9 100644 --- a/csharp/src/Fory/GraphMemory.cs +++ b/csharp/src/Fory/GraphMemory.cs @@ -24,16 +24,6 @@ internal static class GraphMemory [MethodImpl(MethodImplOptions.AggressiveInlining)] internal static long ValueOwnerBytes() => ValueOwner.Bytes; - [MethodImpl(MethodImplOptions.AggressiveInlining)] - internal static void ReserveRootValue(ReadContext context) - { - long bytes = ValueOwner.Bytes; - if (bytes != 0) - { - context.ReserveGraphMemory(bytes); - } - } - private static class ValueOwner { internal static readonly long Bytes = Compute(); diff --git a/csharp/src/Fory/Serializer.cs b/csharp/src/Fory/Serializer.cs index 33e874208e..40d027a49f 100644 --- a/csharp/src/Fory/Serializer.cs +++ b/csharp/src/Fory/Serializer.cs @@ -136,6 +136,12 @@ public virtual T Read(ReadContext context, RefMode refMode, bool readTypeInfo) context.TypeResolver.ReadTypeInfo(this, context); } + long graphBytes = GraphMemory.ValueOwnerBytes(); + if (graphBytes != 0) + { + context.ReserveGraphMemory(graphBytes); + } + return ReadData(context); } } diff --git a/csharp/src/Fory/TypeInfo.cs b/csharp/src/Fory/TypeInfo.cs index 0b9a5f8630..6969eb3573 100644 --- a/csharp/src/Fory/TypeInfo.cs +++ b/csharp/src/Fory/TypeInfo.cs @@ -122,7 +122,7 @@ internal static TypeInfo Create(Type type, Serializer serializer) context => ReadDataObject(serializer, context, boxedValueBytes), (context, value, refMode, writeTypeInfo, hasGenerics) => WriteObject(serializer, context, value, refMode, writeTypeInfo, hasGenerics), - (context, refMode, readTypeInfo) => ReadObject(serializer, context, refMode, readTypeInfo, boxedValueBytes), + (context, refMode, readTypeInfo) => ReadObject(serializer, context, refMode, readTypeInfo), typeMetaFields, builtInTypeId, null); @@ -204,14 +204,8 @@ private static void WriteObject( Serializer serializer, ReadContext context, RefMode refMode, - bool readTypeInfo, - long boxedValueBytes) + bool readTypeInfo) { - if (boxedValueBytes != 0) - { - context.ReserveGraphMemory(boxedValueBytes); - } - return serializer.Read(context, refMode, readTypeInfo); } diff --git a/docs/security/deserialization.md b/docs/security/deserialization.md index 39edb7af54..0f615c9712 100644 --- a/docs/security/deserialization.md +++ b/docs/security/deserialization.md @@ -234,9 +234,10 @@ Graph budget accounting should: array, and primitive dense-array leaf owners. Each runtime must inspect the concrete owner path before choosing formulas. Reserve self storage -exactly once at the owner that stores or allocates the value. Reference-backed paths reserve parent -owner self cost plus reference storage, while each referenced heap owner reserves its own shallow -self cost when materialized. Inline/value paths reserve inline element, field, or root storage in the +exactly once at the owner that stores or allocates the value. Root facades may reset the budget, but +must not pre-reserve root type or root self bytes. Reference-backed paths reserve parent owner self +cost plus reference storage, while each referenced heap owner reserves its own shallow self cost when +materialized. Inline/value paths reserve inline element, field, or root value storage in the holder/allocation owner; nested value serializers must not charge their own self storage again. Parents must not recursively include child object, collection, map, string, binary, or primitive dense-array contents; the child owner reserves its own shallow memory when it is materialized. diff --git a/docs/specification/xlang_implementation_guide.md b/docs/specification/xlang_implementation_guide.md index dd0f81bc0b..bc05252720 100644 --- a/docs/specification/xlang_implementation_guide.md +++ b/docs/specification/xlang_implementation_guide.md @@ -399,13 +399,14 @@ the container/map serializer and should be validated only when they protect a real owner invariant. Materializing readers should also reserve a root-operation estimated graph -memory budget before allocation or size hinting. The budget belongs to -`ReadContext` or the equivalent root read state, not to serializers and not to -ambient thread-local state. `maxGraphMemoryBytes` defaults to a fixed `128 MiB`; -positive configuration overrides the default; explicit non-positive -configuration is invalid and must be rejected when the runtime is created. Do -not derive this budget from root input size, and do not add dynamic stream -bytes-read accounting for this budget. +memory budget before allocation or size hinting. The budget state belongs to +`ReadContext` or the equivalent root read state, not to ambient thread-local +state. Root facades set or reset the per-operation budget only; they must not +pre-reserve root type or root self bytes. `maxGraphMemoryBytes` defaults to a +fixed `128 MiB`; positive configuration overrides the default; explicit +non-positive configuration is invalid and must be rejected when the runtime is +created. Do not derive this budget from root input size, and do not add dynamic +stream bytes-read accounting for this budget. Read context or equivalent read state owns only raw byte reservation. It must not expose counted arithmetic helpers or collection, map, array, struct, or @@ -423,9 +424,9 @@ or allocates the value. Reference-backed containers, maps, sets, and object/reference arrays reserve nonzero owner self cost plus reference slots; each referenced heap owner then reserves its own shallow self cost when materialized. Inline/value containers reserve element storage; inline/value maps -reserve key plus value storage; root/product/box owners reserve value self -storage; and nested value serializers reserve only additional dynamic storage -they allocate. Struct/record/POJO/tuple, compatible, generated, and dynamic +reserve key plus value storage; root materialization, product, and box owners +reserve value self storage; and nested value serializers reserve only additional +dynamic storage they allocate. Struct/record/POJO/tuple, compatible, generated, and dynamic object owners reserve a nonzero shallow self cost plus shallow field storage. Parents must not recursively include child object, collection, map, string, binary, or primitive dense-array contents. Skip enum/union as separate owners and diff --git a/go/fory/fory.go b/go/fory/fory.go index d9613eeb34..e7f8f270f6 100644 --- a/go/fory/fory.go +++ b/go/fory/fory.go @@ -195,9 +195,6 @@ type Fory struct { // Resolvers shared between contexts typeResolver *TypeResolver refResolver *RefResolver - - rootGraphType reflect.Type - rootGraphBytes int64 } // New creates a new Fory instance with the given options @@ -572,15 +569,9 @@ func (f *Fory) Deserialize(data []byte, v any) error { defer f.resetReadState() f.readCtx.SetData(data) target := reflect.ValueOf(v).Elem() - targetType := target.Type() limit := f.config.MaxGraphMemoryBytes f.readCtx.graphMemoryLimitBytes = limit f.readCtx.remainingGraphMemoryBytes = limit - if bytes, ok := f.rootGraphBytesFor(targetType); ok && bytes > 0 { - if !f.readCtx.ReserveGraphMemory(bytes) { - return f.readCtx.TakeError() - } - } readHeader(f.readCtx) if f.readCtx.HasError() { @@ -589,7 +580,11 @@ func (f *Fory) Deserialize(data []byte, v any) error { // Root writes include type metadata, so keep the root ReadValue path. // Calling a cached serializer directly would read that metadata byte as payload. - f.readCtx.ReadValue(target, RefModeTracking, true) + if target.Kind() == reflect.Struct && target.Type() != dateReflectType && target.Type() != timeReflectType && target.Type() != decimalType { + f.readCtx.ReadStruct(target) + } else { + f.readCtx.ReadValue(target, RefModeTracking, true) + } if f.readCtx.HasError() { return f.readCtx.TakeError() } @@ -674,16 +669,9 @@ func (f *Fory) DeserializeFrom(buf *ByteBuffer, v any) error { origBuffer := f.readCtx.buffer f.readCtx.buffer = buf target := reflect.ValueOf(v).Elem() - targetType := target.Type() limit := f.config.MaxGraphMemoryBytes f.readCtx.graphMemoryLimitBytes = limit f.readCtx.remainingGraphMemoryBytes = limit - if bytes, ok := f.rootGraphBytesFor(targetType); ok && bytes > 0 { - if !f.readCtx.ReserveGraphMemory(bytes) { - f.readCtx.buffer = origBuffer - return f.readCtx.TakeError() - } - } readHeader(f.readCtx) if f.readCtx.HasError() { @@ -692,7 +680,11 @@ func (f *Fory) DeserializeFrom(buf *ByteBuffer, v any) error { } // Deserialize the value - TypeMeta is read inline using streaming protocol - f.readCtx.ReadValue(target, RefModeTracking, true) + if target.Kind() == reflect.Struct && target.Type() != dateReflectType && target.Type() != timeReflectType && target.Type() != decimalType { + f.readCtx.ReadStruct(target) + } else { + f.readCtx.ReadValue(target, RefModeTracking, true) + } if f.readCtx.HasError() { f.readCtx.buffer = origBuffer return f.readCtx.TakeError() @@ -796,15 +788,9 @@ func (f *Fory) DeserializeWithCallbackBuffers(buffer *ByteBuffer, v any, buffers } target := rv.Elem() - targetType := target.Type() limit := f.config.MaxGraphMemoryBytes f.readCtx.graphMemoryLimitBytes = limit f.readCtx.remainingGraphMemoryBytes = limit - if bytes, ok := f.rootGraphBytesFor(targetType); ok && bytes > 0 { - if !f.readCtx.ReserveGraphMemory(bytes) { - return f.readCtx.TakeError() - } - } // ReadData and validate header readHeader(f.readCtx) @@ -813,7 +799,11 @@ func (f *Fory) DeserializeWithCallbackBuffers(buffer *ByteBuffer, v any, buffers } // Deserialize the value - TypeMeta is read inline using streaming protocol - f.readCtx.ReadValue(target, RefModeTracking, true) + if target.Kind() == reflect.Struct && target.Type() != dateReflectType && target.Type() != timeReflectType && target.Type() != decimalType { + f.readCtx.ReadStruct(target) + } else { + f.readCtx.ReadValue(target, RefModeTracking, true) + } if f.readCtx.HasError() { return f.readCtx.TakeError() } @@ -1074,11 +1064,6 @@ func Deserialize[T any](f *Fory, data []byte, target *T) error { default: targetVal = reflect.ValueOf(target).Elem() targetType = targetVal.Type() - if bytes, ok := f.rootGraphBytesFor(targetType); ok && bytes > 0 { - if !f.readCtx.ReserveGraphMemory(bytes) { - return f.readCtx.TakeError() - } - } } // ReadData and validate header @@ -1222,6 +1207,10 @@ func Deserialize[T any](f *Fory, data []byte, target *T) error { targetVal = reflect.ValueOf(target).Elem() targetType = targetVal.Type() } + if targetType.Kind() == reflect.Struct && targetType != dateReflectType && targetType != timeReflectType && targetType != decimalType { + f.readCtx.ReadStruct(targetVal) + return f.readCtx.CheckError() + } // Get serializer for the target type serializer, err := f.typeResolver.getSerializerByType(targetType, false) @@ -1234,19 +1223,3 @@ func Deserialize[T any](f *Fory, data []byte, target *T) error { return f.readCtx.CheckError() } } - -func (f *Fory) rootGraphBytesFor(targetType reflect.Type) (int64, bool) { - if targetType == nil || targetType.Kind() != reflect.Struct { - return 0, false - } - if targetType == dateReflectType || targetType == timeReflectType { - return 0, true - } - if targetType == f.rootGraphType { - return f.rootGraphBytes, true - } - bytes := structGraphBytes(targetType) - f.rootGraphType = targetType - f.rootGraphBytes = bytes - return bytes, true -} diff --git a/go/fory/pointer.go b/go/fory/pointer.go index 5180deb627..199ea54e0f 100644 --- a/go/fory/pointer.go +++ b/go/fory/pointer.go @@ -140,7 +140,7 @@ func (s *ptrToValueSerializer) ReadData(ctx *ReadContext, value reflect.Value) { var newVal reflect.Value if value.IsNil() { // Allocate new value - if !reserveStructGraph(ctx, value.Type().Elem()) { + if !ctx.ReserveGraphMemory(structGraphBytes(value.Type().Elem())) { return } newVal = reflect.New(value.Type().Elem()) @@ -198,7 +198,7 @@ func (s *ptrToValueSerializer) Read(ctx *ReadContext, refMode RefMode, readType if structSer, ok := typeInfo.Serializer.(*structSerializer); ok && len(structSer.fieldDefs) > 0 { // Allocate the pointer value if needed if value.IsNil() { - if !reserveStructGraph(ctx, value.Type().Elem()) { + if !ctx.ReserveGraphMemory(structGraphBytes(value.Type().Elem())) { return } value.Set(reflect.New(value.Type().Elem())) diff --git a/go/fory/reader.go b/go/fory/reader.go index 2821e6b42b..c6f0eca1a2 100644 --- a/go/fory/reader.go +++ b/go/fory/reader.go @@ -75,14 +75,6 @@ func structGraphBytes(type_ reflect.Type) int64 { return bytes } -func reserveStructGraph(ctx *ReadContext, type_ reflect.Type) bool { - bytes := structGraphBytes(type_) - if bytes == 0 { - return true - } - return ctx.ReserveGraphMemory(bytes) -} - // IsXlang returns whether cross-language serialization mode is enabled func (c *ReadContext) IsXlang() bool { return c.xlang @@ -901,13 +893,13 @@ func (c *ReadContext) ReadValue(value reflect.Value, refMode RefMode, readType b } else if isNamedStruct { // For named struct types, create a pointer to support circular references // Create *A instead of A - if !reserveStructGraph(c, actualType) { + if !c.ReserveGraphMemory(structGraphBytes(actualType)) { return } newValue = reflect.New(actualType) valueToSet = newValue } else { - if !reserveStructGraph(c, actualType) { + if !c.ReserveGraphMemory(structGraphBytes(actualType)) { return } newValue = reflect.New(actualType).Elem() @@ -1041,7 +1033,7 @@ func (c *ReadContext) ReadStruct(value reflect.Value) { var readTarget reflect.Value if isPtr { if value.IsNil() { - if !reserveStructGraph(c, structType) { + if !c.ReserveGraphMemory(structGraphBytes(structType)) { return } value.Set(reflect.New(structType)) @@ -1050,6 +1042,9 @@ func (c *ReadContext) ReadStruct(value reflect.Value) { // Register reference before reading (for circular references) refResolver.SetReadObject(refID, value) } else { + if !c.ReserveGraphMemory(structGraphBytes(structType)) { + return + } readTarget = value // For non-pointer structs, register a pointer to enable circular ref resolution if value.CanAddr() { diff --git a/go/fory/stream.go b/go/fory/stream.go index f2391d7152..cea27cec61 100644 --- a/go/fory/stream.go +++ b/go/fory/stream.go @@ -100,14 +100,6 @@ func (f *Fory) DeserializeFromStream(is *InputStream, v any) error { limit := f.config.MaxGraphMemoryBytes f.readCtx.graphMemoryLimitBytes = limit f.readCtx.remainingGraphMemoryBytes = limit - if bytes, ok := f.rootGraphBytesFor(target.Type()); ok && bytes > 0 { - if !f.readCtx.ReserveGraphMemory(bytes) { - err := f.readCtx.TakeError() - f.readCtx.buffer = origBuffer - f.resetReadState() - return err - } - } defer func() { f.readCtx.buffer = origBuffer f.resetReadState() @@ -118,7 +110,11 @@ func (f *Fory) DeserializeFromStream(is *InputStream, v any) error { return f.readCtx.TakeError() } - f.readCtx.ReadValue(target, RefModeTracking, true) + if target.Kind() == reflect.Struct && target.Type() != dateReflectType && target.Type() != timeReflectType && target.Type() != decimalType { + f.readCtx.ReadStruct(target) + } else { + f.readCtx.ReadValue(target, RefModeTracking, true) + } if f.readCtx.HasError() { return f.readCtx.TakeError() } @@ -138,18 +134,17 @@ func (f *Fory) DeserializeFromReader(r io.Reader, v any) error { limit := f.config.MaxGraphMemoryBytes f.readCtx.graphMemoryLimitBytes = limit f.readCtx.remainingGraphMemoryBytes = limit - if bytes, ok := f.rootGraphBytesFor(target.Type()); ok && bytes > 0 { - if !f.readCtx.ReserveGraphMemory(bytes) { - return f.readCtx.TakeError() - } - } readHeader(f.readCtx) if f.readCtx.HasError() { return f.readCtx.TakeError() } - f.readCtx.ReadValue(target, RefModeTracking, true) + if target.Kind() == reflect.Struct && target.Type() != dateReflectType && target.Type() != timeReflectType && target.Type() != decimalType { + f.readCtx.ReadStruct(target) + } else { + f.readCtx.ReadValue(target, RefModeTracking, true) + } if f.readCtx.HasError() { return f.readCtx.TakeError() } diff --git a/rust/fory-core/src/fory.rs b/rust/fory-core/src/fory.rs index 8c759f6f21..c217726776 100644 --- a/rust/fory-core/src/fory.rs +++ b/rust/fory-core/src/fory.rs @@ -1146,11 +1146,6 @@ impl Fory { } else { RefMode::NullOnly }; - // TypeMeta is read inline during deserialization (streaming protocol) - let root_graph_self_size = T::fory_graph_self_size(); - if root_graph_self_size != 0 { - context.reserve_graph_memory(root_graph_self_size)?; - } let result = ::fory_read(context, ref_mode, true); context.ref_reader.resolve_callbacks(); result diff --git a/rust/fory-core/src/serializer/core.rs b/rust/fory-core/src/serializer/core.rs index 255f9c8a86..09a3d2adf8 100644 --- a/rust/fory-core/src/serializer/core.rs +++ b/rust/fory-core/src/serializer/core.rs @@ -585,6 +585,10 @@ pub trait Serializer: 'static { if read_type_info { Self::fory_read_type_info(context)?; } + let graph_self_size = Self::fory_graph_self_size(); + if graph_self_size != 0 { + context.reserve_graph_memory(graph_self_size)?; + } Self::fory_read_data(context) } diff --git a/rust/fory-derive/src/object/read.rs b/rust/fory-derive/src/object/read.rs index c4addc8593..9d3a20d819 100644 --- a/rust/fory-derive/src/object/read.rs +++ b/rust/fory-derive/src/object/read.rs @@ -195,6 +195,10 @@ pub fn gen_read(_struct_ident: &Ident) -> TokenStream { if ref_flag == (fory_core::RefFlag::RefValue as i8) && ref_mode == fory_core::RefMode::Tracking { context.ref_reader.reserve_ref_id(); } + let graph_self_size = ::fory_graph_self_size(); + if graph_self_size != 0 { + context.reserve_graph_memory(graph_self_size)?; + } if context.is_compatible() { let type_info = if read_type_info { context.read_any_type_info()? From 5e4304a3a624bb34d302242e12820a503a715d15 Mon Sep 17 00:00:00 2001 From: chaokunyang Date: Fri, 3 Jul 2026 07:53:12 +0800 Subject: [PATCH 42/54] fix(go): keep union roots on generic read path --- go/fory/fory.go | 57 +++++++++++++++++++++++++---------------------- go/fory/stream.go | 12 ++-------- 2 files changed, 32 insertions(+), 37 deletions(-) diff --git a/go/fory/fory.go b/go/fory/fory.go index e7f8f270f6..294cd8bbef 100644 --- a/go/fory/fory.go +++ b/go/fory/fory.go @@ -578,13 +578,7 @@ func (f *Fory) Deserialize(data []byte, v any) error { return f.readCtx.TakeError() } - // Root writes include type metadata, so keep the root ReadValue path. - // Calling a cached serializer directly would read that metadata byte as payload. - if target.Kind() == reflect.Struct && target.Type() != dateReflectType && target.Type() != timeReflectType && target.Type() != decimalType { - f.readCtx.ReadStruct(target) - } else { - f.readCtx.ReadValue(target, RefModeTracking, true) - } + f.readRootValue(target) if f.readCtx.HasError() { return f.readCtx.TakeError() } @@ -680,11 +674,7 @@ func (f *Fory) DeserializeFrom(buf *ByteBuffer, v any) error { } // Deserialize the value - TypeMeta is read inline using streaming protocol - if target.Kind() == reflect.Struct && target.Type() != dateReflectType && target.Type() != timeReflectType && target.Type() != decimalType { - f.readCtx.ReadStruct(target) - } else { - f.readCtx.ReadValue(target, RefModeTracking, true) - } + f.readRootValue(target) if f.readCtx.HasError() { f.readCtx.buffer = origBuffer return f.readCtx.TakeError() @@ -799,11 +789,7 @@ func (f *Fory) DeserializeWithCallbackBuffers(buffer *ByteBuffer, v any, buffers } // Deserialize the value - TypeMeta is read inline using streaming protocol - if target.Kind() == reflect.Struct && target.Type() != dateReflectType && target.Type() != timeReflectType && target.Type() != decimalType { - f.readCtx.ReadStruct(target) - } else { - f.readCtx.ReadValue(target, RefModeTracking, true) - } + f.readRootValue(target) if f.readCtx.HasError() { return f.readCtx.TakeError() } @@ -1207,19 +1193,36 @@ func Deserialize[T any](f *Fory, data []byte, target *T) error { targetVal = reflect.ValueOf(target).Elem() targetType = targetVal.Type() } - if targetType.Kind() == reflect.Struct && targetType != dateReflectType && targetType != timeReflectType && targetType != decimalType { + // Get serializer for the target type + if f.rootUsesReadStruct(targetType) { f.readCtx.ReadStruct(targetVal) - return f.readCtx.CheckError() - } + } else { + serializer, err := f.typeResolver.getSerializerByType(targetType, false) + if err != nil { + return fmt.Errorf("failed to get serializer for type %v: %w", targetType, err) + } - // Get serializer for the target type - serializer, err := f.typeResolver.getSerializerByType(targetType, false) - if err != nil { - return fmt.Errorf("failed to get serializer for type %v: %w", targetType, err) + // Use Read to deserialize directly into target + serializer.Read(f.readCtx, RefModeTracking, true, false, targetVal) } - - // Use Read to deserialize directly into target - serializer.Read(f.readCtx, RefModeTracking, true, false, targetVal) return f.readCtx.CheckError() } } + +func (f *Fory) readRootValue(target reflect.Value) { + if f.rootUsesReadStruct(target.Type()) { + f.readCtx.ReadStruct(target) + return + } + // Root writes include type metadata, so generic roots must keep ReadValue. + // Calling a cached serializer directly would read that metadata byte as payload. + f.readCtx.ReadValue(target, RefModeTracking, true) +} + +func (f *Fory) rootUsesReadStruct(targetType reflect.Type) bool { + return targetType.Kind() == reflect.Struct && + targetType != dateReflectType && + targetType != timeReflectType && + targetType != decimalType && + !f.typeResolver.IsUnionType(targetType) +} diff --git a/go/fory/stream.go b/go/fory/stream.go index cea27cec61..c98362ef67 100644 --- a/go/fory/stream.go +++ b/go/fory/stream.go @@ -110,11 +110,7 @@ func (f *Fory) DeserializeFromStream(is *InputStream, v any) error { return f.readCtx.TakeError() } - if target.Kind() == reflect.Struct && target.Type() != dateReflectType && target.Type() != timeReflectType && target.Type() != decimalType { - f.readCtx.ReadStruct(target) - } else { - f.readCtx.ReadValue(target, RefModeTracking, true) - } + f.readRootValue(target) if f.readCtx.HasError() { return f.readCtx.TakeError() } @@ -140,11 +136,7 @@ func (f *Fory) DeserializeFromReader(r io.Reader, v any) error { return f.readCtx.TakeError() } - if target.Kind() == reflect.Struct && target.Type() != dateReflectType && target.Type() != timeReflectType && target.Type() != decimalType { - f.readCtx.ReadStruct(target) - } else { - f.readCtx.ReadValue(target, RefModeTracking, true) - } + f.readRootValue(target) if f.readCtx.HasError() { return f.readCtx.TakeError() } From 80d652ee8cf270800727ebc17bbb825653ea0222 Mon Sep 17 00:00:00 2001 From: chaokunyang Date: Fri, 3 Jul 2026 08:08:22 +0800 Subject: [PATCH 43/54] refactor: remove graph budget limit mirrors --- .agents/languages/cpp.md | 2 ++ .agents/languages/csharp.md | 2 ++ .agents/languages/go.md | 3 ++- .agents/languages/java.md | 3 ++- .agents/languages/python.md | 4 +++- .agents/languages/rust.md | 2 ++ AGENTS.md | 2 +- cpp/fory/serialization/context.cc | 8 +------- cpp/fory/serialization/context.h | 2 -- cpp/fory/serialization/fory.h | 2 +- csharp/src/Fory/Fory.cs | 12 +++--------- csharp/src/Fory/ReadContext.cs | 4 +--- .../tests/Fory.Tests/GraphMemoryBudgetTests.cs | 1 - docs/specification/xlang_implementation_guide.md | 4 ++++ go/fory/fory.go | 16 ++++------------ go/fory/graph_memory_budget_test.go | 8 +++----- go/fory/reader.go | 6 ++---- go/fory/stream.go | 8 ++------ .../org/apache/fory/context/ReadContext.java | 9 ++------- python/pyfory/context.pxi | 8 ++------ python/pyfory/context.py | 8 ++------ python/pyfory/serialization.pyx | 1 - python/pyfory/tests/test_graph_memory_budget.py | 4 ++-- rust/fory-core/src/context.rs | 6 ++---- rust/fory-core/src/fory.rs | 2 -- 25 files changed, 45 insertions(+), 82 deletions(-) diff --git a/.agents/languages/cpp.md b/.agents/languages/cpp.md index 8ebc82ae70..d1047934b2 100644 --- a/.agents/languages/cpp.md +++ b/.agents/languages/cpp.md @@ -23,6 +23,8 @@ Load this file when changing `cpp/`, Cython build plumbing, or C++ xlang behavio values are invalid at config creation. Byte and stream roots use the same configured/default budget behavior. Root `Fory` overloads reset the budget only; they must not pre-reserve root type or root self bytes. + Do not mirror the configured max into a second active-limit field; use config plus mutable + remaining budget. Reserve estimated shallow graph-owner memory before allocation while preserving existing byte-availability checks and their non-empty metadata ordering. `ReadContext` may expose only raw byte reservation; collection, map, array, struct, and object diff --git a/.agents/languages/csharp.md b/.agents/languages/csharp.md index 3c8b4a14dd..09924e7637 100644 --- a/.agents/languages/csharp.md +++ b/.agents/languages/csharp.md @@ -15,6 +15,8 @@ Load this file when changing `csharp/` or C# xlang behavior. - Root deserialization graph memory budget state belongs to `ReadContext`. C# public roots are memory-backed today, but the graph budget uses the same fixed default for every root shape. Root APIs reset the budget only; they must not pre-reserve root type or root self bytes. + Do not mirror the configured max into a second active-limit field; use config plus mutable + remaining budget. `ReadContext` may expose only raw byte reservation; concrete serializers and generated serializers must compute list, array, map, struct, and object byte formulas before calling it. - `ReadContext` must not expose ref-publication pause/resume APIs or any non-budget owner diff --git a/.agents/languages/go.md b/.agents/languages/go.md index c0c00716c5..bad052077a 100644 --- a/.agents/languages/go.md +++ b/.agents/languages/go.md @@ -12,7 +12,8 @@ Load this file when changing `go/fory/` or Go xlang behavior. values override it, and explicit non-positive values are invalid at config creation. Byte-slice and stream roots use the same configured/default budget behavior. Root APIs reset the budget only; they must not pre-reserve - root type or root self bytes. `ReadContext` may expose only raw byte + root type or root self bytes. Do not mirror the configured max into a second active-limit field; + root setup should update only the mutable remaining budget. `ReadContext` may expose only raw byte reservation; slice, map, array, struct, and object formulas belong in handwritten or generated serializer owners. Reserve Go slices as `len * elemBytes`, maps as `len * (keyBytes + valueBytes)`, diff --git a/.agents/languages/java.md b/.agents/languages/java.md index 9c9f0f3a30..7021dc18be 100644 --- a/.agents/languages/java.md +++ b/.agents/languages/java.md @@ -20,7 +20,8 @@ Load this file when changing anything under `java/` or when Java drives a cross- explicit non-positive values are invalid and must be rejected at config creation. Byte-array, memory-buffer, and stream roots use the same configured/default budget behavior. Root APIs reset the budget only; they must not pre-reserve - root type or root self bytes. `ReadContext` + root type or root self bytes. Do not mirror the configured max into a second + active-limit field; use config plus mutable remaining budget. `ReadContext` may expose only raw byte reservation; collection, map, array, struct, and object formulas belong in the concrete serializer or generated serializer owner. Java collection, map, and diff --git a/.agents/languages/python.md b/.agents/languages/python.md index 81222297f8..5e8005c9b2 100644 --- a/.agents/languages/python.md +++ b/.agents/languages/python.md @@ -17,7 +17,9 @@ Load this file when changing `python/`, Cython serialization, or Python xlang be Keep `max_graph_memory_bytes` public on `pyfory.Fory`/`Config`; the default effective limit is fixed `128 MiB`, positive explicit values override it, and explicit non-positive values are invalid at config creation. Byte and stream roots use the same - configured/default budget behavior. `ReadContext` may expose only raw + configured/default budget behavior. Do not mirror the configured max into a + second active-limit field; keep one configured max plus mutable remaining + budget. `ReadContext` may expose only raw byte reservation; collection, dict, array, struct, and object formulas belong in the pure-Python or Cython serializer owner. Lists, tuples, sets, and object-dtype ndarray item storage reserve nonzero owner self cost plus `count * PyObject*`; dicts diff --git a/.agents/languages/rust.md b/.agents/languages/rust.md index 29b14032a5..ea57f19624 100644 --- a/.agents/languages/rust.md +++ b/.agents/languages/rust.md @@ -24,6 +24,8 @@ Load this file when changing `rust/` or Rust xlang behavior. creation. Root `Fory` read methods reset the budget only; they must not pre-reserve root type or root self bytes. Do not derive the budget from root input size or add dynamic bytes-read accounting. + Do not mirror the configured max into a second active-limit field; keep one configured max plus + mutable remaining budget. `ReadContext` may expose only raw byte reservation; `Vec`, collection, map, array, struct, object, and derive codec formulas belong in their serializer owners. diff --git a/AGENTS.md b/AGENTS.md index a6e1d9a915..421ea8c81f 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -33,7 +33,7 @@ This is the entry point for AI guidance in Apache Fory. Read this file first, th - Check the spec before implementation. For wire behavior and xlang mapping, use the specs as the source of truth and never copy one runtime's bug into another runtime just to make tests pass. - Do not make assumptions about runtime behavior, ownership, registration, metadata construction, protocol semantics, or test coverage. Read the current code, owning docs/specs, and relevant tests before making a design judgment or implementation decision. If the evidence is incomplete, inspect more or state the uncertainty explicitly instead of filling gaps from memory or analogy with another runtime. - For untrusted deserialization, read `docs/security/deserialization.md` before changing allocation, stream filling, skip, reference, metadata, or policy validation behavior. Variable-length deserialization must not allocate or reserve backing/output capacity from attacker-declared lengths or counts before the byte owner has proven proportional readable bytes with `checkReadableBytes` or the runtime equivalent. Root graph memory reservation is accounting only and may happen before that byte check, but it must not replace the byte check. -- Root deserialization graph memory budgets estimate lower-bound shallow memory for materialized graph owners, not exact heap accounting, input byte accounting, or raw element counts. `maxGraphMemoryBytes` defaults to fixed `128 MiB`; positive values override the default; explicit non-positive values are invalid and must be rejected at config/Fory creation. Do not add a disabled-budget sentinel path, derive this budget from root input size, or split known-length and stream root behavior. Read context/read state owns only raw byte reservation with `reserveGraphMemory(bytes)`; it must not expose counted arithmetic helpers or collection, map, array, struct, or object semantic reservation APIs. Do not add any non-memory-budget read-context/read-state API for this feature, including ref-publication controls, temporary-owner controls, serializer-owner controls, conversion helpers, or APIs that encode what kind of value is being materialized. Root facades may set/reset the per-operation budget, but they must not pre-reserve root type or root self bytes; the serializer or materialization owner reserves root value storage. Concrete serializers and generated serializers own counted formulas, overflow checks, allocation-owner decisions, and reference publication timing for their allocation path. Reserve self storage exactly once at the owner that stores or allocates the value: reference/object runtimes reserve parent owner self cost plus reference storage and every referenced heap owner reserves its own shallow self cost when materialized; inline/value runtimes reserve value storage in the struct/product, collection, map, set, array, smart-pointer, box, or root materialization owner that actually owns that storage, and nested value serializers do not unconditionally reserve their own `sizeof(T)` again. Collection, map, set, and reference-array owners reserve nonzero shallow self cost when independently materialized, plus backing/reference/inline storage. Struct, record, POJO, tuple/product, compatible, generated, and dynamic object owners reserve a nonzero shallow self cost plus shallow field storage. Reference fields use 4 bytes when reference size is not cheap or reliable to query; primitive/value fields use their storage width. Parents do not recursively include child object, collection, map, string, binary, or primitive dense-array contents. Skip enum/union as separate owners and skip dedicated string, binary, primitive scalar, primitive array, and primitive dense-array leaf owners. Do not guess allocator, bucket-table, node, debug, or map-entry overhead unless it is a documented lower-bound owner allocation. Do not add dynamic stream bytes-read accounting or nested hot-path cleanup just for this budget. +- Root deserialization graph memory budgets estimate lower-bound shallow memory for materialized graph owners, not exact heap accounting, input byte accounting, or raw element counts. `maxGraphMemoryBytes` defaults to fixed `128 MiB`; positive values override the default; explicit non-positive values are invalid and must be rejected at config/Fory creation. Do not add a disabled-budget sentinel path, derive this budget from root input size, or split known-length and stream root behavior. Read context/read state owns only raw byte reservation with `reserveGraphMemory(bytes)`; it must not expose counted arithmetic helpers or collection, map, array, struct, or object semantic reservation APIs. Do not add any non-memory-budget read-context/read-state API for this feature, including ref-publication controls, temporary-owner controls, serializer-owner controls, conversion helpers, or APIs that encode what kind of value is being materialized. Root facades may set/reset the per-operation budget, but they must not pre-reserve root type or root self bytes; the serializer or materialization owner reserves root value storage. Because the budget is fixed per root, read state must not mirror the configured max into a second active-limit field; use the existing config or one configured max field plus the mutable remaining budget. Concrete serializers and generated serializers own counted formulas, overflow checks, allocation-owner decisions, and reference publication timing for their allocation path. Reserve self storage exactly once at the owner that stores or allocates the value: reference/object runtimes reserve parent owner self cost plus reference storage and every referenced heap owner reserves its own shallow self cost when materialized; inline/value runtimes reserve value storage in the struct/product, collection, map, set, array, smart-pointer, box, or root materialization owner that actually owns that storage, and nested value serializers do not unconditionally reserve their own `sizeof(T)` again. Collection, map, set, and reference-array owners reserve nonzero shallow self cost when independently materialized, plus backing/reference/inline storage. Struct, record, POJO, tuple/product, compatible, generated, and dynamic object owners reserve a nonzero shallow self cost plus shallow field storage. Reference fields use 4 bytes when reference size is not cheap or reliable to query; primitive/value fields use their storage width. Parents do not recursively include child object, collection, map, string, binary, or primitive dense-array contents. Skip enum/union as separate owners and skip dedicated string, binary, primitive scalar, primitive array, and primitive dense-array leaf owners. Do not guess allocator, bucket-table, node, debug, or map-entry overhead unless it is a documented lower-bound owner allocation. Do not add dynamic stream bytes-read accounting or nested hot-path cleanup just for this budget. - For remote TypeDef/TypeMeta reads, the checked metadata cache is the only owner of remote "already validated" state. Cache hit means the header was previously parsed, body/hash-validated, policy-checked, and published by that cache, so the hot path must skip the body and use cached metadata without extra validation, hashing, limit checks, exact-local checks, allocation, or policy work. A known expected local TypeDef/TypeMeta header/hash match is a local-schema hit, not a remote cache miss: it may skip the body and use the local TypeInfo/TypeMeta without schema-version counting or cache publish. Cache miss is the only path that parses and validates non-local metadata, enforces limits, performs exact-local byte comparison when needed, and publishes remote metadata to the cache. Do not add nullable accepted-header fields, sentinel headers, per-TypeInfo markers, pending metadata state, parallel header-low/header-high slots, or parallel acceptance state for this decision. If a runtime needs a metadata hit hint, cache the concrete checked metadata owner object, such as the TypeInfo, TypeDef, or TypeMeta used by that runtime, and compare its validated header identity directly. - When a user corrects a non-obvious invariant, encode it in the nearest source comment before continuing, and also update `AGENTS.md`, `.agents/**`, docs, or specs when the rule is reusable beyond one file. Do not rely only on chat history, task notes, commit messages, or benchmark logs for corrections that protect security, protocol behavior, ownership, naming, or hot-path performance. - Reject semantic hacks. Do not bypass broken semantics by deleting cases, simplifying callers, adding coercion hooks, or using workaround fallbacks; fix the underlying bug and prove it with focused tests. diff --git a/cpp/fory/serialization/context.cc b/cpp/fory/serialization/context.cc index 8e27c80491..c2837fc84b 100644 --- a/cpp/fory/serialization/context.cc +++ b/cpp/fory/serialization/context.cc @@ -437,9 +437,8 @@ ReadContext::ReadContext(const Config &config, type_resolver_(std::move(type_resolver)), current_dyn_depth_(0) { FORY_CHECK(config.max_graph_memory_bytes > 0) << "max_graph_memory_bytes must be positive"; - graph_memory_limit_bytes_ = + remaining_graph_memory_bytes_ = static_cast(config.max_graph_memory_bytes); - remaining_graph_memory_bytes_ = graph_memory_limit_bytes_; } ReadContext::~ReadContext() = default; @@ -745,11 +744,6 @@ const TypeInfo *ReadContext::read_any_type_info(Error &error) { return result.value(); } -bool ReadContext::set_graph_memory_limit_error(const std::string &message) { - set_error(Error::invalid_data(message)); - return false; -} - bool ReadContext::set_graph_memory_exceeded(size_t bytes, size_t remaining) { set_error(Error::invalid_data( "estimated graph memory request " + std::to_string(bytes) + diff --git a/cpp/fory/serialization/context.h b/cpp/fory/serialization/context.h index 6555481a65..3ddbf8b4af 100644 --- a/cpp/fory/serialization/context.h +++ b/cpp/fory/serialization/context.h @@ -674,7 +674,6 @@ class ReadContext { FORY_NOINLINE Result check_remote_type_meta_limit(const TypeMeta &type_meta); void record_remote_type_meta(const std::string &type_key); - FORY_NOINLINE bool set_graph_memory_limit_error(const std::string &message); FORY_NOINLINE bool set_graph_memory_exceeded(size_t bytes, size_t remaining); // Error state - accumulated during deserialization, checked at the end @@ -685,7 +684,6 @@ class ReadContext { std::unique_ptr type_resolver_; RefReader ref_reader_; uint32_t current_dyn_depth_; - size_t graph_memory_limit_bytes_ = 0; size_t remaining_graph_memory_bytes_ = 0; // Meta sharing state (for compatible mode) diff --git a/cpp/fory/serialization/fory.h b/cpp/fory/serialization/fory.h index 13e52b2eb1..527a385e32 100644 --- a/cpp/fory/serialization/fory.h +++ b/cpp/fory/serialization/fory.h @@ -888,7 +888,7 @@ class Fory : public BaseFory { read_ctx_->attach(buffer); read_ctx_->remaining_graph_memory_bytes_ = - read_ctx_->graph_memory_limit_bytes_; + static_cast(read_ctx_->config_->max_graph_memory_bytes); ReadContextGuard guard(*read_ctx_); return deserialize_impl(buffer); } diff --git a/csharp/src/Fory/Fory.cs b/csharp/src/Fory/Fory.cs index 54387c5e87..e2698d8b0e 100644 --- a/csharp/src/Fory/Fory.cs +++ b/csharp/src/Fory/Fory.cs @@ -192,9 +192,7 @@ public T Deserialize(ReadOnlySpan payload) { ByteReader reader = _readContext.Reader; reader.Reset(payload); - long graphLimit = Config.MaxGraphMemoryBytes; - _readContext._graphMemoryLimitBytes = graphLimit; - _readContext._remainingGraphMemoryBytes = graphLimit; + _readContext._remainingGraphMemoryBytes = Config.MaxGraphMemoryBytes; T value = DeserializeFromReader(reader); if (reader.Remaining != 0) { @@ -215,9 +213,7 @@ public T Deserialize(byte[] payload) { ByteReader reader = _readContext.Reader; reader.Reset(payload); - long graphLimit = Config.MaxGraphMemoryBytes; - _readContext._graphMemoryLimitBytes = graphLimit; - _readContext._remainingGraphMemoryBytes = graphLimit; + _readContext._remainingGraphMemoryBytes = Config.MaxGraphMemoryBytes; T value = DeserializeFromReader(reader); if (reader.Remaining != 0) { @@ -238,9 +234,7 @@ public T Deserialize(ref ReadOnlySequence payload) byte[] bytes = payload.ToArray(); ByteReader reader = _readContext.Reader; reader.Reset(bytes); - long graphLimit = Config.MaxGraphMemoryBytes; - _readContext._graphMemoryLimitBytes = graphLimit; - _readContext._remainingGraphMemoryBytes = graphLimit; + _readContext._remainingGraphMemoryBytes = Config.MaxGraphMemoryBytes; T value = DeserializeFromReader(reader); payload = payload.Slice(reader.Cursor); return value; diff --git a/csharp/src/Fory/ReadContext.cs b/csharp/src/Fory/ReadContext.cs index bd0283c9a4..05d1bec5ca 100644 --- a/csharp/src/Fory/ReadContext.cs +++ b/csharp/src/Fory/ReadContext.cs @@ -45,7 +45,6 @@ public sealed class ReadContext private readonly Dictionary _remoteSchemaVersionsByType = []; private readonly Config _config; private int _totalAcceptedSchemaVersions; - internal long _graphMemoryLimitBytes; internal long _remainingGraphMemoryBytes; public ReadContext( @@ -63,7 +62,6 @@ public ReadContext( RefReader = new RefReader(); _maxDynamicReadDepth = config.MaxDepth; _config = config; - _graphMemoryLimitBytes = config.MaxGraphMemoryBytes; _remainingGraphMemoryBytes = config.MaxGraphMemoryBytes; } @@ -108,7 +106,7 @@ private void ReserveGraphMemorySlow(long bytes, long remaining) } throw new InvalidDataException( - $"estimated graph memory request {bytes} bytes exceeds MaxGraphMemoryBytes remaining budget {remaining} bytes out of effective limit {_graphMemoryLimitBytes} bytes"); + $"estimated graph memory request {bytes} bytes exceeds MaxGraphMemoryBytes remaining budget {remaining} bytes out of effective limit {_config.MaxGraphMemoryBytes} bytes"); } internal void ResetFor(ByteReader reader) diff --git a/csharp/tests/Fory.Tests/GraphMemoryBudgetTests.cs b/csharp/tests/Fory.Tests/GraphMemoryBudgetTests.cs index 05061ed855..fb9773795b 100644 --- a/csharp/tests/Fory.Tests/GraphMemoryBudgetTests.cs +++ b/csharp/tests/Fory.Tests/GraphMemoryBudgetTests.cs @@ -149,7 +149,6 @@ public void DefaultFixedBudgetAndValidation() Assert.Throws(() => NewFory(-2)); ReadContext context = new(new ByteReader([]), new TypeResolver(), NewFory().Config); - context._graphMemoryLimitBytes = DefaultGraphMemoryBytes; context._remainingGraphMemoryBytes = DefaultGraphMemoryBytes; context.ReserveGraphMemory(DefaultGraphMemoryBytes); Assert.Throws(() => context.ReserveGraphMemory(ReferenceBytes)); diff --git a/docs/specification/xlang_implementation_guide.md b/docs/specification/xlang_implementation_guide.md index bc05252720..5fb5e299aa 100644 --- a/docs/specification/xlang_implementation_guide.md +++ b/docs/specification/xlang_implementation_guide.md @@ -407,6 +407,10 @@ fixed `128 MiB`; positive configuration overrides the default; explicit non-positive configuration is invalid and must be rejected when the runtime is created. Do not derive this budget from root input size, and do not add dynamic stream bytes-read accounting for this budget. +Because the budget is fixed per root, read state should not mirror the +configured maximum into a second active-limit field. Use the existing +configuration, or one configured maximum field when the config is not otherwise +available, plus the mutable remaining budget. Read context or equivalent read state owns only raw byte reservation. It must not expose counted arithmetic helpers or collection, map, array, struct, or diff --git a/go/fory/fory.go b/go/fory/fory.go index 294cd8bbef..e0986bbebb 100644 --- a/go/fory/fory.go +++ b/go/fory/fory.go @@ -569,9 +569,7 @@ func (f *Fory) Deserialize(data []byte, v any) error { defer f.resetReadState() f.readCtx.SetData(data) target := reflect.ValueOf(v).Elem() - limit := f.config.MaxGraphMemoryBytes - f.readCtx.graphMemoryLimitBytes = limit - f.readCtx.remainingGraphMemoryBytes = limit + f.readCtx.remainingGraphMemoryBytes = f.config.MaxGraphMemoryBytes readHeader(f.readCtx) if f.readCtx.HasError() { @@ -663,9 +661,7 @@ func (f *Fory) DeserializeFrom(buf *ByteBuffer, v any) error { origBuffer := f.readCtx.buffer f.readCtx.buffer = buf target := reflect.ValueOf(v).Elem() - limit := f.config.MaxGraphMemoryBytes - f.readCtx.graphMemoryLimitBytes = limit - f.readCtx.remainingGraphMemoryBytes = limit + f.readCtx.remainingGraphMemoryBytes = f.config.MaxGraphMemoryBytes readHeader(f.readCtx) if f.readCtx.HasError() { @@ -778,9 +774,7 @@ func (f *Fory) DeserializeWithCallbackBuffers(buffer *ByteBuffer, v any, buffers } target := rv.Elem() - limit := f.config.MaxGraphMemoryBytes - f.readCtx.graphMemoryLimitBytes = limit - f.readCtx.remainingGraphMemoryBytes = limit + f.readCtx.remainingGraphMemoryBytes = f.config.MaxGraphMemoryBytes // ReadData and validate header readHeader(f.readCtx) @@ -1038,9 +1032,7 @@ func Deserialize[T any](f *Fory, data []byte, target *T) error { // Reuse context, reset and set new data f.readCtx.Reset() f.readCtx.SetData(data) - limit := f.config.MaxGraphMemoryBytes - f.readCtx.graphMemoryLimitBytes = limit - f.readCtx.remainingGraphMemoryBytes = limit + f.readCtx.remainingGraphMemoryBytes = f.config.MaxGraphMemoryBytes var targetVal reflect.Value var targetType reflect.Type diff --git a/go/fory/graph_memory_budget_test.go b/go/fory/graph_memory_budget_test.go index 57536abc9b..f75546cf8e 100644 --- a/go/fory/graph_memory_budget_test.go +++ b/go/fory/graph_memory_budget_test.go @@ -52,17 +52,15 @@ func TestGraphMemoryBudgetConfig(t *testing.T) { func TestGraphMemoryBudgetFixedDefault(t *testing.T) { ctx := NewReadContext(false) - ctx.graphMemoryLimitBytes = 128 * 1024 * 1024 ctx.remainingGraphMemoryBytes = 128 * 1024 * 1024 - require.Equal(t, int64(128*1024*1024), ctx.graphMemoryLimitBytes) - require.True(t, ctx.ReserveGraphMemory(ctx.graphMemoryLimitBytes)) + require.Equal(t, int64(128*1024*1024), ctx.remainingGraphMemoryBytes) + require.True(t, ctx.ReserveGraphMemory(ctx.remainingGraphMemoryBytes)) require.False(t, ctx.ReserveGraphMemory(1)) require.Contains(t, ctx.CheckError().Error(), "maxGraphMemoryBytes") ctx = NewReadContext(false) - ctx.graphMemoryLimitBytes = 77 ctx.remainingGraphMemoryBytes = 77 - require.Equal(t, int64(77), ctx.graphMemoryLimitBytes) + require.Equal(t, int64(77), ctx.remainingGraphMemoryBytes) } func TestGraphBudgetRootKinds(t *testing.T) { diff --git a/go/fory/reader.go b/go/fory/reader.go index c6f0eca1a2..f550e427d3 100644 --- a/go/fory/reader.go +++ b/go/fory/reader.go @@ -44,7 +44,6 @@ type ReadContext struct { err Error // Accumulated error state for deferred checking lastTypePtr uintptr lastTypeInfo *TypeInfo - graphMemoryLimitBytes int64 remainingGraphMemoryBytes int64 } @@ -87,7 +86,6 @@ func NewReadContext(trackRef bool) *ReadContext { refReader: NewRefReader(trackRef), trackRef: trackRef, maxDepth: 128, // Default maximum nesting depth - graphMemoryLimitBytes: 128 * 1024 * 1024, remainingGraphMemoryBytes: 128 * 1024 * 1024, } } @@ -130,8 +128,8 @@ func (c *ReadContext) rejectGraphMemoryBytes(bytes int64) bool { //go:noinline func (c *ReadContext) rejectGraphMemoryExceeded(bytes int64, remaining int64) bool { c.SetError(DeserializationErrorf( - "estimated graph memory request %d bytes exceeds maxGraphMemoryBytes remaining budget %d bytes out of effective limit %d bytes", - bytes, remaining, c.graphMemoryLimitBytes)) + "estimated graph memory request %d bytes exceeds maxGraphMemoryBytes remaining budget %d bytes", + bytes, remaining)) return false } diff --git a/go/fory/stream.go b/go/fory/stream.go index c98362ef67..6c6c91e306 100644 --- a/go/fory/stream.go +++ b/go/fory/stream.go @@ -97,9 +97,7 @@ func (f *Fory) DeserializeFromStream(is *InputStream, v any) error { origBuffer := f.readCtx.buffer f.readCtx.buffer = is.buffer target := reflect.ValueOf(v).Elem() - limit := f.config.MaxGraphMemoryBytes - f.readCtx.graphMemoryLimitBytes = limit - f.readCtx.remainingGraphMemoryBytes = limit + f.readCtx.remainingGraphMemoryBytes = f.config.MaxGraphMemoryBytes defer func() { f.readCtx.buffer = origBuffer f.resetReadState() @@ -127,9 +125,7 @@ func (f *Fory) DeserializeFromReader(r io.Reader, v any) error { // Always reset to enforce stateless semantics. f.readCtx.buffer.ResetWithReader(r, 0) target := reflect.ValueOf(v).Elem() - limit := f.config.MaxGraphMemoryBytes - f.readCtx.graphMemoryLimitBytes = limit - f.readCtx.remainingGraphMemoryBytes = limit + f.readCtx.remainingGraphMemoryBytes = f.config.MaxGraphMemoryBytes readHeader(f.readCtx) if f.readCtx.HasError() { diff --git a/java/fory-core/src/main/java/org/apache/fory/context/ReadContext.java b/java/fory-core/src/main/java/org/apache/fory/context/ReadContext.java index ff30a0950a..d3a1f0f0d9 100644 --- a/java/fory-core/src/main/java/org/apache/fory/context/ReadContext.java +++ b/java/fory-core/src/main/java/org/apache/fory/context/ReadContext.java @@ -63,7 +63,6 @@ public final class ReadContext { private final boolean compressInt; private final Int64Encoding longEncoding; private final int maxDepth; - private final long maxGraphMemoryBytes; private final boolean scopedMetaShareEnabled; private final boolean forVirtualThread; private final IdentityHashMap contextObjects = new IdentityHashMap<>(); @@ -72,7 +71,6 @@ public final class ReadContext { private MetaReadContext metaReadContext; private boolean peerOutOfBandEnabled; private int depth; - private long graphMemoryLimitBytes; private long remainingGraphMemoryBytes; /** @@ -99,7 +97,6 @@ public ReadContext( compressInt = config.compressInt(); longEncoding = config.longEncoding(); maxDepth = config.maxDepth(); - maxGraphMemoryBytes = config.maxGraphMemoryBytes(); forVirtualThread = config.forVirtualThread(); scopedMetaShareEnabled = config.isScopedMetaShareEnabled(); if (scopedMetaShareEnabled) { @@ -116,8 +113,7 @@ public void prepare( this.buffer = buffer; this.peerOutOfBandEnabled = peerOutOfBandEnabled; this.outOfBandBuffers = outOfBandBuffers == null ? null : outOfBandBuffers.iterator(); - graphMemoryLimitBytes = maxGraphMemoryBytes; - remainingGraphMemoryBytes = maxGraphMemoryBytes; + remainingGraphMemoryBytes = config.maxGraphMemoryBytes(); } /** @@ -313,7 +309,6 @@ public void reset() { outOfBandBuffers = null; peerOutOfBandEnabled = false; depth = 0; - graphMemoryLimitBytes = 0; remainingGraphMemoryBytes = 0; } @@ -345,7 +340,7 @@ private void throwGraphMemoryExceeded(long bytes, long remaining) { + " bytes exceeds maxGraphMemoryBytes remaining budget " + remaining + " bytes out of effective limit " - + graphMemoryLimitBytes + + config.maxGraphMemoryBytes() + " bytes. If the data is trusted, increase ForyBuilder#withMaxGraphMemoryBytes."); } diff --git a/python/pyfory/context.pxi b/python/pyfory/context.pxi index 03076f093e..e53bc20175 100644 --- a/python/pyfory/context.pxi +++ b/python/pyfory/context.pxi @@ -748,7 +748,6 @@ cdef class ReadContext: cdef readonly object policy cdef readonly int32_t max_depth cdef public int64_t max_graph_memory_bytes - cdef public int64_t graph_memory_limit_bytes cdef public int64_t remaining_graph_memory_bytes cdef readonly RefReader ref_reader cdef readonly MetaStringReader meta_string_reader @@ -771,7 +770,6 @@ cdef class ReadContext: self.policy = config.policy self.max_depth = config.max_depth self.max_graph_memory_bytes = config.max_graph_memory_bytes - self.graph_memory_limit_bytes = 0 self.remaining_graph_memory_bytes = 0 self.ref_reader = RefReader(self.track_ref) self.meta_string_reader = MetaStringReader(self.type_resolver.shared_registry) @@ -796,7 +794,6 @@ cdef class ReadContext: self.buffers = iter(buffers) if buffers is not None else None self.unsupported_objects = iter(unsupported_objects) if unsupported_objects is not None else None self.peer_out_of_band_enabled = peer_out_of_band_enabled - self.graph_memory_limit_bytes = self.max_graph_memory_bytes self.remaining_graph_memory_bytes = self.max_graph_memory_bytes self.depth = 0 @@ -812,7 +809,6 @@ cdef class ReadContext: self.buffers = None self.unsupported_objects = None self.peer_out_of_band_enabled = False - self.graph_memory_limit_bytes = 0 self.remaining_graph_memory_bytes = 0 self.depth = 0 @@ -823,10 +819,10 @@ cdef class ReadContext: if num_bytes > _MAX_GRAPH_MEMORY_BYTES: raise ValueError("Estimated graph memory overflow") if num_bytes > self.remaining_graph_memory_bytes: - used = self.graph_memory_limit_bytes - self.remaining_graph_memory_bytes + used = self.max_graph_memory_bytes - self.remaining_graph_memory_bytes raise ValueError( f"Estimated graph memory budget exceeded: requested {num_bytes} bytes, " - f"used {used} bytes, limit {self.graph_memory_limit_bytes} bytes. " + f"used {used} bytes, limit {self.max_graph_memory_bytes} bytes. " "Increase Fory(..., max_graph_memory_bytes=...) for trusted larger payloads." ) self.remaining_graph_memory_bytes -= num_bytes diff --git a/python/pyfory/context.py b/python/pyfory/context.py index edf30e31d6..35fa4beeb9 100644 --- a/python/pyfory/context.py +++ b/python/pyfory/context.py @@ -472,7 +472,6 @@ class ReadContext: "policy", "max_depth", "max_graph_memory_bytes", - "graph_memory_limit_bytes", "remaining_graph_memory_bytes", "ref_reader", "meta_string_reader", @@ -495,7 +494,6 @@ def __init__(self, config: Config, type_resolver): self.policy = config.policy self.max_depth = config.max_depth self.max_graph_memory_bytes = config.max_graph_memory_bytes - self.graph_memory_limit_bytes = 0 self.remaining_graph_memory_bytes = 0 self.ref_reader = MapRefReader() if self.track_ref else NoRefReader() self.meta_string_reader = MetaStringReader(type_resolver.shared_registry) @@ -531,7 +529,6 @@ def prepare( self.buffers = iter(buffers) if buffers is not None else None self.unsupported_objects = iter(unsupported_objects) if unsupported_objects is not None else None self.peer_out_of_band_enabled = peer_out_of_band_enabled - self.graph_memory_limit_bytes = self.max_graph_memory_bytes self.remaining_graph_memory_bytes = self.max_graph_memory_bytes self.depth = 0 @@ -546,7 +543,6 @@ def reset(self): self.buffers = None self.unsupported_objects = None self.peer_out_of_band_enabled = False - self.graph_memory_limit_bytes = 0 self.remaining_graph_memory_bytes = 0 self.depth = 0 @@ -557,10 +553,10 @@ def reserve_graph_memory(self, num_bytes): raise ValueError("Estimated graph memory overflow") remaining = self.remaining_graph_memory_bytes if num_bytes > remaining: - used = self.graph_memory_limit_bytes - remaining + used = self.max_graph_memory_bytes - remaining raise ValueError( f"Estimated graph memory budget exceeded: requested {num_bytes} bytes, " - f"used {used} bytes, limit {self.graph_memory_limit_bytes} bytes. " + f"used {used} bytes, limit {self.max_graph_memory_bytes} bytes. " "Increase Fory(..., max_graph_memory_bytes=...) for trusted larger payloads." ) self.remaining_graph_memory_bytes = remaining - num_bytes diff --git a/python/pyfory/serialization.pyx b/python/pyfory/serialization.pyx index 9c1fe10643..6f00480793 100644 --- a/python/pyfory/serialization.pyx +++ b/python/pyfory/serialization.pyx @@ -1100,7 +1100,6 @@ cdef class Fory: iter(unsupported_objects) if unsupported_objects is not None else None ) read_context.peer_out_of_band_enabled = peer_out_of_band_enabled - read_context.graph_memory_limit_bytes = self.max_graph_memory_bytes read_context.remaining_graph_memory_bytes = self.max_graph_memory_bytes read_context.depth = 0 return read_context.read_ref() diff --git a/python/pyfory/tests/test_graph_memory_budget.py b/python/pyfory/tests/test_graph_memory_budget.py index aa83a406e5..1b45aeb7e9 100644 --- a/python/pyfory/tests/test_graph_memory_budget.py +++ b/python/pyfory/tests/test_graph_memory_budget.py @@ -130,7 +130,7 @@ def test_fixed_default_budget(): fory = new_fory(xlang=False) try: fory.read_context.prepare(Buffer(b"x" * 17)) - assert fory.read_context.graph_memory_limit_bytes == DEFAULT_GRAPH_MEMORY_BYTES + assert fory.read_context.remaining_graph_memory_bytes == DEFAULT_GRAPH_MEMORY_BYTES fory.read_context.reserve_graph_memory(DEFAULT_GRAPH_MEMORY_BYTES) with pytest.raises(ValueError, match="Estimated graph memory budget exceeded"): fory.read_context.reserve_graph_memory(1) @@ -143,7 +143,7 @@ def test_stream_default_budget(): try: buffer = Buffer.from_stream(OneByteStream(b"streamed")) fory.read_context.prepare(buffer) - assert fory.read_context.graph_memory_limit_bytes == DEFAULT_GRAPH_MEMORY_BYTES + assert fory.read_context.remaining_graph_memory_bytes == DEFAULT_GRAPH_MEMORY_BYTES finally: fory.reset_read() diff --git a/rust/fory-core/src/context.rs b/rust/fory-core/src/context.rs index 13c48a3e47..06a67a9a5c 100644 --- a/rust/fory-core/src/context.rs +++ b/rust/fory-core/src/context.rs @@ -360,7 +360,6 @@ pub struct ReadContext<'a> { check_struct_version: bool, check_string_read: bool, pub(crate) max_graph_memory_bytes: i64, - pub(crate) graph_memory_limit_bytes: usize, pub(crate) remaining_graph_memory_bytes: usize, // Context-specific fields @@ -392,7 +391,6 @@ impl<'a> ReadContext<'a> { check_struct_version: config.check_struct_version, check_string_read: config.check_string_read, max_graph_memory_bytes: config.max_graph_memory_bytes, - graph_memory_limit_bytes: 0, remaining_graph_memory_bytes: 0, reader: Reader::default(), meta_resolver: MetaReaderResolver::default(), @@ -457,7 +455,7 @@ impl<'a> ReadContext<'a> { return Err(graph_memory_exceeded( bytes, remaining, - self.graph_memory_limit_bytes, + self.max_graph_memory_bytes, )); } self.remaining_graph_memory_bytes = remaining - bytes; @@ -576,7 +574,7 @@ impl<'a> ReadContext<'a> { #[cold] #[inline(never)] -fn graph_memory_exceeded(bytes: usize, remaining: usize, limit: usize) -> Error { +fn graph_memory_exceeded(bytes: usize, remaining: usize, limit: i64) -> Error { Error::invalid_data(format!( "estimated graph memory request {} bytes exceeds max_graph_memory_bytes remaining budget {} bytes out of effective limit {} bytes", bytes, remaining, limit diff --git a/rust/fory-core/src/fory.rs b/rust/fory-core/src/fory.rs index c217726776..3c847b0229 100644 --- a/rust/fory-core/src/fory.rs +++ b/rust/fory-core/src/fory.rs @@ -1005,7 +1005,6 @@ impl Fory { .map_err(|_| Error::invalid_data("max_graph_memory_bytes does not fit usize")) { Ok(limit) => { - context.graph_memory_limit_bytes = limit; context.remaining_graph_memory_bytes = limit; self.deserialize_with_context(context) } @@ -1080,7 +1079,6 @@ impl Fory { .map_err(|_| Error::invalid_data("max_graph_memory_bytes does not fit usize")) { Ok(limit) => { - context.graph_memory_limit_bytes = limit; context.remaining_graph_memory_bytes = limit; self.deserialize_with_context(context) } From c2c4dec29336bf7d0314149dd7963125f5186f8d Mon Sep 17 00:00:00 2001 From: chaokunyang Date: Fri, 3 Jul 2026 08:40:22 +0800 Subject: [PATCH 44/54] refactor: clean graph budget read-state drift --- .../serialization/collection_serializer.h | 134 ++---------------- cpp/fory/serialization/context.cc | 2 - .../serialization/graph_memory_budget_test.cc | 21 ++- csharp/src/Fory/ReadContext.cs | 1 - .../Fory.Tests/GraphMemoryBudgetTests.cs | 13 +- docs/guide/cpp/configuration.md | 2 +- go/fory/graph_memory_budget_test.go | 19 ++- go/fory/reader.go | 9 +- go/fory/struct_test.go | 8 +- python/pyfory/context.pxi | 4 +- python/pyfory/context.py | 20 +-- .../pyfory/tests/test_graph_memory_budget.py | 20 +-- rust/fory-core/src/context.rs | 4 +- rust/fory-core/src/fory.rs | 4 +- swift/Sources/Fory/Fory.swift | 4 +- swift/Sources/Fory/ReadContext.swift | 2 - .../ForyTests/GraphMemoryBudgetTests.swift | 16 +-- 17 files changed, 74 insertions(+), 209 deletions(-) diff --git a/cpp/fory/serialization/collection_serializer.h b/cpp/fory/serialization/collection_serializer.h index c54ca9cbb9..433e407cc4 100644 --- a/cpp/fory/serialization/collection_serializer.h +++ b/cpp/fory/serialization/collection_serializer.h @@ -453,7 +453,9 @@ inline bool reserve_collection(std::vector &result, // Helper to insert element into container (vector or set) template inline void collection_insert(Container &result, T &&elem) { - if constexpr (has_push_back_v) { + if constexpr (is_forward_list_v) { + result.push_front(std::forward(elem)); + } else if constexpr (has_push_back_v) { result.push_back(std::forward(elem)); } else { result.insert(std::forward(elem)); @@ -536,8 +538,9 @@ inline Container read_collection_data_slow(ReadContext &ctx, uint32_t length) { } bool has_value = read_null_only_flag(ctx, RefMode::NullOnly); if (!has_value) { - if constexpr (has_push_back_v) { - result.push_back(T{}); + if constexpr (has_push_back_v || + is_forward_list_v) { + collection_insert(result, T{}); } // For sets, skip null elements } else { @@ -562,8 +565,9 @@ inline Container read_collection_data_slow(ReadContext &ctx, uint32_t length) { } bool has_value = read_null_only_flag(ctx, RefMode::NullOnly); if (!has_value) { - if constexpr (has_push_back_v) { - result.push_back(T{}); + if constexpr (has_push_back_v || + is_forward_list_v) { + collection_insert(result, T{}); } } else { // Read type info + data without ref flag @@ -584,122 +588,9 @@ inline Container read_collection_data_slow(ReadContext &ctx, uint32_t length) { } } - return result; -} - -/// Read forward_list data without a temporary vector so budget accounting only -/// covers the destination container's portable lower-bound storage. -template -inline std::forward_list -read_forward_list_data_slow(ReadContext &ctx, uint32_t length) { - std::forward_list result; - if (length == 0) { - return result; - } - - constexpr bool elem_is_polymorphic = is_polymorphic_v; - - uint8_t bitmap = ctx.read_uint8(ctx.error()); - if (FORY_PREDICT_FALSE(ctx.has_error())) { - return result; - } - - bool track_ref = (bitmap & COLL_TRACKING_REF) != 0; - bool has_null = (bitmap & COLL_HAS_NULL) != 0; - bool is_decl_type = (bitmap & COLL_DECL_ELEMENT_TYPE) != 0; - bool is_same_type = (bitmap & COLL_IS_SAME_TYPE) != 0; - - const TypeInfo *elem_type_info = nullptr; - if (is_same_type && !is_decl_type) { - elem_type_info = ctx.read_any_type_info(ctx.error()); - if (FORY_PREDICT_FALSE(ctx.has_error())) { - return result; - } - } - - if (FORY_PREDICT_FALSE(!reserve_collection(result, ctx, length))) { - return result; - } - - auto tail = result.before_begin(); - auto append = [&](T &&elem) { - tail = result.insert_after(tail, std::move(elem)); - }; - auto append_default = [&]() { tail = result.emplace_after(tail); }; - - if (is_same_type) { - if (track_ref) { - for (uint32_t i = 0; i < length; ++i) { - if (FORY_PREDICT_FALSE(ctx.has_error())) { - return result; - } - if constexpr (elem_is_polymorphic) { - auto elem = Serializer::read_with_type_info(ctx, RefMode::Tracking, - *elem_type_info); - append(std::move(elem)); - } else { - auto elem = Serializer::read(ctx, RefMode::Tracking, false); - append(std::move(elem)); - } - } - } else if (!has_null) { - for (uint32_t i = 0; i < length; ++i) { - if (FORY_PREDICT_FALSE(ctx.has_error())) { - return result; - } - if constexpr (elem_is_polymorphic) { - auto elem = Serializer::read_with_type_info(ctx, RefMode::None, - *elem_type_info); - append(std::move(elem)); - } else { - auto elem = Serializer::read(ctx, RefMode::None, false); - append(std::move(elem)); - } - } - } else { - for (uint32_t i = 0; i < length; ++i) { - if (FORY_PREDICT_FALSE(ctx.has_error())) { - return result; - } - bool has_value = read_null_only_flag(ctx, RefMode::NullOnly); - if (!has_value) { - append_default(); - } else if constexpr (elem_is_polymorphic) { - auto elem = Serializer::read_with_type_info(ctx, RefMode::None, - *elem_type_info); - append(std::move(elem)); - } else { - auto elem = Serializer::read(ctx, RefMode::None, false); - append(std::move(elem)); - } - } - } - } else { - if (has_null && !track_ref) { - for (uint32_t i = 0; i < length; ++i) { - if (FORY_PREDICT_FALSE(ctx.has_error())) { - return result; - } - bool has_value = read_null_only_flag(ctx, RefMode::NullOnly); - if (!has_value) { - append_default(); - } else { - auto elem = Serializer::read(ctx, RefMode::None, true); - append(std::move(elem)); - } - } - } else { - for (uint32_t i = 0; i < length; ++i) { - if (FORY_PREDICT_FALSE(ctx.has_error())) { - return result; - } - auto elem = Serializer::read( - ctx, track_ref ? RefMode::Tracking : RefMode::None, true); - append(std::move(elem)); - } - } + if constexpr (is_forward_list_v) { + result.reverse(); } - return result; } @@ -1817,7 +1708,8 @@ struct Serializer> { // Dispatch to slow path for polymorphic/shared-ref elements constexpr bool is_slow_path = is_polymorphic_v || is_shared_ref_v; if constexpr (is_slow_path) { - return read_forward_list_data_slow(ctx, length); + return read_collection_data_slow>(ctx, + length); } else { auto tail = result.before_begin(); // Fast path for non-polymorphic, non-shared-ref elements diff --git a/cpp/fory/serialization/context.cc b/cpp/fory/serialization/context.cc index c2837fc84b..7e8569d218 100644 --- a/cpp/fory/serialization/context.cc +++ b/cpp/fory/serialization/context.cc @@ -437,8 +437,6 @@ ReadContext::ReadContext(const Config &config, type_resolver_(std::move(type_resolver)), current_dyn_depth_(0) { FORY_CHECK(config.max_graph_memory_bytes > 0) << "max_graph_memory_bytes must be positive"; - remaining_graph_memory_bytes_ = - static_cast(config.max_graph_memory_bytes); } ReadContext::~ReadContext() = default; diff --git a/cpp/fory/serialization/graph_memory_budget_test.cc b/cpp/fory/serialization/graph_memory_budget_test.cc index 3410cc0dd9..5cbe55d8f1 100644 --- a/cpp/fory/serialization/graph_memory_budget_test.cc +++ b/cpp/fory/serialization/graph_memory_budget_test.cc @@ -414,17 +414,16 @@ TEST(GraphMemoryBudgetTest, DensePathsSkipped) { } TEST(GraphMemoryBudgetTest, ByteCheckStillRejectsLargeLength) { - Config config; - auto resolver = std::make_unique(); - ReadContext ctx(config, std::move(resolver)); - std::vector bytes{64}; - Buffer buffer(bytes.data(), static_cast(bytes.size()), false); - ctx.attach(buffer); - - auto result = Serializer>::read_data(ctx); - EXPECT_TRUE(result.empty()); - ASSERT_TRUE(ctx.has_error()); - EXPECT_EQ(ctx.error().code(), ErrorCode::BufferOutOfBound); + auto bytes = serialize_value(std::vector{}); + ASSERT_FALSE(bytes.empty()); + bytes.back() = 64; + bytes.push_back(0); + + auto result = with_fory(kDefaultGraphMemoryBytes, [&](Fory &fory) { + return fory.deserialize>(bytes); + }); + ASSERT_FALSE(result.ok()); + EXPECT_EQ(result.error().code(), ErrorCode::BufferOutOfBound); } } // namespace diff --git a/csharp/src/Fory/ReadContext.cs b/csharp/src/Fory/ReadContext.cs index 05d1bec5ca..13948ab1d7 100644 --- a/csharp/src/Fory/ReadContext.cs +++ b/csharp/src/Fory/ReadContext.cs @@ -62,7 +62,6 @@ public ReadContext( RefReader = new RefReader(); _maxDynamicReadDepth = config.MaxDepth; _config = config; - _remainingGraphMemoryBytes = config.MaxGraphMemoryBytes; } public ByteReader Reader { get; private set; } diff --git a/csharp/tests/Fory.Tests/GraphMemoryBudgetTests.cs b/csharp/tests/Fory.Tests/GraphMemoryBudgetTests.cs index fb9773795b..7f00f2515e 100644 --- a/csharp/tests/Fory.Tests/GraphMemoryBudgetTests.cs +++ b/csharp/tests/Fory.Tests/GraphMemoryBudgetTests.cs @@ -148,10 +148,8 @@ public void DefaultFixedBudgetAndValidation() Assert.Throws(() => NewFory(0)); Assert.Throws(() => NewFory(-2)); - ReadContext context = new(new ByteReader([]), new TypeResolver(), NewFory().Config); - context._remainingGraphMemoryBytes = DefaultGraphMemoryBytes; - context.ReserveGraphMemory(DefaultGraphMemoryBytes); - Assert.Throws(() => context.ReserveGraphMemory(ReferenceBytes)); + List> value = Enumerable.Range(0, 3).Select(_ => new List()).ToList(); + Assert.Equal(value.Count, NewFory().Deserialize>>(Serialize(value)).Count); } [Fact] @@ -336,9 +334,10 @@ public void CompatibleListToDenseArrayIsSkipped() [Fact] public void ByteChecksRejectLargeLength() { - byte[] bytes = [64, 0]; - ReadContext context = new(new ByteReader(bytes), new TypeResolver(), NewFory().Config); + byte[] bytes = Serialize(new List()); + bytes[^1] = 64; + Array.Resize(ref bytes, bytes.Length + 1); - Assert.Throws(() => new ListSerializer().ReadData(context)); + Assert.Throws(() => NewFory().Deserialize>(bytes)); } } diff --git a/docs/guide/cpp/configuration.md b/docs/guide/cpp/configuration.md index 4239fcaed9..906796431e 100644 --- a/docs/guide/cpp/configuration.md +++ b/docs/guide/cpp/configuration.md @@ -120,7 +120,7 @@ primitive dense-array payloads continue to rely on their byte-availability checks instead. `std::vector` is counted as packed standard-container storage. -**Default:** `-1` +**Default:** `128 MiB` ### max_dyn_depth(uint32_t) diff --git a/go/fory/graph_memory_budget_test.go b/go/fory/graph_memory_budget_test.go index f75546cf8e..81250c9693 100644 --- a/go/fory/graph_memory_budget_test.go +++ b/go/fory/graph_memory_budget_test.go @@ -51,16 +51,15 @@ func TestGraphMemoryBudgetConfig(t *testing.T) { } func TestGraphMemoryBudgetFixedDefault(t *testing.T) { - ctx := NewReadContext(false) - ctx.remainingGraphMemoryBytes = 128 * 1024 * 1024 - require.Equal(t, int64(128*1024*1024), ctx.remainingGraphMemoryBytes) - require.True(t, ctx.ReserveGraphMemory(ctx.remainingGraphMemoryBytes)) - require.False(t, ctx.ReserveGraphMemory(1)) - require.Contains(t, ctx.CheckError().Error(), "maxGraphMemoryBytes") - - ctx = NewReadContext(false) - ctx.remainingGraphMemoryBytes = 77 - require.Equal(t, int64(77), ctx.remainingGraphMemoryBytes) + writer := New(WithCompatible(false)) + value := []any{[]any{}, []any{}, []any{}} + data, err := writer.Serialize(value) + require.NoError(t, err) + + var out []any + err = New(WithCompatible(false)).Deserialize(data, &out) + require.NoError(t, err) + require.Len(t, out, len(value)) } func TestGraphBudgetRootKinds(t *testing.T) { diff --git a/go/fory/reader.go b/go/fory/reader.go index f550e427d3..80fca7319f 100644 --- a/go/fory/reader.go +++ b/go/fory/reader.go @@ -82,11 +82,10 @@ func (c *ReadContext) IsXlang() bool { // NewReadContext creates a new read context func NewReadContext(trackRef bool) *ReadContext { return &ReadContext{ - buffer: NewByteBuffer(nil), - refReader: NewRefReader(trackRef), - trackRef: trackRef, - maxDepth: 128, // Default maximum nesting depth - remainingGraphMemoryBytes: 128 * 1024 * 1024, + buffer: NewByteBuffer(nil), + refReader: NewRefReader(trackRef), + trackRef: trackRef, + maxDepth: 128, // Default maximum nesting depth } } diff --git a/go/fory/struct_test.go b/go/fory/struct_test.go index ac5d8ad6f0..262a404135 100644 --- a/go/fory/struct_test.go +++ b/go/fory/struct_test.go @@ -607,13 +607,13 @@ func TestSkipAnyValueReadsSharedTypeMeta(t *testing.T) { readHeader(f.readCtx) SkipAnyValue(f.readCtx, true) require.NoError(t, f.readCtx.CheckError()) + nextRootIndex := f.readCtx.Buffer().ReaderIndex() - f.resetReadState() - readHeader(f.readCtx) + rootBuffer := NewByteBuffer(buf.Bytes()) + rootBuffer.SetReaderIndex(nextRootIndex) var out any - f.readCtx.ReadValue(reflect.ValueOf(&out).Elem(), RefModeTracking, true) - require.NoError(t, f.readCtx.CheckError()) + require.NoError(t, f.DeserializeFrom(rootBuffer, &out)) result, ok := out.(*Second) require.True(t, ok) diff --git a/python/pyfory/context.pxi b/python/pyfory/context.pxi index e53bc20175..72e6c50077 100644 --- a/python/pyfory/context.pxi +++ b/python/pyfory/context.pxi @@ -747,8 +747,8 @@ cdef class ReadContext: cdef readonly bint field_nullable cdef readonly object policy cdef readonly int32_t max_depth - cdef public int64_t max_graph_memory_bytes - cdef public int64_t remaining_graph_memory_bytes + cdef int64_t max_graph_memory_bytes + cdef int64_t remaining_graph_memory_bytes cdef readonly RefReader ref_reader cdef readonly MetaStringReader meta_string_reader cdef readonly MetaShareReadContext meta_share_context diff --git a/python/pyfory/context.py b/python/pyfory/context.py index 35fa4beeb9..b4dbc0889e 100644 --- a/python/pyfory/context.py +++ b/python/pyfory/context.py @@ -471,8 +471,8 @@ class ReadContext: "field_nullable", "policy", "max_depth", - "max_graph_memory_bytes", - "remaining_graph_memory_bytes", + "_max_graph_memory_bytes", + "_remaining_graph_memory_bytes", "ref_reader", "meta_string_reader", "meta_share_context", @@ -493,8 +493,8 @@ def __init__(self, config: Config, type_resolver): self.field_nullable = config.field_nullable self.policy = config.policy self.max_depth = config.max_depth - self.max_graph_memory_bytes = config.max_graph_memory_bytes - self.remaining_graph_memory_bytes = 0 + self._max_graph_memory_bytes = config.max_graph_memory_bytes + self._remaining_graph_memory_bytes = 0 self.ref_reader = MapRefReader() if self.track_ref else NoRefReader() self.meta_string_reader = MetaStringReader(type_resolver.shared_registry) self.meta_share_context = MetaShareReadContext() if config.scoped_meta_share_enabled else None @@ -529,7 +529,7 @@ def prepare( self.buffers = iter(buffers) if buffers is not None else None self.unsupported_objects = iter(unsupported_objects) if unsupported_objects is not None else None self.peer_out_of_band_enabled = peer_out_of_band_enabled - self.remaining_graph_memory_bytes = self.max_graph_memory_bytes + self._remaining_graph_memory_bytes = self._max_graph_memory_bytes self.depth = 0 def reset(self): @@ -543,7 +543,7 @@ def reset(self): self.buffers = None self.unsupported_objects = None self.peer_out_of_band_enabled = False - self.remaining_graph_memory_bytes = 0 + self._remaining_graph_memory_bytes = 0 self.depth = 0 def reserve_graph_memory(self, num_bytes): @@ -551,15 +551,15 @@ def reserve_graph_memory(self, num_bytes): raise ValueError("Estimated graph memory is negative") if num_bytes > _MAX_GRAPH_MEMORY_BYTES: raise ValueError("Estimated graph memory overflow") - remaining = self.remaining_graph_memory_bytes + remaining = self._remaining_graph_memory_bytes if num_bytes > remaining: - used = self.max_graph_memory_bytes - remaining + used = self._max_graph_memory_bytes - remaining raise ValueError( f"Estimated graph memory budget exceeded: requested {num_bytes} bytes, " - f"used {used} bytes, limit {self.max_graph_memory_bytes} bytes. " + f"used {used} bytes, limit {self._max_graph_memory_bytes} bytes. " "Increase Fory(..., max_graph_memory_bytes=...) for trusted larger payloads." ) - self.remaining_graph_memory_bytes = remaining - num_bytes + self._remaining_graph_memory_bytes = remaining - num_bytes def add_context_object(self, key, obj): self.context_objects[id(key)] = obj diff --git a/python/pyfory/tests/test_graph_memory_budget.py b/python/pyfory/tests/test_graph_memory_budget.py index 1b45aeb7e9..6f95acb7c0 100644 --- a/python/pyfory/tests/test_graph_memory_budget.py +++ b/python/pyfory/tests/test_graph_memory_budget.py @@ -127,25 +127,17 @@ def varuint_payload(value): def test_fixed_default_budget(): + assert pyfory.Fory(xlang=False, ref=True).max_graph_memory_bytes == DEFAULT_GRAPH_MEMORY_BYTES fory = new_fory(xlang=False) - try: - fory.read_context.prepare(Buffer(b"x" * 17)) - assert fory.read_context.remaining_graph_memory_bytes == DEFAULT_GRAPH_MEMORY_BYTES - fory.read_context.reserve_graph_memory(DEFAULT_GRAPH_MEMORY_BYTES) - with pytest.raises(ValueError, match="Estimated graph memory budget exceeded"): - fory.read_context.reserve_graph_memory(1) - finally: - fory.reset_read() + value = [[], [], []] + assert fory.deserialize(fory.serialize(value)) == value def test_stream_default_budget(): fory = new_fory(xlang=False) - try: - buffer = Buffer.from_stream(OneByteStream(b"streamed")) - fory.read_context.prepare(buffer) - assert fory.read_context.remaining_graph_memory_bytes == DEFAULT_GRAPH_MEMORY_BYTES - finally: - fory.reset_read() + value = [[], [], []] + data = fory.serialize(value) + assert fory.deserialize(Buffer.from_stream(OneByteStream(data))) == value def test_explicit_budget(): diff --git a/rust/fory-core/src/context.rs b/rust/fory-core/src/context.rs index 06a67a9a5c..eab540e7dd 100644 --- a/rust/fory-core/src/context.rs +++ b/rust/fory-core/src/context.rs @@ -359,7 +359,6 @@ pub struct ReadContext<'a> { max_dyn_depth: u32, check_struct_version: bool, check_string_read: bool, - pub(crate) max_graph_memory_bytes: i64, pub(crate) remaining_graph_memory_bytes: usize, // Context-specific fields @@ -390,7 +389,6 @@ impl<'a> ReadContext<'a> { max_dyn_depth: config.max_dyn_depth, check_struct_version: config.check_struct_version, check_string_read: config.check_string_read, - max_graph_memory_bytes: config.max_graph_memory_bytes, remaining_graph_memory_bytes: 0, reader: Reader::default(), meta_resolver: MetaReaderResolver::default(), @@ -455,7 +453,7 @@ impl<'a> ReadContext<'a> { return Err(graph_memory_exceeded( bytes, remaining, - self.max_graph_memory_bytes, + self.config.max_graph_memory_bytes, )); } self.remaining_graph_memory_bytes = remaining - bytes; diff --git a/rust/fory-core/src/fory.rs b/rust/fory-core/src/fory.rs index 3c847b0229..ee09e865c4 100644 --- a/rust/fory-core/src/fory.rs +++ b/rust/fory-core/src/fory.rs @@ -1001,7 +1001,7 @@ impl Fory { self.with_read_context(|context| { let outlive_buffer = unsafe { mem::transmute::<&[u8], &[u8]>(bf) }; context.attach_reader(Reader::new(outlive_buffer)); - let result = match usize::try_from(context.max_graph_memory_bytes) + let result = match usize::try_from(self.config.max_graph_memory_bytes) .map_err(|_| Error::invalid_data("max_graph_memory_bytes does not fit usize")) { Ok(limit) => { @@ -1075,7 +1075,7 @@ impl Fory { let mut new_reader = Reader::new(outlive_buffer); new_reader.set_cursor(reader.cursor); context.attach_reader(new_reader); - let result = match usize::try_from(context.max_graph_memory_bytes) + let result = match usize::try_from(self.config.max_graph_memory_bytes) .map_err(|_| Error::invalid_data("max_graph_memory_bytes does not fit usize")) { Ok(limit) => { diff --git a/swift/Sources/Fory/Fory.swift b/swift/Sources/Fory/Fory.swift index 4b653bcba7..8abad7bfb9 100644 --- a/swift/Sources/Fory/Fory.swift +++ b/swift/Sources/Fory/Fory.swift @@ -496,7 +496,7 @@ public final class Fory { _ body: (ReadContext) throws -> R ) throws -> R { readContext.buffer.replace(with: data) - readContext.remainingGraphMemoryBytes = readContext.maxGraphMemoryBytes + readContext.remainingGraphMemoryBytes = Int(self.config.maxGraphMemoryBytes) defer { readContext.reset() } @@ -558,7 +558,7 @@ public final class Fory { ) throws -> R { try typeResolver.finishRegistration() readContext.buffer.swapState(with: buffer) - readContext.remainingGraphMemoryBytes = readContext.maxGraphMemoryBytes + readContext.remainingGraphMemoryBytes = Int(self.config.maxGraphMemoryBytes) defer { readContext.buffer.swapState(with: buffer) readContext.reset() diff --git a/swift/Sources/Fory/ReadContext.swift b/swift/Sources/Fory/ReadContext.swift index 0e15f5f481..ad90f10a40 100644 --- a/swift/Sources/Fory/ReadContext.swift +++ b/swift/Sources/Fory/ReadContext.swift @@ -35,7 +35,6 @@ public final class ReadContext { private var typeInfoScopeStack: [(typeKey: UInt64, previousTypeInfo: TypeInfo?)] = [] private var lastTypeInfo = TypeInfo.uncached private let config: Config - let maxGraphMemoryBytes: Int var remainingGraphMemoryBytes = 0 init( @@ -50,7 +49,6 @@ public final class ReadContext { self.checkClassVersion = config.checkClassVersion self.maxDepth = config.maxDepth self.config = config - self.maxGraphMemoryBytes = Int(config.maxGraphMemoryBytes) self.refReader = RefReader() } diff --git a/swift/Tests/ForyTests/GraphMemoryBudgetTests.swift b/swift/Tests/ForyTests/GraphMemoryBudgetTests.swift index 0ecc4b857e..1d77b77dee 100644 --- a/swift/Tests/ForyTests/GraphMemoryBudgetTests.swift +++ b/swift/Tests/ForyTests/GraphMemoryBudgetTests.swift @@ -147,18 +147,10 @@ private func expectInvalidData(_ body: () throws -> Void) { @Test func fixedDefaultBudget() throws { - let config = Config(trackRef: false, compatible: false) - let context = ReadContext( - buffer: ByteBuffer(), - typeResolver: TypeResolver(config: config), - config: config - ) - - context.remainingGraphMemoryBytes = context.maxGraphMemoryBytes - try context.reserveGraphMemory(Int(defaultGraphMemoryBytes)) - expectInvalidData { - try context.reserveGraphMemory(testReferenceBytes) - } + let fory = makeBudgetFory() + #expect(fory.config.maxGraphMemoryBytes == defaultGraphMemoryBytes) + let value = Array(repeating: [String](), count: 3) + #expect(try fory.deserialize(try fory.serialize(value)) == value) } @Test From 8c53f75df19f1c410e5f8af59f991185bb32f775 Mon Sep 17 00:00:00 2001 From: chaokunyang Date: Fri, 3 Jul 2026 08:47:40 +0800 Subject: [PATCH 45/54] docs: remove stale container budget wording --- docs/guide/dart/configuration.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/guide/dart/configuration.md b/docs/guide/dart/configuration.md index 2fced7675f..4e3c005cf9 100644 --- a/docs/guide/dart/configuration.md +++ b/docs/guide/dart/configuration.md @@ -117,8 +117,8 @@ protected by byte-availability checks. The default is a fixed `128 MiB` and is not derived from input size. -Set a positive value when a trusted workload legitimately contains compact, container-heavy -payloads: +Set a positive value when a trusted workload legitimately contains compact, graph-heavy object or +collection payloads: ```dart final fory = Fory(maxGraphMemoryBytes: 256 * 1024 * 1024); From 12742078678f66fbd764e9b7e63b53f4edc90130 Mon Sep 17 00:00:00 2001 From: chaokunyang Date: Fri, 3 Jul 2026 08:56:45 +0800 Subject: [PATCH 46/54] refactor: remove stale graph budget helpers --- cpp/fory/serialization/serializer_traits.h | 80 ---------------------- csharp/src/Fory/TypeInfo.cs | 11 +-- go/fory/codegen/generator.go | 23 ------- 3 files changed, 1 insertion(+), 113 deletions(-) diff --git a/cpp/fory/serialization/serializer_traits.h b/cpp/fory/serialization/serializer_traits.h index 38ff4db440..854a083a4a 100644 --- a/cpp/fory/serialization/serializer_traits.h +++ b/cpp/fory/serialization/serializer_traits.h @@ -19,7 +19,6 @@ #pragma once -#include "fory/meta/field.h" #include "fory/meta/field_info.h" #include "fory/meta/type_index.h" #include "fory/meta/type_traits.h" @@ -328,85 +327,6 @@ template constexpr size_t graph_value_owner_self_bytes() { } } -template -struct has_graph_budget_children : std::false_type {}; - -template -struct has_graph_budget_children, void> - : std::bool_constant>> { -}; - -template -struct has_graph_budget_children, void> - : std::true_type {}; - -template -struct has_graph_budget_children< - T, std::enable_if_t || is_deque_v || is_forward_list_v || - is_set_like_v || is_map_like_v>> - : std::true_type {}; - -template -struct has_graph_budget_children, void> - : has_graph_budget_children>> { -}; - -template -struct has_graph_budget_children, void> - : has_graph_budget_children>> { -}; - -template -struct has_graph_budget_children, void> - : std::bool_constant<(graph_value_owner_self_bytes() != 0) || - has_graph_budget_children>>::value> {}; - -template -struct has_graph_budget_children, void> - : std::bool_constant<(graph_value_owner_self_bytes() != 0) || - has_graph_budget_children>>::value> {}; - -template -struct has_graph_budget_children, void> - : std::bool_constant<(has_graph_budget_children>>::value || - ...)> {}; - -template -struct has_graph_budget_children, void> - : std::bool_constant<(has_graph_budget_children>>::value || - ...)> {}; - -template -constexpr bool struct_has_graph_children_impl(std::index_sequence) { - return ( - has_graph_budget_children< - std::remove_cv_t>>>>::value || - ...); -} - -template -struct has_graph_budget_children>> { -private: - using Value = std::remove_cv_t>; - using Ptrs = - decltype(::fory::meta::fory_field_info(std::declval()) - .ptrs()); - -public: - static constexpr bool value = struct_has_graph_children_impl( - std::make_index_sequence>{}); -}; - -template -inline constexpr bool has_graph_budget_children_v = has_graph_budget_children< - std::remove_cv_t>>::value; - template FORY_ALWAYS_INLINE bool reserve_allocated_value_owner(Context &ctx) { constexpr size_t bytes = graph_value_owner_self_bytes(); diff --git a/csharp/src/Fory/TypeInfo.cs b/csharp/src/Fory/TypeInfo.cs index 6969eb3573..d0c43e082c 100644 --- a/csharp/src/Fory/TypeInfo.cs +++ b/csharp/src/Fory/TypeInfo.cs @@ -122,7 +122,7 @@ internal static TypeInfo Create(Type type, Serializer serializer) context => ReadDataObject(serializer, context, boxedValueBytes), (context, value, refMode, writeTypeInfo, hasGenerics) => WriteObject(serializer, context, value, refMode, writeTypeInfo, hasGenerics), - (context, refMode, readTypeInfo) => ReadObject(serializer, context, refMode, readTypeInfo), + (context, refMode, readTypeInfo) => serializer.Read(context, refMode, readTypeInfo), typeMetaFields, builtInTypeId, null); @@ -200,15 +200,6 @@ private static void WriteObject( serializer.Write(context, CoerceRuntimeValue(serializer, value), refMode, writeTypeInfo, hasGenerics); } - private static object? ReadObject( - Serializer serializer, - ReadContext context, - RefMode refMode, - bool readTypeInfo) - { - return serializer.Read(context, refMode, readTypeInfo); - } - private static T CoerceRuntimeValue(Serializer serializer, object? value) { if (value is T typed) diff --git a/go/fory/codegen/generator.go b/go/fory/codegen/generator.go index c054ea462b..56f615163d 100644 --- a/go/fory/codegen/generator.go +++ b/go/fory/codegen/generator.go @@ -21,7 +21,6 @@ import ( "bytes" "fmt" "go/format" - "go/types" "io/ioutil" "log" "os" @@ -34,16 +33,6 @@ import ( var logger = log.New(os.Stdout, "", 0) -func typeNeedsGraphReservation(t types.Type) bool { - if _, ok := t.(*types.Slice); ok { - return true - } - if _, ok := t.(*types.Map); ok { - return true - } - return false -} - // GeneratorOptions contains configuration for the code generator type GeneratorOptions struct { TypeList string // comma-separated list of types to generate code for @@ -307,12 +296,6 @@ func generateCodeForFile(pkg *packages.Package, structs []*StructInfo, sourceFil } if field.IsOptional { needsOptional = true - if field.OptionalElem != nil && typeNeedsGraphReservation(field.OptionalElem) { - needsUnsafe = true - } - } - if typeNeedsGraphReservation(field.Type) { - needsUnsafe = true } } } @@ -580,12 +563,6 @@ func generateCode(pkg *packages.Package, structs []*StructInfo) error { } if field.IsOptional { needsOptional = true - if field.OptionalElem != nil && typeNeedsGraphReservation(field.OptionalElem) { - needsUnsafe = true - } - } - if typeNeedsGraphReservation(field.Type) { - needsUnsafe = true } } } From 3ad18ecceb1376951c15bdd7b1ce2b1a43be75dc Mon Sep 17 00:00:00 2001 From: chaokunyang Date: Fri, 3 Jul 2026 09:26:13 +0800 Subject: [PATCH 47/54] fix(swift): restore read context any readers --- swift/Sources/Fory/AnySerializer.swift | 147 ++++++++++++++++--------- swift/Sources/Fory/FieldSkipper.swift | 3 +- swift/Sources/Fory/ReadContext.swift | 2 +- swift/Tests/ForyTests/AnyTests.swift | 84 ++++++++++++++ 4 files changed, 179 insertions(+), 57 deletions(-) diff --git a/swift/Sources/Fory/AnySerializer.swift b/swift/Sources/Fory/AnySerializer.swift index a3926eed53..0c33cfa06c 100644 --- a/swift/Sources/Fory/AnySerializer.swift +++ b/swift/Sources/Fory/AnySerializer.swift @@ -588,7 +588,7 @@ public func readAny( refMode: RefMode, readTypeInfo: Bool = true ) throws -> Any? { - try SerializableAny.foryRead(context, refMode: refMode, readTypeInfo: readTypeInfo).anyValue() + try context.readAny(refMode: refMode, readTypeInfo: readTypeInfo) } public func writeListOfAny( @@ -612,16 +612,7 @@ public func readListOfAny( refMode: RefMode, readTypeInfo: Bool = false ) throws -> [Any]? { - let wrapped: [SerializableAny]? = try [SerializableAny]?.foryRead( - context, - refMode: refMode, - readTypeInfo: readTypeInfo - ) - guard let wrapped else { - return nil - } - try reserveAnyReferenceArrayMemory(context, count: wrapped.count) - return wrapped.map { $0.anyValueForCollection() } + try context.readListOfAny(refMode: refMode, readTypeInfo: readTypeInfo) } public func writeMapStringToAny( @@ -647,21 +638,7 @@ public func readMapStringToAny( refMode: RefMode, readTypeInfo: Bool = false ) throws -> [String: Any]? { - let wrapped: [String: SerializableAny]? = try [String: SerializableAny]?.foryRead( - context, - refMode: refMode, - readTypeInfo: readTypeInfo - ) - guard let wrapped else { - return nil - } - try reserveAnyReferenceMapMemory(context, [String: Any].self, count: wrapped.count) - var map: [String: Any] = [:] - map.reserveCapacity(wrapped.count) - for pair in wrapped { - map[pair.key] = pair.value.anyValueForCollection() - } - return map + try context.readMapStringToAny(refMode: refMode, readTypeInfo: readTypeInfo) } public func writeMapInt32ToAny( @@ -687,21 +664,7 @@ public func readMapInt32ToAny( refMode: RefMode, readTypeInfo: Bool = false ) throws -> [Int32: Any]? { - let wrapped: [Int32: SerializableAny]? = try [Int32: SerializableAny]?.foryRead( - context, - refMode: refMode, - readTypeInfo: readTypeInfo - ) - guard let wrapped else { - return nil - } - try reserveAnyReferenceMapMemory(context, [Int32: Any].self, count: wrapped.count) - var map: [Int32: Any] = [:] - map.reserveCapacity(wrapped.count) - for pair in wrapped { - map[pair.key] = pair.value.anyValueForCollection() - } - return map + try context.readMapInt32ToAny(refMode: refMode, readTypeInfo: readTypeInfo) } public func writeMapAnyHashableToAny( @@ -727,21 +690,97 @@ public func readMapAnyHashableToAny( refMode: RefMode, readTypeInfo: Bool = false ) throws -> [AnyHashable: Any]? { - let wrapped: [AnyHashable: SerializableAny]? = try [AnyHashable: SerializableAny]?.foryRead( - context, - refMode: refMode, - readTypeInfo: readTypeInfo - ) - guard let wrapped else { - return nil + try context.readMapAnyHashableToAny(refMode: refMode, readTypeInfo: readTypeInfo) +} + +public extension ReadContext { + // Swift `Any` cannot conform to `Serializer`, so generated and hand-written dynamic-Any + // readers must enter through these context methods instead of trying `Any.foryRead(...)`. + func readAny( + refMode: RefMode, + readTypeInfo: Bool = true + ) throws -> Any? { + try SerializableAny.foryRead(self, refMode: refMode, readTypeInfo: readTypeInfo).anyValue() } - try reserveAnyReferenceMapMemory(context, [AnyHashable: Any].self, count: wrapped.count) - var map: [AnyHashable: Any] = [:] - map.reserveCapacity(wrapped.count) - for pair in wrapped { - map[pair.key] = pair.value.anyValueForCollection() + + func readListOfAny( + refMode: RefMode, + readTypeInfo: Bool = false + ) throws -> [Any]? { + let wrapped: [SerializableAny]? = try [SerializableAny]?.foryRead( + self, + refMode: refMode, + readTypeInfo: readTypeInfo + ) + guard let wrapped else { + return nil + } + try reserveAnyReferenceArrayMemory(self, count: wrapped.count) + return wrapped.map { $0.anyValueForCollection() } + } + + func readMapStringToAny( + refMode: RefMode, + readTypeInfo: Bool = false + ) throws -> [String: Any]? { + let wrapped: [String: SerializableAny]? = try [String: SerializableAny]?.foryRead( + self, + refMode: refMode, + readTypeInfo: readTypeInfo + ) + guard let wrapped else { + return nil + } + try reserveAnyReferenceMapMemory(self, [String: Any].self, count: wrapped.count) + var map: [String: Any] = [:] + map.reserveCapacity(wrapped.count) + for pair in wrapped { + map[pair.key] = pair.value.anyValueForCollection() + } + return map + } + + func readMapInt32ToAny( + refMode: RefMode, + readTypeInfo: Bool = false + ) throws -> [Int32: Any]? { + let wrapped: [Int32: SerializableAny]? = try [Int32: SerializableAny]?.foryRead( + self, + refMode: refMode, + readTypeInfo: readTypeInfo + ) + guard let wrapped else { + return nil + } + try reserveAnyReferenceMapMemory(self, [Int32: Any].self, count: wrapped.count) + var map: [Int32: Any] = [:] + map.reserveCapacity(wrapped.count) + for pair in wrapped { + map[pair.key] = pair.value.anyValueForCollection() + } + return map + } + + func readMapAnyHashableToAny( + refMode: RefMode, + readTypeInfo: Bool = false + ) throws -> [AnyHashable: Any]? { + let wrapped: [AnyHashable: SerializableAny]? = try [AnyHashable: SerializableAny]?.foryRead( + self, + refMode: refMode, + readTypeInfo: readTypeInfo + ) + guard let wrapped else { + return nil + } + try reserveAnyReferenceMapMemory(self, [AnyHashable: Any].self, count: wrapped.count) + var map: [AnyHashable: Any] = [:] + map.reserveCapacity(wrapped.count) + for pair in wrapped { + map[pair.key] = pair.value.anyValueForCollection() + } + return map } - return map } func readDynamicAnyMapValue(context: ReadContext) throws -> Any { diff --git a/swift/Sources/Fory/FieldSkipper.swift b/swift/Sources/Fory/FieldSkipper.swift index b9e67ade84..9183d2f58f 100644 --- a/swift/Sources/Fory/FieldSkipper.swift +++ b/swift/Sources/Fory/FieldSkipper.swift @@ -395,7 +395,6 @@ extension ReadContext { private func readSkippedUnion() throws -> Any { _ = try buffer.readVarUInt32() - return try readAny(context: self, refMode: .tracking, readTypeInfo: true) - ?? ForyAnyNullValue() + return try readAny(refMode: .tracking, readTypeInfo: true) ?? ForyAnyNullValue() } } diff --git a/swift/Sources/Fory/ReadContext.swift b/swift/Sources/Fory/ReadContext.swift index ad90f10a40..9a92d7f3ed 100644 --- a/swift/Sources/Fory/ReadContext.swift +++ b/swift/Sources/Fory/ReadContext.swift @@ -665,7 +665,7 @@ public final class ReadContext { case .float64Array: value = try readPrimitiveArray(self) as [Double] case .array, .list: - value = try readListOfAny(context: self, refMode: .none) ?? [] + value = try readListOfAny(refMode: .none) ?? [] case .set: value = try Set.foryRead(self, refMode: .none, readTypeInfo: false) case .map: diff --git a/swift/Tests/ForyTests/AnyTests.swift b/swift/Tests/ForyTests/AnyTests.swift index e7eb1c3e3f..0dec3f3a92 100644 --- a/swift/Tests/ForyTests/AnyTests.swift +++ b/swift/Tests/ForyTests/AnyTests.swift @@ -92,6 +92,90 @@ private func nestedDynamicAnyList(depth: Int) -> Any { return value } +private func readAnyRoot( + _ fory: Fory, + data: Data, + _ body: (ReadContext) throws -> R +) throws -> R { + try fory.withReusableReadContext(data: data) { context in + try fory.readHead(buffer: context.buffer) + let value = try body(context) + #expect(context.buffer.remaining == 0) + return value + } +} + +@Test +func contextAnyReadersRoundTrip() throws { + let fory = Fory(config: .init(trackRef: false, compatible: false)) + fory.register(AnyHashableDynamicKey.self, id: 510) + fory.register(AnyHashableDynamicValue.self, id: 511) + + let anyValue: Any = AnyHashableDynamicValue(label: "context-any", score: 1) + let anyDecoded = try readAnyRoot(fory, data: try fory.serialize(anyValue)) { context in + try context.readAny(refMode: .nullOnly, readTypeInfo: true) + } + #expect( + anyDecoded as? AnyHashableDynamicValue + == AnyHashableDynamicValue(label: "context-any", score: 1) + ) + + let listValue: [Any] = [ + Int32(2), + "context-list", + AnyHashableDynamicValue(label: "context-list-obj", score: 3) + ] + let listDecoded = try readAnyRoot(fory, data: try fory.serialize(listValue)) { context in + try context.readListOfAny(refMode: .nullOnly, readTypeInfo: true) + } + #expect(listDecoded?[0] as? Int32 == 2) + #expect(listDecoded?[1] as? String == "context-list") + #expect( + listDecoded?[2] as? AnyHashableDynamicValue + == AnyHashableDynamicValue(label: "context-list-obj", score: 3) + ) + + let stringMapValue: [String: Any] = [ + "a": Int32(4), + "b": AnyHashableDynamicValue(label: "context-string-map", score: 5) + ] + let stringMapDecoded = try readAnyRoot(fory, data: try fory.serialize(stringMapValue)) { context in + try context.readMapStringToAny(refMode: .nullOnly, readTypeInfo: true) + } + #expect(stringMapDecoded?["a"] as? Int32 == 4) + #expect( + stringMapDecoded?["b"] as? AnyHashableDynamicValue + == AnyHashableDynamicValue(label: "context-string-map", score: 5) + ) + + let int32MapValue: [Int32: Any] = [ + 6: "context-int-map", + 7: AnyHashableDynamicValue(label: "context-int-map-obj", score: 8) + ] + let int32MapDecoded = try readAnyRoot(fory, data: try fory.serialize(int32MapValue)) { context in + try context.readMapInt32ToAny(refMode: .nullOnly, readTypeInfo: true) + } + #expect(int32MapDecoded?[6] as? String == "context-int-map") + #expect( + int32MapDecoded?[7] as? AnyHashableDynamicValue + == AnyHashableDynamicValue(label: "context-int-map-obj", score: 8) + ) + + let anyHashableMapValue: [AnyHashable: Any] = [ + AnyHashable("x"): Int32(9), + AnyHashable(AnyHashableDynamicKey(id: 10)): + AnyHashableDynamicValue(label: "context-any-map", score: 11) + ] + let anyHashableMapDecoded = try readAnyRoot(fory, data: try fory.serialize(anyHashableMapValue)) { context in + try context.readMapAnyHashableToAny(refMode: .nullOnly, readTypeInfo: true) + } + #expect(anyHashableMapDecoded?[AnyHashable("x")] as? Int32 == 9) + #expect( + anyHashableMapDecoded?[AnyHashable(AnyHashableDynamicKey(id: 10))] as? AnyHashableDynamicValue + == AnyHashableDynamicValue(label: "context-any-map", score: 11) + ) +} + @Test func topLevelAnyHashableRoundTrip() throws { let fory = Fory() From 2a6e22e0aa9500fa8aa3801a7e061f45ccafc469 Mon Sep 17 00:00:00 2001 From: chaokunyang Date: Fri, 3 Jul 2026 09:56:11 +0800 Subject: [PATCH 48/54] refactor: simplify graph memory collection readers --- .../serialization/collection_serializer.h | 36 +- csharp/src/Fory/CollectionSerializers.cs | 455 +++++------------- 2 files changed, 145 insertions(+), 346 deletions(-) diff --git a/cpp/fory/serialization/collection_serializer.h b/cpp/fory/serialization/collection_serializer.h index 433e407cc4..417b958727 100644 --- a/cpp/fory/serialization/collection_serializer.h +++ b/cpp/fory/serialization/collection_serializer.h @@ -22,7 +22,6 @@ #include "fory/serialization/array_serializer.h" #include "fory/serialization/serializer.h" #include -#include #include #include #include @@ -383,37 +382,26 @@ template inline constexpr bool has_reserve_v = has_reserve::value; template -constexpr size_t collection_element_memory_bytes() { +inline bool reserve_collection(Container &result, ReadContext &ctx, + uint32_t length) { + // Lazy error propagation may continue into later readers; do not let that + // path retain attacker-controlled capacity after an earlier read failure. + if (FORY_PREDICT_FALSE(ctx.has_error())) { + return false; + } using Elem = typename Container::value_type; + constexpr size_t elem_bytes = sizeof(Elem); // Portable lower-bound estimate only: STL node/header/debug-layout details // differ across implementations, so generic collections charge value storage. - return sizeof(Elem); -} - -template -inline bool reserve_collection_storage(ReadContext &ctx, uint32_t length) { - if (FORY_PREDICT_FALSE(elem_bytes != 0 && - static_cast(length) > - std::numeric_limits::max() / elem_bytes)) { + if (FORY_PREDICT_FALSE(static_cast(length) > + std::numeric_limits::max() / elem_bytes)) { ctx.set_error(Error::invalid_data( "graph memory estimate overflows: length=" + std::to_string(length) + " elementBytes=" + std::to_string(elem_bytes))); return false; } - return ctx.reserve_graph_memory(static_cast(length) * elem_bytes); -} - -template -inline bool reserve_collection(Container &result, ReadContext &ctx, - uint32_t length) { - // Lazy error propagation may continue into later readers; do not let that - // path retain attacker-controlled capacity after an earlier read failure. - if (FORY_PREDICT_FALSE(ctx.has_error())) { - return false; - } - constexpr size_t elem_bytes = collection_element_memory_bytes(); - if (FORY_PREDICT_FALSE( - (!reserve_collection_storage(ctx, length)))) { + if (FORY_PREDICT_FALSE(!ctx.reserve_graph_memory(static_cast(length) * + elem_bytes))) { return false; } if (FORY_PREDICT_FALSE(!ctx.buffer().ensure_readable(length, ctx.error()))) { diff --git a/csharp/src/Fory/CollectionSerializers.cs b/csharp/src/Fory/CollectionSerializers.cs index 9054d91880..d66061a0cf 100644 --- a/csharp/src/Fory/CollectionSerializers.cs +++ b/csharp/src/Fory/CollectionSerializers.cs @@ -213,220 +213,68 @@ private static class ElementStorage internal static readonly int Bytes = typeof(T).IsValueType ? Unsafe.SizeOf() : ReferenceBytes; } - public static List ReadCollectionData(Serializer elementSerializer, ReadContext context) + private interface IValueSink { - int length = checked((int)context.Reader.ReadVarUInt32()); - if (length == 0) - { - ReserveElementStorage(context, length); - List empty = []; - context.StoreRef(empty); - - return empty; - } - - byte header = context.Reader.ReadUInt8(); - // IMPORTANT: collection readers must obey the ref/null bits written on - // the wire, not the local generic metadata that may imply a different - // ref policy. Shared xlang tests intentionally deserialize one ref - // policy and then serialize another local payload. DO NOT REMOVE this comment. - bool trackRef = (header & CollectionBits.TrackingRef) != 0; - bool hasNull = (header & CollectionBits.HasNull) != 0; - bool declared = (header & CollectionBits.DeclaredElementType) != 0; - bool sameType = (header & CollectionBits.SameType) != 0; - ReserveElementStorage(context, length); - context.Reader.CheckBound(length); - List values = new(length); - context.StoreRef(values); - - if (!sameType) - { - if (trackRef) - { - for (int i = 0; i < length; i++) - { - values.Add(elementSerializer.Read(context, RefMode.Tracking, true)); - } - - return values; - } - - if (hasNull) - { - for (int i = 0; i < length; i++) - { - sbyte refFlag = context.Reader.ReadInt8(); - if (refFlag == (sbyte)RefFlag.Null) - { - values.Add((T)elementSerializer.DefaultObject!); - } - else if (refFlag == (sbyte)RefFlag.NotNullValue) - { - values.Add(elementSerializer.Read(context, RefMode.None, true)); - } - else - { - throw new RefException($"invalid nullability flag {refFlag}"); - } - } - } - else - { - for (int i = 0; i < length; i++) - { - values.Add(elementSerializer.Read(context, RefMode.None, true)); - } - } - - return values; - } - - if (!declared) - { - context.TypeResolver.ReadTypeInfo(elementSerializer, context); - } - - if (trackRef) - { - for (int i = 0; i < length; i++) - { - values.Add(elementSerializer.Read(context, RefMode.Tracking, false)); - } - - if (!declared) - { - context.ClearReadTypeInfo(typeof(T)); - } - - return values; - } - - if (hasNull) - { - for (int i = 0; i < length; i++) - { - sbyte refFlag = context.Reader.ReadInt8(); - if (refFlag == (sbyte)RefFlag.Null) - { - values.Add((T)elementSerializer.DefaultObject!); - } - else - { - values.Add(elementSerializer.ReadData(context)); - } - } - } - else - { - for (int i = 0; i < length; i++) - { - values.Add(elementSerializer.ReadData(context)); - } - } - - if (!declared) - { - context.ClearReadTypeInfo(typeof(T)); - } - - return values; + void Add(T value); } - private interface ICollectionBuilder + private readonly struct CollectionSink(TCollection values) : IValueSink + where TCollection : ICollection { - bool StoresOwnerRef { get; } - - TCollection Create(int length); - - void Add(TCollection values, T value); + public void Add(T value) => values.Add(value); } - private readonly struct HashSetBuilder : ICollectionBuilder, T> where T : notnull + private struct ArraySink(T[] values) : IValueSink { - public bool StoresOwnerRef => true; - - public HashSet Create(int length) => new(length); - - public void Add(HashSet values, T value) => values.Add(value); - } + private int _index; - private readonly struct SortedSetBuilder : ICollectionBuilder, T> where T : notnull - { - public bool StoresOwnerRef => true; - - public SortedSet Create(int length) => new(); - - public void Add(SortedSet values, T value) => values.Add(value); + public void Add(T value) + { + values[_index] = value; + _index++; + } } - private readonly struct ImmutableHashSetBuilder : ICollectionBuilder.Builder, T> + private readonly struct QueueSink(Queue values) : IValueSink { - public bool StoresOwnerRef => false; - - public ImmutableHashSet.Builder Create(int length) => ImmutableHashSet.CreateBuilder(); - - public void Add(ImmutableHashSet.Builder values, T value) => values.Add(value); + public void Add(T value) => values.Enqueue(value); } - private readonly struct LinkedListBuilder : ICollectionBuilder, T> + private readonly struct StackSink(Stack values) : IValueSink { - public bool StoresOwnerRef => true; - - public LinkedList Create(int length) => new(); - - public void Add(LinkedList values, T value) => values.AddLast(value); + public void Add(T value) => values.Push(value); } - private readonly struct QueueBuilder : ICollectionBuilder, T> + private static int ReadLength(ReadContext context) { - public bool StoresOwnerRef => true; - - public Queue Create(int length) => new(length); - - public void Add(Queue values, T value) => values.Enqueue(value); + int length = checked((int)context.Reader.ReadVarUInt32()); + ReserveElementStorage(context, length); + return length; } - private readonly struct StackBuilder : ICollectionBuilder, T> + private static byte ReadHeader(ReadContext context, int length) { - public bool StoresOwnerRef => true; - - public Stack Create(int length) => new(length); - - public void Add(Stack values, T value) => values.Push(value); + byte header = context.Reader.ReadUInt8(); + context.Reader.CheckBound(length); + return header; } - private static TCollection ReadCollectionOwner( + private static void ReadElements( Serializer elementSerializer, ReadContext context, - TBuilder builder) - where TBuilder : struct, ICollectionBuilder + int length, + byte header, + TSink sink) + where TSink : struct, IValueSink { - int length = checked((int)context.Reader.ReadVarUInt32()); - ReserveElementStorage(context, length); - if (length == 0) - { - TCollection empty = builder.Create(length); - if (builder.StoresOwnerRef) - { - context.StoreRef(empty); - } - - return empty; - } - - byte header = context.Reader.ReadUInt8(); + // IMPORTANT: collection readers must obey the ref/null bits written on + // the wire, not the local generic metadata that may imply a different + // ref policy. Shared xlang tests intentionally deserialize one ref + // policy and then serialize another local payload. DO NOT REMOVE this comment. bool trackRef = (header & CollectionBits.TrackingRef) != 0; bool hasNull = (header & CollectionBits.HasNull) != 0; bool declared = (header & CollectionBits.DeclaredElementType) != 0; bool sameType = (header & CollectionBits.SameType) != 0; - // Some builders allocate backing capacity from length, so prove proportional payload bytes - // before materializing non-empty owners. - context.Reader.CheckBound(length); - TCollection values = builder.Create(length); - if (builder.StoresOwnerRef) - { - context.StoreRef(values); - } if (!sameType) { @@ -434,10 +282,10 @@ private static TCollection ReadCollectionOwner( { for (int i = 0; i < length; i++) { - builder.Add(values, elementSerializer.Read(context, RefMode.Tracking, true)); + sink.Add(elementSerializer.Read(context, RefMode.Tracking, true)); } - return values; + return; } if (hasNull) @@ -447,11 +295,11 @@ private static TCollection ReadCollectionOwner( sbyte refFlag = context.Reader.ReadInt8(); if (refFlag == (sbyte)RefFlag.Null) { - builder.Add(values, (T)elementSerializer.DefaultObject!); + sink.Add((T)elementSerializer.DefaultObject!); } else if (refFlag == (sbyte)RefFlag.NotNullValue) { - builder.Add(values, elementSerializer.Read(context, RefMode.None, true)); + sink.Add(elementSerializer.Read(context, RefMode.None, true)); } else { @@ -463,11 +311,11 @@ private static TCollection ReadCollectionOwner( { for (int i = 0; i < length; i++) { - builder.Add(values, elementSerializer.Read(context, RefMode.None, true)); + sink.Add(elementSerializer.Read(context, RefMode.None, true)); } } - return values; + return; } if (!declared) @@ -479,7 +327,7 @@ private static TCollection ReadCollectionOwner( { for (int i = 0; i < length; i++) { - builder.Add(values, elementSerializer.Read(context, RefMode.Tracking, false)); + sink.Add(elementSerializer.Read(context, RefMode.Tracking, false)); } if (!declared) @@ -487,7 +335,7 @@ private static TCollection ReadCollectionOwner( context.ClearReadTypeInfo(typeof(T)); } - return values; + return; } if (hasNull) @@ -497,11 +345,11 @@ private static TCollection ReadCollectionOwner( sbyte refFlag = context.Reader.ReadInt8(); if (refFlag == (sbyte)RefFlag.Null) { - builder.Add(values, (T)elementSerializer.DefaultObject!); + sink.Add((T)elementSerializer.DefaultObject!); } else { - builder.Add(values, elementSerializer.ReadData(context)); + sink.Add(elementSerializer.ReadData(context)); } } } @@ -509,7 +357,7 @@ private static TCollection ReadCollectionOwner( { for (int i = 0; i < length; i++) { - builder.Add(values, elementSerializer.ReadData(context)); + sink.Add(elementSerializer.ReadData(context)); } } @@ -517,26 +365,57 @@ private static TCollection ReadCollectionOwner( { context.ClearReadTypeInfo(typeof(T)); } + } + public static List ReadCollectionData(Serializer elementSerializer, ReadContext context) + { + int length = ReadLength(context); + if (length == 0) + { + List empty = []; + context.StoreRef(empty); + return empty; + } + + byte header = ReadHeader(context, length); + List values = new(length); + context.StoreRef(values); + ReadElements(elementSerializer, context, length, header, new CollectionSink, T>(values)); return values; } internal static HashSet ReadHashSetData(Serializer elementSerializer, ReadContext context) where T : notnull { - return ReadCollectionOwner, HashSetBuilder>( - elementSerializer, - context, - new HashSetBuilder()); + int length = ReadLength(context); + if (length == 0) + { + HashSet empty = new(length); + context.StoreRef(empty); + return empty; + } + + byte header = ReadHeader(context, length); + HashSet values = new(length); + context.StoreRef(values); + ReadElements(elementSerializer, context, length, header, new CollectionSink, T>(values)); + return values; } internal static SortedSet ReadSortedSetData(Serializer elementSerializer, ReadContext context) where T : notnull { - return ReadCollectionOwner, SortedSetBuilder>( - elementSerializer, - context, - new SortedSetBuilder()); + int length = ReadLength(context); + SortedSet values = new(); + context.StoreRef(values); + if (length == 0) + { + return values; + } + + byte header = ReadHeader(context, length); + ReadElements(elementSerializer, context, length, header, new CollectionSink, T>(values)); + return values; } internal static ImmutableHashSet ReadImmutableHashSetData( @@ -544,150 +423,82 @@ internal static ImmutableHashSet ReadImmutableHashSetData( ReadContext context) where T : notnull { - ImmutableHashSet.Builder values = - ReadCollectionOwner.Builder, ImmutableHashSetBuilder>( - elementSerializer, - context, - new ImmutableHashSetBuilder()); + int length = ReadLength(context); + ImmutableHashSet.Builder values = ImmutableHashSet.CreateBuilder(); + if (length == 0) + { + return values.ToImmutable(); + } + + byte header = ReadHeader(context, length); + ReadElements( + elementSerializer, + context, + length, + header, + new CollectionSink.Builder, T>(values)); return values.ToImmutable(); } internal static LinkedList ReadLinkedListData(Serializer elementSerializer, ReadContext context) { - return ReadCollectionOwner, LinkedListBuilder>( - elementSerializer, - context, - new LinkedListBuilder()); + int length = ReadLength(context); + LinkedList values = new(); + context.StoreRef(values); + if (length == 0) + { + return values; + } + + byte header = ReadHeader(context, length); + ReadElements(elementSerializer, context, length, header, new CollectionSink, T>(values)); + return values; } internal static Queue ReadQueueData(Serializer elementSerializer, ReadContext context) { - return ReadCollectionOwner, QueueBuilder>( - elementSerializer, - context, - new QueueBuilder()); + int length = ReadLength(context); + Queue values = new(length); + context.StoreRef(values); + if (length == 0) + { + return values; + } + + byte header = ReadHeader(context, length); + ReadElements(elementSerializer, context, length, header, new QueueSink(values)); + return values; } internal static Stack ReadStackData(Serializer elementSerializer, ReadContext context) { - return ReadCollectionOwner, StackBuilder>( - elementSerializer, - context, - new StackBuilder()); + int length = ReadLength(context); + Stack values = new(length); + context.StoreRef(values); + if (length == 0) + { + return values; + } + + byte header = ReadHeader(context, length); + ReadElements(elementSerializer, context, length, header, new StackSink(values)); + return values; } public static T[] ReadArrayData(Serializer elementSerializer, ReadContext context) { - int length = checked((int)context.Reader.ReadVarUInt32()); + int length = ReadLength(context); if (length == 0) { - ReserveElementStorage(context, length); T[] empty = []; context.StoreRef(empty); - return empty; } - byte header = context.Reader.ReadUInt8(); - bool trackRef = (header & CollectionBits.TrackingRef) != 0; - bool hasNull = (header & CollectionBits.HasNull) != 0; - bool declared = (header & CollectionBits.DeclaredElementType) != 0; - bool sameType = (header & CollectionBits.SameType) != 0; - ReserveElementStorage(context, length); - context.Reader.CheckBound(length); + byte header = ReadHeader(context, length); T[] values = new T[length]; context.StoreRef(values); - - if (!sameType) - { - if (trackRef) - { - for (int i = 0; i < length; i++) - { - values[i] = elementSerializer.Read(context, RefMode.Tracking, true); - } - - return values; - } - - if (hasNull) - { - for (int i = 0; i < length; i++) - { - sbyte refFlag = context.Reader.ReadInt8(); - if (refFlag == (sbyte)RefFlag.Null) - { - values[i] = (T)elementSerializer.DefaultObject!; - } - else if (refFlag == (sbyte)RefFlag.NotNullValue) - { - values[i] = elementSerializer.Read(context, RefMode.None, true); - } - else - { - throw new RefException($"invalid nullability flag {refFlag}"); - } - } - } - else - { - for (int i = 0; i < length; i++) - { - values[i] = elementSerializer.Read(context, RefMode.None, true); - } - } - - return values; - } - - if (!declared) - { - context.TypeResolver.ReadTypeInfo(elementSerializer, context); - } - - if (trackRef) - { - for (int i = 0; i < length; i++) - { - values[i] = elementSerializer.Read(context, RefMode.Tracking, false); - } - - if (!declared) - { - context.ClearReadTypeInfo(typeof(T)); - } - - return values; - } - - if (hasNull) - { - for (int i = 0; i < length; i++) - { - sbyte refFlag = context.Reader.ReadInt8(); - if (refFlag == (sbyte)RefFlag.Null) - { - values[i] = (T)elementSerializer.DefaultObject!; - } - else - { - values[i] = elementSerializer.ReadData(context); - } - } - } - else - { - for (int i = 0; i < length; i++) - { - values[i] = elementSerializer.ReadData(context); - } - } - - if (!declared) - { - context.ClearReadTypeInfo(typeof(T)); - } - + ReadElements(elementSerializer, context, length, header, new ArraySink(values)); return values; } } From 1c860216f2494c0cb39fb609ce0e7c7b427c5112 Mon Sep 17 00:00:00 2001 From: chaokunyang Date: Fri, 3 Jul 2026 10:04:44 +0800 Subject: [PATCH 49/54] refactor(swift): use read context any helpers --- swift/Sources/Fory/Fory.swift | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/swift/Sources/Fory/Fory.swift b/swift/Sources/Fory/Fory.swift index 8abad7bfb9..567aedc74f 100644 --- a/swift/Sources/Fory/Fory.swift +++ b/swift/Sources/Fory/Fory.swift @@ -167,7 +167,7 @@ public final class Fory { data: data ) { context in try castAnyDynamicValue( - readAny(context: context, refMode: refMode, readTypeInfo: true), + context.readAny(refMode: refMode, readTypeInfo: true), to: Any.self ) } @@ -186,7 +186,7 @@ public final class Fory { data: data ) { context in try castAnyDynamicValue( - readAny(context: context, refMode: refMode, readTypeInfo: true), + context.readAny(refMode: refMode, readTypeInfo: true), to: AnyObject.self ) } @@ -209,7 +209,7 @@ public final class Fory { data: data ) { context in try castAnyDynamicValue( - readAny(context: context, refMode: refMode, readTypeInfo: true), + context.readAny(refMode: refMode, readTypeInfo: true), to: (any Serializer).self ) } @@ -227,7 +227,7 @@ public final class Fory { try deserializeRoot( data: data ) { context in - try readListOfAny(context: context, refMode: refMode, readTypeInfo: true) ?? [] + try context.readListOfAny(refMode: refMode, readTypeInfo: true) ?? [] } } @@ -248,7 +248,7 @@ public final class Fory { try deserializeRoot( data: data ) { context in - try readMapStringToAny(context: context, refMode: refMode, readTypeInfo: true) ?? [:] + try context.readMapStringToAny(refMode: refMode, readTypeInfo: true) ?? [:] } } @@ -269,7 +269,7 @@ public final class Fory { try deserializeRoot( data: data ) { context in - try readMapInt32ToAny(context: context, refMode: refMode, readTypeInfo: true) ?? [:] + try context.readMapInt32ToAny(refMode: refMode, readTypeInfo: true) ?? [:] } } @@ -290,7 +290,7 @@ public final class Fory { try deserializeRoot( data: data ) { context in - try readMapAnyHashableToAny(context: context, refMode: refMode, readTypeInfo: true) ?? [:] + try context.readMapAnyHashableToAny(refMode: refMode, readTypeInfo: true) ?? [:] } } @@ -314,7 +314,7 @@ public final class Fory { from: buffer ) { context in try castAnyDynamicValue( - readAny(context: context, refMode: refMode, readTypeInfo: true), + context.readAny(refMode: refMode, readTypeInfo: true), to: Any.self ) } @@ -337,7 +337,7 @@ public final class Fory { from: buffer ) { context in try castAnyDynamicValue( - readAny(context: context, refMode: refMode, readTypeInfo: true), + context.readAny(refMode: refMode, readTypeInfo: true), to: AnyObject.self ) } @@ -359,7 +359,7 @@ public final class Fory { from: buffer ) { context in try castAnyDynamicValue( - readAny(context: context, refMode: refMode, readTypeInfo: true), + context.readAny(refMode: refMode, readTypeInfo: true), to: (any Serializer).self ) } @@ -370,7 +370,7 @@ public final class Fory { try deserializeRoot( from: buffer ) { context in - try readListOfAny(context: context, refMode: refMode, readTypeInfo: true) ?? [] + try context.readListOfAny(refMode: refMode, readTypeInfo: true) ?? [] } } @@ -391,7 +391,7 @@ public final class Fory { try deserializeRoot( from: buffer ) { context in - try readMapStringToAny(context: context, refMode: refMode, readTypeInfo: true) ?? [:] + try context.readMapStringToAny(refMode: refMode, readTypeInfo: true) ?? [:] } } @@ -420,7 +420,7 @@ public final class Fory { try deserializeRoot( from: buffer ) { context in - try readMapInt32ToAny(context: context, refMode: refMode, readTypeInfo: true) ?? [:] + try context.readMapInt32ToAny(refMode: refMode, readTypeInfo: true) ?? [:] } } @@ -431,7 +431,7 @@ public final class Fory { try deserializeRoot( from: buffer ) { context in - try readMapAnyHashableToAny(context: context, refMode: refMode, readTypeInfo: true) ?? [:] + try context.readMapAnyHashableToAny(refMode: refMode, readTypeInfo: true) ?? [:] } } From 16aa1fc65511b1e971eea79995dc93df24eda103 Mon Sep 17 00:00:00 2001 From: chaokunyang Date: Fri, 3 Jul 2026 10:17:32 +0800 Subject: [PATCH 50/54] refactor(swift): keep any readers in read context --- swift/Sources/Fory/AnySerializer.swift | 165 ------------------------- swift/Sources/Fory/ReadContext.swift | 163 ++++++++++++++++++++++++ 2 files changed, 163 insertions(+), 165 deletions(-) diff --git a/swift/Sources/Fory/AnySerializer.swift b/swift/Sources/Fory/AnySerializer.swift index 0c33cfa06c..4d2692d651 100644 --- a/swift/Sources/Fory/AnySerializer.swift +++ b/swift/Sources/Fory/AnySerializer.swift @@ -17,45 +17,6 @@ import Foundation -private let anyReferenceBytes = 4 -private let anyArrayOwnerBytes = max(1, MemoryLayout<[Any]>.stride) - -@inline(never) -private func throwAnyGraphMemoryOverflow() throws -> Never { - throw ForyError.invalidData("graph memory estimate overflows") -} - -@inline(__always) -private func reserveAnyReferenceArrayMemory(_ context: ReadContext, count: Int) throws { - let (slotBytes, overflow) = count.multipliedReportingOverflow(by: anyReferenceBytes) - if overflow { - try throwAnyGraphMemoryOverflow() - } - let (bytes, addOverflow) = anyArrayOwnerBytes.addingReportingOverflow(slotBytes) - if addOverflow { - try throwAnyGraphMemoryOverflow() - } - try context.reserveGraphMemory(bytes) -} - -@inline(__always) -private func reserveAnyReferenceMapMemory( - _ context: ReadContext, _ type: Map.Type, count: Int -) - throws -{ - let (slotBytes, overflow) = count.multipliedReportingOverflow(by: 2 * anyReferenceBytes) - if overflow { - try throwAnyGraphMemoryOverflow() - } - let ownerBytes = max(1, MemoryLayout.stride) - let (bytes, addOverflow) = ownerBytes.addingReportingOverflow(slotBytes) - if addOverflow { - try throwAnyGraphMemoryOverflow() - } - try context.reserveGraphMemory(bytes) -} - public struct ForyAnyNullValue: Serializer { public init() {} @@ -692,129 +653,3 @@ public func readMapAnyHashableToAny( ) throws -> [AnyHashable: Any]? { try context.readMapAnyHashableToAny(refMode: refMode, readTypeInfo: readTypeInfo) } - -public extension ReadContext { - // Swift `Any` cannot conform to `Serializer`, so generated and hand-written dynamic-Any - // readers must enter through these context methods instead of trying `Any.foryRead(...)`. - func readAny( - refMode: RefMode, - readTypeInfo: Bool = true - ) throws -> Any? { - try SerializableAny.foryRead(self, refMode: refMode, readTypeInfo: readTypeInfo).anyValue() - } - - func readListOfAny( - refMode: RefMode, - readTypeInfo: Bool = false - ) throws -> [Any]? { - let wrapped: [SerializableAny]? = try [SerializableAny]?.foryRead( - self, - refMode: refMode, - readTypeInfo: readTypeInfo - ) - guard let wrapped else { - return nil - } - try reserveAnyReferenceArrayMemory(self, count: wrapped.count) - return wrapped.map { $0.anyValueForCollection() } - } - - func readMapStringToAny( - refMode: RefMode, - readTypeInfo: Bool = false - ) throws -> [String: Any]? { - let wrapped: [String: SerializableAny]? = try [String: SerializableAny]?.foryRead( - self, - refMode: refMode, - readTypeInfo: readTypeInfo - ) - guard let wrapped else { - return nil - } - try reserveAnyReferenceMapMemory(self, [String: Any].self, count: wrapped.count) - var map: [String: Any] = [:] - map.reserveCapacity(wrapped.count) - for pair in wrapped { - map[pair.key] = pair.value.anyValueForCollection() - } - return map - } - - func readMapInt32ToAny( - refMode: RefMode, - readTypeInfo: Bool = false - ) throws -> [Int32: Any]? { - let wrapped: [Int32: SerializableAny]? = try [Int32: SerializableAny]?.foryRead( - self, - refMode: refMode, - readTypeInfo: readTypeInfo - ) - guard let wrapped else { - return nil - } - try reserveAnyReferenceMapMemory(self, [Int32: Any].self, count: wrapped.count) - var map: [Int32: Any] = [:] - map.reserveCapacity(wrapped.count) - for pair in wrapped { - map[pair.key] = pair.value.anyValueForCollection() - } - return map - } - - func readMapAnyHashableToAny( - refMode: RefMode, - readTypeInfo: Bool = false - ) throws -> [AnyHashable: Any]? { - let wrapped: [AnyHashable: SerializableAny]? = try [AnyHashable: SerializableAny]?.foryRead( - self, - refMode: refMode, - readTypeInfo: readTypeInfo - ) - guard let wrapped else { - return nil - } - try reserveAnyReferenceMapMemory(self, [AnyHashable: Any].self, count: wrapped.count) - var map: [AnyHashable: Any] = [:] - map.reserveCapacity(wrapped.count) - for pair in wrapped { - map[pair.key] = pair.value.anyValueForCollection() - } - return map - } -} - -func readDynamicAnyMapValue(context: ReadContext) throws -> Any { - let map = try readMapAnyHashableToAny(context: context, refMode: .none) ?? [:] - if map.isEmpty { - try reserveAnyReferenceMapMemory(context, [String: Any].self, count: 0) - return [String: Any]() - } - try reserveAnyReferenceMapMemory(context, [String: Any].self, count: map.count) - var stringMap: [String: Any] = [:] - stringMap.reserveCapacity(map.count) - for pair in map { - guard let key = pair.key.base as? String else { - stringMap.removeAll(keepingCapacity: false) - break - } - stringMap[key] = pair.value - } - if stringMap.count == map.count { - return stringMap - } - - try reserveAnyReferenceMapMemory(context, [Int32: Any].self, count: map.count) - var int32Map: [Int32: Any] = [:] - int32Map.reserveCapacity(map.count) - for pair in map { - guard let key = pair.key.base as? Int32 else { - return map - } - int32Map[key] = pair.value - } - if int32Map.count == map.count { - return int32Map - } - - return map -} diff --git a/swift/Sources/Fory/ReadContext.swift b/swift/Sources/Fory/ReadContext.swift index 9a92d7f3ed..3f5cba39e4 100644 --- a/swift/Sources/Fory/ReadContext.swift +++ b/swift/Sources/Fory/ReadContext.swift @@ -739,3 +739,166 @@ public final class ReadContext { metaStrings.reset() } } + +private let anyReferenceBytes = 4 +private let anyArrayOwnerBytes = max(1, MemoryLayout<[Any]>.stride) + +@inline(never) +private func throwAnyGraphMemoryOverflow() throws -> Never { + throw ForyError.invalidData("graph memory estimate overflows") +} + +@inline(__always) +private func reserveAnyReferenceArrayMemory(_ context: ReadContext, count: Int) throws { + let (slotBytes, overflow) = count.multipliedReportingOverflow(by: anyReferenceBytes) + if overflow { + try throwAnyGraphMemoryOverflow() + } + let (bytes, addOverflow) = anyArrayOwnerBytes.addingReportingOverflow(slotBytes) + if addOverflow { + try throwAnyGraphMemoryOverflow() + } + try context.reserveGraphMemory(bytes) +} + +@inline(__always) +private func reserveAnyReferenceMapMemory( + _ context: ReadContext, _ type: Map.Type, count: Int +) throws { + let (slotBytes, overflow) = count.multipliedReportingOverflow(by: 2 * anyReferenceBytes) + if overflow { + try throwAnyGraphMemoryOverflow() + } + let ownerBytes = max(1, MemoryLayout.stride) + let (bytes, addOverflow) = ownerBytes.addingReportingOverflow(slotBytes) + if addOverflow { + try throwAnyGraphMemoryOverflow() + } + try context.reserveGraphMemory(bytes) +} + +extension ReadContext { + // Swift `Any` cannot conform to `Serializer`, so generated and hand-written dynamic-Any + // readers must enter through these context methods instead of trying `Any.foryRead(...)`. + public func readAny( + refMode: RefMode, + readTypeInfo: Bool = true + ) throws -> Any? { + try SerializableAny.foryRead(self, refMode: refMode, readTypeInfo: readTypeInfo).anyValue() + } + + public func readListOfAny( + refMode: RefMode, + readTypeInfo: Bool = false + ) throws -> [Any]? { + let wrapped: [SerializableAny]? = try [SerializableAny]?.foryRead( + self, + refMode: refMode, + readTypeInfo: readTypeInfo + ) + guard let wrapped else { + return nil + } + try reserveAnyReferenceArrayMemory(self, count: wrapped.count) + return wrapped.map { $0.anyValueForCollection() } + } + + public func readMapStringToAny( + refMode: RefMode, + readTypeInfo: Bool = false + ) throws -> [String: Any]? { + let wrapped: [String: SerializableAny]? = try [String: SerializableAny]?.foryRead( + self, + refMode: refMode, + readTypeInfo: readTypeInfo + ) + guard let wrapped else { + return nil + } + try reserveAnyReferenceMapMemory(self, [String: Any].self, count: wrapped.count) + var map: [String: Any] = [:] + map.reserveCapacity(wrapped.count) + for pair in wrapped { + map[pair.key] = pair.value.anyValueForCollection() + } + return map + } + + public func readMapInt32ToAny( + refMode: RefMode, + readTypeInfo: Bool = false + ) throws -> [Int32: Any]? { + let wrapped: [Int32: SerializableAny]? = try [Int32: SerializableAny]?.foryRead( + self, + refMode: refMode, + readTypeInfo: readTypeInfo + ) + guard let wrapped else { + return nil + } + try reserveAnyReferenceMapMemory(self, [Int32: Any].self, count: wrapped.count) + var map: [Int32: Any] = [:] + map.reserveCapacity(wrapped.count) + for pair in wrapped { + map[pair.key] = pair.value.anyValueForCollection() + } + return map + } + + public func readMapAnyHashableToAny( + refMode: RefMode, + readTypeInfo: Bool = false + ) throws -> [AnyHashable: Any]? { + let wrapped: [AnyHashable: SerializableAny]? = try [AnyHashable: SerializableAny]?.foryRead( + self, + refMode: refMode, + readTypeInfo: readTypeInfo + ) + guard let wrapped else { + return nil + } + try reserveAnyReferenceMapMemory(self, [AnyHashable: Any].self, count: wrapped.count) + var map: [AnyHashable: Any] = [:] + map.reserveCapacity(wrapped.count) + for pair in wrapped { + map[pair.key] = pair.value.anyValueForCollection() + } + return map + } +} + +private func readDynamicAnyMapValue(context: ReadContext) throws -> Any { + let map = try context.readMapAnyHashableToAny(refMode: .none) ?? [:] + if map.isEmpty { + try reserveAnyReferenceMapMemory(context, [String: Any].self, count: 0) + return [String: Any]() + } + try reserveAnyReferenceMapMemory(context, [String: Any].self, count: map.count) + var stringMap: [String: Any] = [:] + stringMap.reserveCapacity(map.count) + for pair in map { + guard let key = pair.key.base as? String else { + stringMap.removeAll(keepingCapacity: false) + break + } + stringMap[key] = pair.value + } + if stringMap.count == map.count { + return stringMap + } + + try reserveAnyReferenceMapMemory(context, [Int32: Any].self, count: map.count) + var int32Map: [Int32: Any] = [:] + int32Map.reserveCapacity(map.count) + for pair in map { + guard let key = pair.key.base as? Int32 else { + return map + } + int32Map[key] = pair.value + } + if int32Map.count == map.count { + return int32Map + } + + return map +} From 82ed7be2f9a6e5d689f24388f42200add9822859 Mon Sep 17 00:00:00 2001 From: chaokunyang Date: Fri, 3 Jul 2026 10:56:48 +0800 Subject: [PATCH 51/54] perf(swift): trim graph budget read overhead --- swift/Sources/Fory/ReadContext.swift | 4 ++-- .../ForyMacro/ForyObjectMacroReadGeneration.swift | 12 ++++-------- 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/swift/Sources/Fory/ReadContext.swift b/swift/Sources/Fory/ReadContext.swift index 3f5cba39e4..eb48b55b60 100644 --- a/swift/Sources/Fory/ReadContext.swift +++ b/swift/Sources/Fory/ReadContext.swift @@ -54,10 +54,10 @@ public final class ReadContext { @inline(__always) public func reserveGraphMemory(_ bytes: Int) throws { - if bytes < 0 { + if _slowPath(bytes < 0) { try throwGraphMemoryOverflow() } - if bytes > remainingGraphMemoryBytes { + if _slowPath(bytes > remainingGraphMemoryBytes) { try throwGraphMemoryExceeded(bytes: bytes) } remainingGraphMemoryBytes -= bytes diff --git a/swift/Sources/ForyMacro/ForyObjectMacroReadGeneration.swift b/swift/Sources/ForyMacro/ForyObjectMacroReadGeneration.swift index 307d3ba918..495743f133 100644 --- a/swift/Sources/ForyMacro/ForyObjectMacroReadGeneration.swift +++ b/swift/Sources/ForyMacro/ForyObjectMacroReadGeneration.swift @@ -65,10 +65,6 @@ private func reserveClassGraphOwnerLine(fields: [ParsedField], indent: String) - "\(indent)try context.reserveGraphMemory(\(classGraphOwnerBytesExpr(fields)))" } -private func reserveValueGraphOwnerLine(indent: String) -> String { - "\(indent)try context.reserveGraphMemory(max(1, MemoryLayout.stride))" -} - func buildClassReadWrapperDecl(accessPrefix: String) -> String { """ @inline(__always) @@ -153,7 +149,7 @@ private func buildEmptyStructReadDataDecl(accessPrefix: String) -> String { \(accessPrefix)static func foryReadData(_ context: ReadContext) throws -> Self { let __buffer = context.buffer \(schemaHashCheckExpr()) - \(reserveValueGraphOwnerLine(indent: " ")) + try context.reserveGraphMemory(1) return Self() } """ @@ -177,7 +173,7 @@ private func buildStructReadDataDecl( \(accessPrefix)static func foryReadData(_ context: ReadContext) throws -> Self { let __buffer = context.buffer \(schemaHashCheckExpr()) - \(reserveValueGraphOwnerLine(indent: " ")) + try context.reserveGraphMemory(MemoryLayout.stride) \(schemaReadBody) return Self( \(ctorArgs) @@ -263,7 +259,7 @@ private func buildEmptyStructReadCompatibleDataDecl(accessPrefix: String) -> Str guard let typeMeta = remoteTypeInfo.compatibleTypeMeta else { throw ForyError.invalidData("compatible type metadata is required") } - \(reserveValueGraphOwnerLine(indent: " ")) + try context.reserveGraphMemory(1) if let localTypeMeta = remoteTypeInfo.typeMeta, let localHeaderHash = remoteTypeInfo.typeDefHeaderHash, typeMeta.headerHash == localHeaderHash, @@ -318,7 +314,7 @@ private func buildStructReadCompatibleDataDecl( \(bufferBinding)guard let typeMeta = remoteTypeInfo.compatibleTypeMeta else { throw ForyError.invalidData("compatible type metadata is required") } - \(reserveValueGraphOwnerLine(indent: " ")) + try context.reserveGraphMemory(MemoryLayout.stride) if let localTypeMeta = remoteTypeInfo.typeMeta, let localHeaderHash = remoteTypeInfo.typeDefHeaderHash, typeMeta.headerHash == localHeaderHash, From 12d842dc5a8031e0b761757c935994ebc7d44fc8 Mon Sep 17 00:00:00 2001 From: chaokunyang Date: Fri, 3 Jul 2026 11:58:31 +0800 Subject: [PATCH 52/54] perf(go): trim graph budget root reads --- go/fory/fory.go | 44 ++++++++++++++++++++--------------------- go/fory/reader.go | 26 ++++++++++-------------- go/fory/struct.go | 50 +++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 81 insertions(+), 39 deletions(-) diff --git a/go/fory/fory.go b/go/fory/fory.go index e0986bbebb..55b2a18b21 100644 --- a/go/fory/fory.go +++ b/go/fory/fory.go @@ -1185,36 +1185,34 @@ func Deserialize[T any](f *Fory, data []byte, target *T) error { targetVal = reflect.ValueOf(target).Elem() targetType = targetVal.Type() } - // Get serializer for the target type - if f.rootUsesReadStruct(targetType) { - f.readCtx.ReadStruct(targetVal) - } else { - serializer, err := f.typeResolver.getSerializerByType(targetType, false) - if err != nil { - return fmt.Errorf("failed to get serializer for type %v: %w", targetType, err) + if targetType.Kind() == reflect.Struct { + if typeInfo := f.readCtx.getTypeInfoByType(targetType); typeInfo != nil { + if structSer, ok := typeInfo.Serializer.(*structSerializer); ok { + structSer.readRoot(f.readCtx, targetVal) + return f.readCtx.CheckError() + } } - - // Use Read to deserialize directly into target - serializer.Read(f.readCtx, RefModeTracking, true, false, targetVal) } + serializer, err := f.typeResolver.getSerializerByType(targetType, false) + if err != nil { + return fmt.Errorf("failed to get serializer for type %v: %w", targetType, err) + } + serializer.Read(f.readCtx, RefModeTracking, true, false, targetVal) return f.readCtx.CheckError() } } func (f *Fory) readRootValue(target reflect.Value) { - if f.rootUsesReadStruct(target.Type()) { - f.readCtx.ReadStruct(target) - return + targetType := target.Type() + if targetType.Kind() == reflect.Struct { + if typeInfo := f.readCtx.getTypeInfoByType(targetType); typeInfo != nil { + if structSer, ok := typeInfo.Serializer.(*structSerializer); ok { + structSer.readRoot(f.readCtx, target) + } else { + typeInfo.Serializer.Read(f.readCtx, RefModeTracking, true, false, target) + } + return + } } - // Root writes include type metadata, so generic roots must keep ReadValue. - // Calling a cached serializer directly would read that metadata byte as payload. f.readCtx.ReadValue(target, RefModeTracking, true) } - -func (f *Fory) rootUsesReadStruct(targetType reflect.Type) bool { - return targetType.Kind() == reflect.Struct && - targetType != dateReflectType && - targetType != timeReflectType && - targetType != decimalType && - !f.typeResolver.IsUnionType(targetType) -} diff --git a/go/fory/reader.go b/go/fory/reader.go index 80fca7319f..c7a87aa4d1 100644 --- a/go/fory/reader.go +++ b/go/fory/reader.go @@ -107,28 +107,22 @@ func (c *ReadContext) Reset() { // ReserveGraphMemory reserves raw estimated graph-owner bytes. func (c *ReadContext) ReserveGraphMemory(bytes int64) bool { - if bytes >= 0 { - remaining := c.remainingGraphMemoryBytes - if bytes <= remaining { - c.remainingGraphMemoryBytes = remaining - bytes - return true - } - return c.rejectGraphMemoryExceeded(bytes, remaining) + if uint64(bytes) <= uint64(c.remainingGraphMemoryBytes) { + c.remainingGraphMemoryBytes -= bytes + return true } - return c.rejectGraphMemoryBytes(bytes) + return c.rejectGraphMemoryReservation(bytes) } //go:noinline -func (c *ReadContext) rejectGraphMemoryBytes(bytes int64) bool { - c.SetError(DeserializationErrorf("estimated graph memory must be non-negative, got %d bytes", bytes)) - return false -} - -//go:noinline -func (c *ReadContext) rejectGraphMemoryExceeded(bytes int64, remaining int64) bool { +func (c *ReadContext) rejectGraphMemoryReservation(bytes int64) bool { + if bytes < 0 { + c.SetError(DeserializationErrorf("estimated graph memory must be non-negative, got %d bytes", bytes)) + return false + } c.SetError(DeserializationErrorf( "estimated graph memory request %d bytes exceeds maxGraphMemoryBytes remaining budget %d bytes", - bytes, remaining)) + bytes, c.remainingGraphMemoryBytes)) return false } diff --git a/go/fory/struct.go b/go/fory/struct.go index 890ca21082..39fc7ce43e 100644 --- a/go/fory/struct.go +++ b/go/fory/struct.go @@ -32,6 +32,7 @@ type structSerializer struct { structHash int32 typeID uint32 userTypeID uint32 + graphBytes int64 // Pre-sorted and categorized fields (embedded for cache locality) fieldGroup FieldGroup @@ -62,6 +63,7 @@ func newStructSerializerFromTypeDef(type_ reflect.Type, name string, fieldDefs [ type_: type_, name: name, userTypeID: invalidUserTypeID, + graphBytes: structGraphBytes(type_), fieldDefs: fieldDefs, } } @@ -77,6 +79,7 @@ func newStructSerializer(type_ reflect.Type, name string) *structSerializer { type_: type_, name: name, userTypeID: invalidUserTypeID, + graphBytes: structGraphBytes(type_), } } @@ -1367,6 +1370,53 @@ func (s *structSerializer) Read(ctx *ReadContext, refMode RefMode, readType bool s.ReadData(ctx, value) } +func (s *structSerializer) readRoot(ctx *ReadContext, value reflect.Value) { + buf := ctx.buffer + ctxErr := ctx.Err() + if ctx.refResolver.refTracking { + refID, refErr := ctx.refResolver.TryPreserveRefId(buf) + if refErr != nil { + ctx.SetError(FromError(refErr)) + return + } + if refID < int32(NotNullValueFlag) { + obj := ctx.refResolver.GetReadObject(refID) + if obj.IsValid() { + value.Set(obj) + } + return + } + } else { + // No-ref roots only need the marker byte; avoid the tracking helper on this hot path. + refFlag := buf.ReadInt8(ctxErr) + if refFlag == NullFlag { + return + } + if refFlag == RefFlag { + buf.ReadVarUint32(ctxErr) + return + } + } + if !ctx.ReserveGraphMemory(s.graphBytes) { + return + } + if s.type_ != nil { + serializer := ctx.typeResolver.ReadTypeInfoForType(buf, s.type_, ctxErr) + if ctxErr.HasError() { + return + } + if serializer == nil { + ctx.SetError(DeserializationError("unexpected type id for struct")) + return + } + if structSer, ok := serializer.(*structSerializer); ok && len(structSer.fieldDefs) > 0 { + structSer.ReadData(ctx, value) + return + } + } + s.ReadData(ctx, value) +} + func (s *structSerializer) ReadWithTypeInfo(ctx *ReadContext, refMode RefMode, typeInfo *TypeInfo, value reflect.Value) { // typeInfo is already read, don't read it again s.Read(ctx, refMode, false, false, value) From cca57743472683f7779569374a4bc5af2f11c20d Mon Sep 17 00:00:00 2001 From: chaokunyang Date: Fri, 3 Jul 2026 15:45:49 +0800 Subject: [PATCH 53/54] fix: clean graph memory budget ownership --- csharp/src/Fory/Fory.cs | 43 ++++++------- .../processing/ForyStructProcessor.java | 38 ++++++++++++ .../annotation/processing/SourceStruct.java | 3 + .../StaticSerializerSourceWriter.java | 19 ++++-- .../fory/builder/ObjectCodecBuilder.java | 40 ++++++++++++- .../builder/StaticCompatibleCodecBuilder.java | 4 +- .../org/apache/fory/context/ReadContext.java | 36 ++++++----- .../serializer/AbstractObjectSerializer.java | 12 ++-- .../CompatibleLayerSerializerBase.java | 2 +- .../fory/serializer/CompatibleSerializer.java | 2 +- .../fory/serializer/ExceptionSerializers.java | 2 +- .../fory/serializer/ObjectSerializer.java | 2 +- .../serializer/ObjectStreamSerializer.java | 2 +- .../kotlin/ksp/ForyKotlinSymbolProcessor.kt | 21 +++++++ .../ksp/KotlinSerializerSourceWriter.kt | 18 ++++-- .../org/apache/fory/kotlin/ksp/Model.kt | 1 + .../serializer/kotlin/GenericDataClassTest.kt | 6 +- swift/Sources/Fory/AnySerializer.swift | 60 +++++++++++++++++++ swift/Sources/Fory/ReadContext.swift | 36 ----------- 19 files changed, 251 insertions(+), 96 deletions(-) diff --git a/csharp/src/Fory/Fory.cs b/csharp/src/Fory/Fory.cs index e2698d8b0e..6c456e543a 100644 --- a/csharp/src/Fory/Fory.cs +++ b/csharp/src/Fory/Fory.cs @@ -192,7 +192,6 @@ public T Deserialize(ReadOnlySpan payload) { ByteReader reader = _readContext.Reader; reader.Reset(payload); - _readContext._remainingGraphMemoryBytes = Config.MaxGraphMemoryBytes; T value = DeserializeFromReader(reader); if (reader.Remaining != 0) { @@ -213,7 +212,6 @@ public T Deserialize(byte[] payload) { ByteReader reader = _readContext.Reader; reader.Reset(payload); - _readContext._remainingGraphMemoryBytes = Config.MaxGraphMemoryBytes; T value = DeserializeFromReader(reader); if (reader.Remaining != 0) { @@ -234,7 +232,6 @@ public T Deserialize(ref ReadOnlySequence payload) byte[] bytes = payload.ToArray(); ByteReader reader = _readContext.Reader; reader.Reset(bytes); - _readContext._remainingGraphMemoryBytes = Config.MaxGraphMemoryBytes; T value = DeserializeFromReader(reader); payload = payload.Slice(reader.Cursor); return value; @@ -280,29 +277,35 @@ private static void ThrowInvalidRootHeader(byte bitmap) => [MethodImpl(MethodImplOptions.AggressiveInlining)] private T DeserializeFromReader(ByteReader reader) { - ReadHead(reader); - Serializer serializer = _typeResolver.GetSerializer(); ReadContext readContext = _readContext; readContext.ResetFor(reader); - T value = _trackRef - ? serializer.Read(readContext, RefMode.Tracking, true) - : ReadRootNoRef(serializer, readContext); - if (_trackRef || readContext.RefReader.HasRefs) + readContext._remainingGraphMemoryBytes = Config.MaxGraphMemoryBytes; + try { - readContext.RefReader.Reset(); + ReadHead(reader); + Serializer serializer = _typeResolver.GetSerializer(); + return _trackRef + ? serializer.Read(readContext, RefMode.Tracking, true) + : ReadRootNoRef(serializer, readContext); } - if (readContext._reservedRefIds.Count != 0) + finally { - readContext._reservedRefIds.Clear(); + if (_trackRef || readContext.RefReader.HasRefs) + { + readContext.RefReader.Reset(); + } + if (readContext._reservedRefIds.Count != 0) + { + readContext._reservedRefIds.Clear(); + } + readContext._typeMetaType = null; + readContext._typeMeta = null; + readContext._typeMetaByType?.ClearKeys(); + readContext._readTypeInfoByType.ClearKeys(); + readContext._cachedTypeMetaType = null; + readContext._cachedTypeMeta = null; + readContext._currentDynamicReadDepth = 0; } - readContext._typeMetaType = null; - readContext._typeMeta = null; - readContext._typeMetaByType?.ClearKeys(); - readContext._readTypeInfoByType.ClearKeys(); - readContext._cachedTypeMetaType = null; - readContext._cachedTypeMeta = null; - readContext._currentDynamicReadDepth = 0; - return value; } [MethodImpl(MethodImplOptions.AggressiveInlining)] diff --git a/java/fory-annotation-processor/src/main/java/org/apache/fory/annotation/processing/ForyStructProcessor.java b/java/fory-annotation-processor/src/main/java/org/apache/fory/annotation/processing/ForyStructProcessor.java index d0ccf91f16..f8348c94db 100644 --- a/java/fory-annotation-processor/src/main/java/org/apache/fory/annotation/processing/ForyStructProcessor.java +++ b/java/fory-annotation-processor/src/main/java/org/apache/fory/annotation/processing/ForyStructProcessor.java @@ -64,6 +64,8 @@ "org.apache.fory.annotation.ForyDebug" }) public final class ForyStructProcessor extends AbstractProcessor { + private static final int OBJECT_SELF_BYTES = 1; + private static final int REFERENCE_BYTES = 4; private static final String ARRAY_TYPE = "org.apache.fory.annotation.ArrayType"; private static final String BFLOAT16_TYPE = "org.apache.fory.annotation.BFloat16Type"; private static final String EXPOSE = "org.apache.fory.annotation.Expose"; @@ -243,10 +245,46 @@ boolean record = isRecord(type); serializerName, record, isForyDebugEnabled(type), + graphMemoryBytes(type), sourceFields, recordConstructorFields); } + private int graphMemoryBytes(TypeElement type) { + int bytes = OBJECT_SELF_BYTES; + for (TypeElement current : hierarchy(type)) { + for (VariableElement field : ElementFilter.fieldsIn(current.getEnclosedElements())) { + if (!field.getModifiers().contains(Modifier.STATIC)) { + bytes = Math.addExact(bytes, fieldGraphMemoryBytes(field.asType())); + } + } + } + return bytes; + } + + private int fieldGraphMemoryBytes(TypeMirror type) { + TypeKind kind = type.getKind(); + if (!kind.isPrimitive()) { + return REFERENCE_BYTES; + } + switch (kind) { + case BOOLEAN: + case BYTE: + return 1; + case CHAR: + case SHORT: + return 2; + case INT: + case FLOAT: + return 4; + case LONG: + case DOUBLE: + return 8; + default: + return 0; + } + } + private boolean isForyDebugEnabled(TypeElement type) { return annotationMirror(type, FORY_DEBUG) != null; } diff --git a/java/fory-annotation-processor/src/main/java/org/apache/fory/annotation/processing/SourceStruct.java b/java/fory-annotation-processor/src/main/java/org/apache/fory/annotation/processing/SourceStruct.java index 859ec8b2c2..d451d5220c 100644 --- a/java/fory-annotation-processor/src/main/java/org/apache/fory/annotation/processing/SourceStruct.java +++ b/java/fory-annotation-processor/src/main/java/org/apache/fory/annotation/processing/SourceStruct.java @@ -31,6 +31,7 @@ final class SourceStruct { final boolean record; final boolean debug; final boolean hasNestedCompatibleStructFields; + final int graphMemoryBytes; final List fields; final List recordConstructorFields; @@ -41,6 +42,7 @@ final class SourceStruct { String serializerName, boolean record, boolean debug, + int graphMemoryBytes, List fields, List recordConstructorFields) { this.packageName = packageName; @@ -49,6 +51,7 @@ final class SourceStruct { this.serializerName = serializerName; this.record = record; this.debug = debug; + this.graphMemoryBytes = graphMemoryBytes; this.fields = Collections.unmodifiableList(new ArrayList<>(fields)); this.recordConstructorFields = Collections.unmodifiableList(new ArrayList<>(recordConstructorFields)); diff --git a/java/fory-annotation-processor/src/main/java/org/apache/fory/annotation/processing/StaticSerializerSourceWriter.java b/java/fory-annotation-processor/src/main/java/org/apache/fory/annotation/processing/StaticSerializerSourceWriter.java index 54adf66696..4c1eab4f0e 100644 --- a/java/fory-annotation-processor/src/main/java/org/apache/fory/annotation/processing/StaticSerializerSourceWriter.java +++ b/java/fory-annotation-processor/src/main/java/org/apache/fory/annotation/processing/StaticSerializerSourceWriter.java @@ -247,7 +247,7 @@ private void writeSchemaConsistentRead() { .append(struct.typeName) .append(" readSchemaConsistent(ReadContext readContext) {\n"); builder.append(" MemoryBuffer buffer = readContext.getBuffer();\n"); - builder.append(" reserveObjectGraphMemory(readContext);\n"); + appendGraphMemoryReserve(); builder.append(" if (typeResolver.checkClassVersion()) {\n"); builder.append(" checkClassVersion(buffer.readInt32(), classVersionHash);\n"); builder.append(" }\n"); @@ -272,7 +272,9 @@ private void writeSchemaConsistentRead() { .append(" value = ") .append(newGeneratedBeanExpression()) .append(";\n"); - builder.append(" readContext.reference(value);\n"); + builder.append(" if (needToWriteRef) {\n"); + builder.append(" readContext.reference(value);\n"); + builder.append(" }\n"); builder.append(" readFields(readContext, value);\n"); builder.append(" return value;\n"); } @@ -795,6 +797,13 @@ private String exactPrimitiveTypeId(SourceField field) { return meta.substring(start + prefix.length(), end); } + private void appendGraphMemoryReserve() { + builder + .append(" readContext.reserveGraphMemory(") + .append(struct.graphMemoryBytes) + .append(");\n"); + } + private void writeCompatibleRead() { builder.append(" @Override\n"); builder @@ -804,7 +813,7 @@ private void writeCompatibleRead() { builder.append(" if (sameSchemaCompatible) {\n"); builder.append(" return readSchemaConsistent(readContext);\n"); builder.append(" }\n"); - builder.append(" reserveObjectGraphMemory(readContext);\n"); + appendGraphMemoryReserve(); if (struct.record) { for (SourceField field : struct.fields) { builder @@ -835,7 +844,9 @@ private void writeCompatibleRead() { .append(" value = ") .append(newGeneratedBeanExpression()) .append(";\n"); - builder.append(" readContext.reference(value);\n"); + builder.append(" if (needToWriteRef) {\n"); + builder.append(" readContext.reference(value);\n"); + builder.append(" }\n"); builder.append(" for (int i = 0; i < remoteFields.size(); i++) {\n"); builder.append(" RemoteFieldInfo remoteField = remoteFields.get(i);\n"); builder.append(" readCompatibleField(readContext, value, remoteField);\n"); diff --git a/java/fory-core/src/main/java/org/apache/fory/builder/ObjectCodecBuilder.java b/java/fory-core/src/main/java/org/apache/fory/builder/ObjectCodecBuilder.java index 2896d8b3c4..7bb1e213a7 100644 --- a/java/fory-core/src/main/java/org/apache/fory/builder/ObjectCodecBuilder.java +++ b/java/fory-core/src/main/java/org/apache/fory/builder/ObjectCodecBuilder.java @@ -39,6 +39,8 @@ import static org.apache.fory.type.TypeUtils.SHORT_TYPE; import static org.apache.fory.type.TypeUtils.getRawType; +import java.lang.reflect.Field; +import java.lang.reflect.Modifier; import java.util.ArrayList; import java.util.Collection; import java.util.List; @@ -65,6 +67,7 @@ import org.apache.fory.logging.LoggerFactory; import org.apache.fory.meta.TypeDef; import org.apache.fory.platform.JdkVersion; +import org.apache.fory.reflect.ReflectionUtils; import org.apache.fory.reflect.TypeRef; import org.apache.fory.serializer.ObjectSerializer; import org.apache.fory.type.BFloat16; @@ -94,6 +97,8 @@ */ public class ObjectCodecBuilder extends BaseObjectCodecBuilder { private static final Logger LOG = LoggerFactory.getLogger(ObjectCodecBuilder.class); + private static final int OBJECT_SELF_BYTES = 1; + private static final int REFERENCE_BYTES = 4; private final Literal classVersionHash; protected ObjectCodecOptimizer objectCodecOptimizer; @@ -793,7 +798,7 @@ private Expression getWriterPos(Expression writerPos, long acc) { public Expression buildDecodeExpression() { Reference buffer = new Reference(BUFFER_NAME, bufferTypeRef, false); ListExpression expressions = new ListExpression(); - expressions.add(new Expression.Block("reserveObjectGraphMemory(" + READ_CONTEXT_NAME + ");")); + expressions.add(new Expression.Block(graphMemoryReserveCode())); if (typeResolver.checkClassVersion()) { expressions.add(checkClassVersion(buffer)); } @@ -833,6 +838,39 @@ public Expression buildDecodeExpression() { return expressions; } + protected String graphMemoryReserveCode() { + return READ_CONTEXT_NAME + ".reserveGraphMemory(" + objectGraphMemoryBytes() + ");"; + } + + private int objectGraphMemoryBytes() { + int bytes = OBJECT_SELF_BYTES; + for (Field field : ReflectionUtils.getFields(beanClass, true)) { + if (!Modifier.isStatic(field.getModifiers())) { + bytes = Math.addExact(bytes, fieldGraphMemoryBytes(field.getType())); + } + } + return bytes; + } + + private int fieldGraphMemoryBytes(Class fieldType) { + if (!fieldType.isPrimitive()) { + return REFERENCE_BYTES; + } + if (fieldType == boolean.class || fieldType == byte.class) { + return 1; + } + if (fieldType == char.class || fieldType == short.class) { + return 2; + } + if (fieldType == int.class || fieldType == float.class) { + return 4; + } + if (fieldType == long.class || fieldType == double.class) { + return 8; + } + return 0; + } + protected void deserializeReadGroup( List> readGroups, int numGroups, diff --git a/java/fory-core/src/main/java/org/apache/fory/builder/StaticCompatibleCodecBuilder.java b/java/fory-core/src/main/java/org/apache/fory/builder/StaticCompatibleCodecBuilder.java index 2feeea6daf..9b699d4cdb 100644 --- a/java/fory-core/src/main/java/org/apache/fory/builder/StaticCompatibleCodecBuilder.java +++ b/java/fory-core/src/main/java/org/apache/fory/builder/StaticCompatibleCodecBuilder.java @@ -160,7 +160,7 @@ private String genObjectCompatibleRead() { ? "((" + ctx.type(beanClass) + ") " + beanCode.value() + ")" : beanCode.value().toString(); StringBuilder code = new StringBuilder(); - code.append("reserveObjectGraphMemory(").append(READ_CONTEXT_NAME).append(");\n"); + code.append(graphMemoryReserveCode()).append('\n'); if (StringUtils.isNotBlank(beanCode.code())) { code.append(beanCode.code()).append('\n'); } @@ -190,7 +190,7 @@ private String genObjectCompatibleRead() { private String genRecordCompatibleRead() { RecordComponent[] components = RecordUtils.getRecordComponents(beanClass); StringBuilder code = new StringBuilder(); - code.append("reserveObjectGraphMemory(").append(READ_CONTEXT_NAME).append(");\n"); + code.append(graphMemoryReserveCode()).append('\n'); for (int i = 0; i < components.length; i++) { Class componentType = components[i].getType(); code.append(recordLocalType(componentType)) diff --git a/java/fory-core/src/main/java/org/apache/fory/context/ReadContext.java b/java/fory-core/src/main/java/org/apache/fory/context/ReadContext.java index d3a1f0f0d9..cdb06122b6 100644 --- a/java/fory-core/src/main/java/org/apache/fory/context/ReadContext.java +++ b/java/fory-core/src/main/java/org/apache/fory/context/ReadContext.java @@ -317,23 +317,29 @@ public Config getConfig() { return config; } - public void reserveGraphMemory(long bytes) { - if (bytes < 0) { - throwNegativeGraphMemory(bytes); - } - long remaining = remainingGraphMemoryBytes; - if (bytes > remaining) { - throwGraphMemoryExceeded(bytes, remaining); + // Failure may leave the remaining counter dirty; root cleanup resets read state for the + // operation. + public final void reserveGraphMemory(long bytes) { + long remaining = remainingGraphMemoryBytes - bytes; + remainingGraphMemoryBytes = remaining; + if ((bytes | remaining) < 0) { + throwInvalidGraphMemory(bytes, remaining + bytes); } - remainingGraphMemoryBytes = remaining - bytes; } - private void throwNegativeGraphMemory(long bytes) { - throw new InsecureException( - "Estimated graph memory must be non-negative, but got " + bytes + " bytes."); + public final void reserveGraphMemory(int bytes) { + long remaining = remainingGraphMemoryBytes - bytes; + remainingGraphMemoryBytes = remaining; + if ((bytes | remaining) < 0) { + throwInvalidGraphMemory(bytes, remaining + bytes); + } } - private void throwGraphMemoryExceeded(long bytes, long remaining) { + private void throwInvalidGraphMemory(long bytes, long remaining) { + if (bytes < 0) { + throw new InsecureException( + "Estimated graph memory must be non-negative, but got " + bytes + " bytes."); + } throw new InsecureException( "Estimated graph memory request " + bytes @@ -390,8 +396,10 @@ public boolean hasPreservedRefId() { } /** Binds the most recently preserved read ref id to {@code object}. */ - public void reference(Object object) { - refReader.reference(object); + public final void reference(Object object) { + if (trackingRef) { + refReader.reference(object); + } } /** Returns a previously read object by ref id. */ diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/AbstractObjectSerializer.java b/java/fory-core/src/main/java/org/apache/fory/serializer/AbstractObjectSerializer.java index 49d91ab657..f147288f4e 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/AbstractObjectSerializer.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/AbstractObjectSerializer.java @@ -81,7 +81,7 @@ public abstract class AbstractObjectSerializer extends Serializer { protected final TypeResolver typeResolver; protected final boolean isRecord; protected final ObjectInstantiator objectInstantiator; - private final long objectGraphMemoryBytes; + protected final int objectGraphMemoryBytes; private SerializationFieldInfo[] fieldInfos; private RecordInfo copyRecordInfo; @@ -108,16 +108,12 @@ public AbstractObjectSerializer( this.objectGraphMemoryBytes = computeObjectGraphMemoryBytes(type); } - protected final void reserveObjectGraphMemory(ReadContext readContext) { - readContext.reserveGraphMemory(objectGraphMemoryBytes); - } - - static long computeObjectGraphMemoryBytes(Class type) { + static int computeObjectGraphMemoryBytes(Class type) { // One byte is a stable nonzero self cost, not an attempt to model JVM object headers. - long bytes = OBJECT_SELF_BYTES; + int bytes = OBJECT_SELF_BYTES; for (Field field : ReflectionUtils.getFields(type, true)) { if (!Modifier.isStatic(field.getModifiers())) { - bytes += fieldGraphMemoryBytes(field.getType()); + bytes = Math.addExact(bytes, fieldGraphMemoryBytes(field.getType())); } } return bytes; diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/CompatibleLayerSerializerBase.java b/java/fory-core/src/main/java/org/apache/fory/serializer/CompatibleLayerSerializerBase.java index 68aeeaf94b..67fab5f10a 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/CompatibleLayerSerializerBase.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/CompatibleLayerSerializerBase.java @@ -135,7 +135,7 @@ public Object[] readFieldValues(ReadContext readContext) { @Override public T read(ReadContext readContext) { checkLayerSerializerMeta(); - reserveObjectGraphMemory(readContext); + readContext.reserveGraphMemory(objectGraphMemoryBytes); T obj = newBean(); readContext.reference(obj); return readAndSetFields(readContext, obj); diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/CompatibleSerializer.java b/java/fory-core/src/main/java/org/apache/fory/serializer/CompatibleSerializer.java index c7c024e451..1af0432dcc 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/CompatibleSerializer.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/CompatibleSerializer.java @@ -237,7 +237,7 @@ private T newInstance() { @Override public T read(ReadContext readContext) { - reserveObjectGraphMemory(readContext); + readContext.reserveGraphMemory(objectGraphMemoryBytes); if (isRecord) { Object[] fieldValues = new Object[allFields.length]; if (hasCompatibleCollectionArrayRead) { diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/ExceptionSerializers.java b/java/fory-core/src/main/java/org/apache/fory/serializer/ExceptionSerializers.java index d9d348966e..345b40ee9f 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/ExceptionSerializers.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/ExceptionSerializers.java @@ -62,7 +62,7 @@ public static final class ExceptionSerializer extends Seria private final TypeResolver typeResolver; private final ObjectInstantiator objectInstantiator; private final Constructor messageConstructor; - private final long graphMemoryBytes; + private final int graphMemoryBytes; private volatile Serializer[] slotsSerializers; private volatile boolean rebuildSlotsSerializersAtRuntime; diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/ObjectSerializer.java b/java/fory-core/src/main/java/org/apache/fory/serializer/ObjectSerializer.java index 9e50544860..6a43affa64 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/ObjectSerializer.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/ObjectSerializer.java @@ -208,7 +208,7 @@ private void writeFieldByCodecCategory( @Override public T read(ReadContext readContext) { - reserveObjectGraphMemory(readContext); + readContext.reserveGraphMemory(objectGraphMemoryBytes); MemoryBuffer buffer = readContext.getBuffer(); if (isRecord) { Object[] fields = readFields(readContext); diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/ObjectStreamSerializer.java b/java/fory-core/src/main/java/org/apache/fory/serializer/ObjectStreamSerializer.java index 4ea7532e61..8f310efc7f 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/ObjectStreamSerializer.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/ObjectStreamSerializer.java @@ -269,7 +269,7 @@ public void write(WriteContext writeContext, Object value) { @Override public Object read(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); - reserveObjectGraphMemory(readContext); + readContext.reserveGraphMemory(objectGraphMemoryBytes); Object obj = objectInstantiator.newInstance(); readContext.reference(obj); int numClasses = buffer.readInt16(); diff --git a/kotlin/fory-kotlin-ksp/src/main/kotlin/org/apache/fory/kotlin/ksp/ForyKotlinSymbolProcessor.kt b/kotlin/fory-kotlin-ksp/src/main/kotlin/org/apache/fory/kotlin/ksp/ForyKotlinSymbolProcessor.kt index 827e4bec27..f582200185 100644 --- a/kotlin/fory-kotlin-ksp/src/main/kotlin/org/apache/fory/kotlin/ksp/ForyKotlinSymbolProcessor.kt +++ b/kotlin/fory-kotlin-ksp/src/main/kotlin/org/apache/fory/kotlin/ksp/ForyKotlinSymbolProcessor.kt @@ -249,11 +249,32 @@ internal class ForyKotlinSymbolProcessor(private val environment: SymbolProcesso KotlinSerializerVisibility.PUBLIC }, construction = parsed.construction, + graphMemoryBytes = graphMemoryBytes(fields), fields = fields, originatingFiles = listOfNotNull(declaration.containingFile), ) } + private fun graphMemoryBytes(fields: List): Int { + var bytes = 1 + for (field in fields) { + bytes = Math.addExact(bytes, fieldGraphMemoryBytes(field.type)) + } + return bytes + } + + private fun fieldGraphMemoryBytes(type: KotlinSourceTypeNode): Int = + when (type.typeName) { + "boolean", + "byte" -> 1 + "short" -> 2 + "int", + "float" -> 4 + "long", + "double" -> 8 + else -> 4 + } + private fun parseStructFields( declaration: KSClassDeclaration, primaryConstructor: KSFunctionDeclaration, diff --git a/kotlin/fory-kotlin-ksp/src/main/kotlin/org/apache/fory/kotlin/ksp/KotlinSerializerSourceWriter.kt b/kotlin/fory-kotlin-ksp/src/main/kotlin/org/apache/fory/kotlin/ksp/KotlinSerializerSourceWriter.kt index f93c17b860..08c1ff0799 100644 --- a/kotlin/fory-kotlin-ksp/src/main/kotlin/org/apache/fory/kotlin/ksp/KotlinSerializerSourceWriter.kt +++ b/kotlin/fory-kotlin-ksp/src/main/kotlin/org/apache/fory/kotlin/ksp/KotlinSerializerSourceWriter.kt @@ -386,12 +386,20 @@ internal class KotlinSerializerSourceWriter(private val struct: KotlinSourceStru writeConstructorRead() } + private fun appendGraphMemoryReserve(indent: String) { + builder + .append(indent) + .append("readContext.reserveGraphMemory(") + .append(struct.graphMemoryBytes) + .append(")\n") + } + private fun writeConstructorRead() { builder .append(" private fun readSchemaConstructor(readContext: ReadContext): ") .append(struct.typeName) .append(" {\n") - builder.append(" reserveObjectGraphMemory(readContext)\n") + appendGraphMemoryReserve(" ") builder.append(" val fieldValues = arrayOfNulls(DESCRIPTORS.size)\n") builder.append(" val bufferedFields = newFieldBits(DESCRIPTORS.size)\n") builder.append(" beginConstructorRef(readContext)\n") @@ -655,7 +663,7 @@ internal class KotlinSerializerSourceWriter(private val struct: KotlinSourceStru } private fun writeMutableReadBody() { - builder.append(" reserveObjectGraphMemory(readContext)\n") + appendGraphMemoryReserve(" ") builder.append(" val value = ").append(struct.typeName).append("()\n") builder.append(" if (readContext.hasPreservedRefId()) {\n") builder.append(" readContext.reference(value)\n") @@ -702,7 +710,7 @@ internal class KotlinSerializerSourceWriter(private val struct: KotlinSourceStru builder.append(" }\n\n") return } - builder.append(" reserveObjectGraphMemory(readContext)\n") + appendGraphMemoryReserve(" ") writeCompatibleValueReadBody(" ", constructorRefs = false) builder.append(" }\n\n") } @@ -712,7 +720,7 @@ internal class KotlinSerializerSourceWriter(private val struct: KotlinSourceStru .append(" private fun readCompatibleConstructor(readContext: ReadContext): ") .append(struct.typeName) .append(" {\n") - builder.append(" reserveObjectGraphMemory(readContext)\n") + appendGraphMemoryReserve(" ") builder.append(" beginConstructorRef(readContext)\n") builder.append(" try {\n") writeCompatibleValueReadBody(" ", constructorRefs = true) @@ -833,7 +841,7 @@ internal class KotlinSerializerSourceWriter(private val struct: KotlinSourceStru private fun writeMutableCompatibleReadBody() { writePresenceVars() - builder.append(" reserveObjectGraphMemory(readContext)\n") + appendGraphMemoryReserve(" ") builder.append(" val value = ").append(struct.typeName).append("()\n") builder.append(" if (readContext.hasPreservedRefId()) {\n") builder.append(" readContext.reference(value)\n") diff --git a/kotlin/fory-kotlin-ksp/src/main/kotlin/org/apache/fory/kotlin/ksp/Model.kt b/kotlin/fory-kotlin-ksp/src/main/kotlin/org/apache/fory/kotlin/ksp/Model.kt index a50b0ca62f..547f4e7dc6 100644 --- a/kotlin/fory-kotlin-ksp/src/main/kotlin/org/apache/fory/kotlin/ksp/Model.kt +++ b/kotlin/fory-kotlin-ksp/src/main/kotlin/org/apache/fory/kotlin/ksp/Model.kt @@ -28,6 +28,7 @@ internal data class KotlinSourceStruct( val construction: KotlinStructConstruction = KotlinStructConstruction.CONSTRUCTOR, val fields: List, val originatingFiles: List, + val graphMemoryBytes: Int = 1, ) { val qualifiedSerializerName: String = if (packageName.isEmpty()) serializerName else "$packageName.$serializerName" diff --git a/kotlin/fory-kotlin/src/test/kotlin/org/apache/fory/serializer/kotlin/GenericDataClassTest.kt b/kotlin/fory-kotlin/src/test/kotlin/org/apache/fory/serializer/kotlin/GenericDataClassTest.kt index 4172768b3e..47313bc953 100644 --- a/kotlin/fory-kotlin/src/test/kotlin/org/apache/fory/serializer/kotlin/GenericDataClassTest.kt +++ b/kotlin/fory-kotlin/src/test/kotlin/org/apache/fory/serializer/kotlin/GenericDataClassTest.kt @@ -113,7 +113,11 @@ class GenericDataClassTest { assertEquals(typeInfo.decodeTypeName(), "Change") assertThrows(IllegalArgumentException::class.java) { KotlinSerializers.registerType( - Fory.builder().withXlang(false).requireClassRegistration(true).withCompatible(false).build(), + Fory.builder() + .withXlang(false) + .requireClassRegistration(true) + .withCompatible(false) + .build(), Change::class.java, "kotlin", "Bad.Name", diff --git a/swift/Sources/Fory/AnySerializer.swift b/swift/Sources/Fory/AnySerializer.swift index 4d2692d651..0b582a0d86 100644 --- a/swift/Sources/Fory/AnySerializer.swift +++ b/swift/Sources/Fory/AnySerializer.swift @@ -653,3 +653,63 @@ public func readMapAnyHashableToAny( ) throws -> [AnyHashable: Any]? { try context.readMapAnyHashableToAny(refMode: refMode, readTypeInfo: readTypeInfo) } + +private let anyMapReferenceBytes = 4 + +@inline(never) +private func throwAnyMapGraphMemoryOverflow() throws -> Never { + throw ForyError.invalidData("graph memory estimate overflows") +} + +@inline(__always) +private func reserveAnyMapMemory( + _ context: ReadContext, _ type: Map.Type, count: Int +) throws { + let (slotBytes, overflow) = count.multipliedReportingOverflow( + by: 2 * anyMapReferenceBytes) + if overflow { + try throwAnyMapGraphMemoryOverflow() + } + let ownerBytes = max(1, MemoryLayout.stride) + let (bytes, addOverflow) = ownerBytes.addingReportingOverflow(slotBytes) + if addOverflow { + try throwAnyMapGraphMemoryOverflow() + } + try context.reserveGraphMemory(bytes) +} + +func readDynamicAnyMapValue(context: ReadContext) throws -> Any { + let map = try context.readMapAnyHashableToAny(refMode: .none) ?? [:] + if map.isEmpty { + try reserveAnyMapMemory(context, [String: Any].self, count: 0) + return [String: Any]() + } + try reserveAnyMapMemory(context, [String: Any].self, count: map.count) + var stringMap: [String: Any] = [:] + stringMap.reserveCapacity(map.count) + for pair in map { + guard let key = pair.key.base as? String else { + stringMap.removeAll(keepingCapacity: false) + break + } + stringMap[key] = pair.value + } + if stringMap.count == map.count { + return stringMap + } + + try reserveAnyMapMemory(context, [Int32: Any].self, count: map.count) + var int32Map: [Int32: Any] = [:] + int32Map.reserveCapacity(map.count) + for pair in map { + guard let key = pair.key.base as? Int32 else { + return map + } + int32Map[key] = pair.value + } + if int32Map.count == map.count { + return int32Map + } + + return map +} diff --git a/swift/Sources/Fory/ReadContext.swift b/swift/Sources/Fory/ReadContext.swift index eb48b55b60..50f4aca171 100644 --- a/swift/Sources/Fory/ReadContext.swift +++ b/swift/Sources/Fory/ReadContext.swift @@ -866,39 +866,3 @@ extension ReadContext { return map } } - -private func readDynamicAnyMapValue(context: ReadContext) throws -> Any { - let map = try context.readMapAnyHashableToAny(refMode: .none) ?? [:] - if map.isEmpty { - try reserveAnyReferenceMapMemory(context, [String: Any].self, count: 0) - return [String: Any]() - } - try reserveAnyReferenceMapMemory(context, [String: Any].self, count: map.count) - var stringMap: [String: Any] = [:] - stringMap.reserveCapacity(map.count) - for pair in map { - guard let key = pair.key.base as? String else { - stringMap.removeAll(keepingCapacity: false) - break - } - stringMap[key] = pair.value - } - if stringMap.count == map.count { - return stringMap - } - - try reserveAnyReferenceMapMemory(context, [Int32: Any].self, count: map.count) - var int32Map: [Int32: Any] = [:] - int32Map.reserveCapacity(map.count) - for pair in map { - guard let key = pair.key.base as? Int32 else { - return map - } - int32Map[key] = pair.value - } - if int32Map.count == map.count { - return int32Map - } - - return map -} From b7f1a25faf27a963374d9447aaf7231bd42f708d Mon Sep 17 00:00:00 2001 From: chaokunyang Date: Fri, 3 Jul 2026 18:48:48 +0800 Subject: [PATCH 54/54] perf(java): trim graph budget read overhead --- .../processing/StaticSerializerSourceWriter.java | 8 ++------ .../java/org/apache/fory/context/ReadContext.java | 9 +++++++++ .../apache/fory/serializer/ObjectSerializer.java | 14 ++++++++++---- 3 files changed, 21 insertions(+), 10 deletions(-) diff --git a/java/fory-annotation-processor/src/main/java/org/apache/fory/annotation/processing/StaticSerializerSourceWriter.java b/java/fory-annotation-processor/src/main/java/org/apache/fory/annotation/processing/StaticSerializerSourceWriter.java index 4c1eab4f0e..fb61507c63 100644 --- a/java/fory-annotation-processor/src/main/java/org/apache/fory/annotation/processing/StaticSerializerSourceWriter.java +++ b/java/fory-annotation-processor/src/main/java/org/apache/fory/annotation/processing/StaticSerializerSourceWriter.java @@ -272,9 +272,7 @@ private void writeSchemaConsistentRead() { .append(" value = ") .append(newGeneratedBeanExpression()) .append(";\n"); - builder.append(" if (needToWriteRef) {\n"); - builder.append(" readContext.reference(value);\n"); - builder.append(" }\n"); + builder.append(" readContext.reference(value);\n"); builder.append(" readFields(readContext, value);\n"); builder.append(" return value;\n"); } @@ -844,9 +842,7 @@ private void writeCompatibleRead() { .append(" value = ") .append(newGeneratedBeanExpression()) .append(";\n"); - builder.append(" if (needToWriteRef) {\n"); - builder.append(" readContext.reference(value);\n"); - builder.append(" }\n"); + builder.append(" readContext.reference(value);\n"); builder.append(" for (int i = 0; i < remoteFields.size(); i++) {\n"); builder.append(" RemoteFieldInfo remoteField = remoteFields.get(i);\n"); builder.append(" readCompatibleField(readContext, value, remoteField);\n"); diff --git a/java/fory-core/src/main/java/org/apache/fory/context/ReadContext.java b/java/fory-core/src/main/java/org/apache/fory/context/ReadContext.java index cdb06122b6..1df6fead40 100644 --- a/java/fory-core/src/main/java/org/apache/fory/context/ReadContext.java +++ b/java/fory-core/src/main/java/org/apache/fory/context/ReadContext.java @@ -648,6 +648,15 @@ public Object readNonRef(TypeInfoHolder classInfoHolder) { /** Variant of {@link #readNonRef()} that uses already resolved {@link TypeInfo}. */ public Object readNonRef(TypeInfo typeInfo) { + int typeId = typeInfo.getTypeId(); + // User-defined xlang type IDs are contiguous; skip the primitive/string switch on this hot + // path. + if (typeId >= Types.ENUM && typeId <= Types.NAMED_UNION) { + increaseDepth(); + Object read = typeInfo.getSerializer().read(this); + decreaseDepth(); + return read; + } return readDataInternal(typeInfo); } diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/ObjectSerializer.java b/java/fory-core/src/main/java/org/apache/fory/serializer/ObjectSerializer.java index 6a43affa64..e3f8300398 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/ObjectSerializer.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/ObjectSerializer.java @@ -65,6 +65,8 @@ public final class ObjectSerializer extends AbstractObjectSerializer { private final RecordInfo recordInfo; private final SerializationFieldInfo[] allFields; private final int classVersionHash; + private final boolean trackingRef; + private final boolean checkClassVersion; public ObjectSerializer(TypeResolver typeResolver, Class cls) { this(typeResolver, cls, true); @@ -80,6 +82,8 @@ public ObjectSerializer( boolean resolveParent, ObjectInstantiator objectInstantiator) { super(typeResolver, cls, objectInstantiator); + trackingRef = config.trackingRef(); + checkClassVersion = typeResolver.checkClassVersion(); // avoid recursive building serializers. // Use `setSerializerIfAbsent` to avoid overwriting existing serializer for class when used // as data serializer. @@ -208,7 +212,7 @@ private void writeFieldByCodecCategory( @Override public T read(ReadContext readContext) { - readContext.reserveGraphMemory(objectGraphMemoryBytes); + readContext.reserveGraphMemory((long) objectGraphMemoryBytes); MemoryBuffer buffer = readContext.getBuffer(); if (isRecord) { Object[] fields = readFields(readContext); @@ -218,14 +222,16 @@ public T read(ReadContext readContext) { return obj; } T obj = newBean(); - readContext.reference(obj); + if (trackingRef) { + readContext.reference(obj); + } return readAndSetFields(readContext, obj); } public Object[] readFields(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); RefReader refReader = readContext.getRefReader(); - if (typeResolver.checkClassVersion()) { + if (checkClassVersion) { int hash = buffer.readInt32(); checkClassVersion(type, hash, classVersionHash); } @@ -242,7 +248,7 @@ public Object[] readFields(ReadContext readContext) { public T readAndSetFields(ReadContext readContext, T obj) { MemoryBuffer buffer = readContext.getBuffer(); RefReader refReader = readContext.getRefReader(); - if (typeResolver.checkClassVersion()) { + if (checkClassVersion) { int hash = buffer.readInt32(); checkClassVersion(type, hash, classVersionHash); }