From fe80c0647f7eb32e8ad9f1b8231dd41247714428 Mon Sep 17 00:00:00 2001 From: Zakir Date: Thu, 26 Feb 2026 01:36:01 +0530 Subject: [PATCH 1/2] feat(go): add MaxStringBytes/MaxCollectionSize/MaxMapSize to Config Add three opt-in guardrail fields to Config and three corresponding WithXxx option functions. All default to 0 (unlimited), preserving full backward compatibility. Enforcement in read paths follows in subsequent commits. --- go/fory/fory.go | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/go/fory/fory.go b/go/fory/fory.go index 09a0e3c6d2..797ed7c7db 100644 --- a/go/fory/fory.go +++ b/go/fory/fory.go @@ -54,6 +54,9 @@ type Config struct { MaxDepth int IsXlang bool Compatible bool // Schema evolution compatibility mode + MaxStringBytes int + MaxCollectionSize int + MaxMapSize int } // defaultConfig returns the default configuration @@ -101,6 +104,31 @@ func WithCompatible(enabled bool) Option { } } +// WithMaxStringBytes sets the maximum allowed byte length for a single +// deserialized string. 0 (default) means no limit. +func WithMaxStringBytes(n int) Option { + return func(f *Fory) { + f.config.MaxStringBytes = n + } +} + +// WithMaxCollectionSize sets the maximum allowed element count for a single +// deserialized slice or list. 0 (default) means no limit. +func WithMaxCollectionSize(n int) Option { + return func(f *Fory) { + f.config.MaxCollectionSize = n + } +} + +// WithMaxMapSize sets the maximum allowed entry count for a single +// deserialized map. 0 (default) means no limit. +func WithMaxMapSize(n int) Option { + return func(f *Fory) { + f.config.MaxMapSize = n + } +} + + // ============================================================================ // Fory - Main serialization instance // ============================================================================ From fc05b92633453bfbf352155a58c17820cf633b2a Mon Sep 17 00:00:00 2001 From: Zakir Date: Thu, 26 Feb 2026 02:53:18 +0530 Subject: [PATCH 2/2] complete changes + test file --- go/fory/fory.go | 4 + go/fory/limits_test.go | 238 +++++++++++++++++++++++++++++++++++++ go/fory/map.go | 1 + go/fory/map_primitive.go | 94 +++++++++++---- go/fory/reader.go | 30 ++++- go/fory/slice.go | 2 + go/fory/slice_primitive.go | 4 +- go/fory/string.go | 12 +- 8 files changed, 359 insertions(+), 26 deletions(-) create mode 100644 go/fory/limits_test.go diff --git a/go/fory/fory.go b/go/fory/fory.go index 797ed7c7db..efb63ae783 100644 --- a/go/fory/fory.go +++ b/go/fory/fory.go @@ -184,6 +184,10 @@ func New(opts ...Option) *Fory { f.readCtx.refResolver = f.refResolver f.readCtx.compatible = f.config.Compatible f.readCtx.xlang = f.config.IsXlang + f.readCtx.maxStringBytes = f.config.MaxStringBytes + f.readCtx.maxCollectionSize = f.config.MaxCollectionSize + f.readCtx.maxMapSize = f.config.MaxMapSize + return f } diff --git a/go/fory/limits_test.go b/go/fory/limits_test.go new file mode 100644 index 0000000000..a4214431c3 --- /dev/null +++ b/go/fory/limits_test.go @@ -0,0 +1,238 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package fory + +import ( + "fmt" + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +// ============================================================================ +// MaxStringBytes +// ============================================================================ + +func TestMaxStringBytesBlocksOversizedString(t *testing.T) { + f := NewFory(WithXlang(true), WithMaxStringBytes(5)) + + long := strings.Repeat("a", 20) // 20 bytes > limit 5 + data, err := f.Marshal(long) + require.NoError(t, err) // write path has no limit + + var result string + err = f.Unmarshal(data, &result) + require.Error(t, err, "expected error: string exceeds MaxStringBytes") +} + +func TestMaxStringBytesAllowsExactLimit(t *testing.T) { + f := NewFory(WithXlang(true), WithMaxStringBytes(5)) + + s := "hello" // exactly 5 bytes — must NOT be rejected (> not >=) + data, err := f.Marshal(s) + require.NoError(t, err) + + var result string + require.NoError(t, f.Unmarshal(data, &result)) + require.Equal(t, s, result) +} + +func TestMaxStringBytesAllowsWithinLimit(t *testing.T) { + f := NewFory(WithXlang(true), WithMaxStringBytes(10)) + + s := "hi" // 2 bytes, well within limit + data, err := f.Marshal(s) + require.NoError(t, err) + + var result string + require.NoError(t, f.Unmarshal(data, &result)) + require.Equal(t, s, result) +} + +func TestMaxStringBytesZeroMeansNoLimit(t *testing.T) { + f := NewFory(WithXlang(true)) // default 0 = no limit + + long := strings.Repeat("x", 100_000) + data, err := f.Marshal(long) + require.NoError(t, err) + + var result string + require.NoError(t, f.Unmarshal(data, &result)) + require.Equal(t, long, result) +} + +// ============================================================================ +// MaxCollectionSize +// ============================================================================ + +func TestMaxCollectionSizeBlocksOversizedSlice(t *testing.T) { + f := NewFory(WithXlang(true), WithMaxCollectionSize(3)) + + s := []string{"a", "b", "c", "d", "e"} // 5 elements > limit 3 + data, err := f.Marshal(s) + require.NoError(t, err) + + var result []string + err = f.Unmarshal(data, &result) + require.Error(t, err, "expected error: slice exceeds MaxCollectionSize") +} + +func TestMaxCollectionSizeAllowsWithinLimit(t *testing.T) { + f := NewFory(WithXlang(true), WithMaxCollectionSize(5)) + + s := []string{"a", "b", "c"} // 3 elements, within limit 5 + data, err := f.Marshal(s) + require.NoError(t, err) + + var result []string + require.NoError(t, f.Unmarshal(data, &result)) + require.Equal(t, s, result) +} + +func TestMaxCollectionSizeAllowsExactLimit(t *testing.T) { + f := NewFory(WithXlang(true), WithMaxCollectionSize(3)) + + s := []string{"a", "b", "c"} // exactly 3 — must NOT be rejected + data, err := f.Marshal(s) + require.NoError(t, err) + + var result []string + require.NoError(t, f.Unmarshal(data, &result)) + require.Equal(t, s, result) +} + +func TestMaxCollectionSizeZeroMeansNoLimit(t *testing.T) { + f := NewFory(WithXlang(true)) // default 0 = no limit + + s := make([]int32, 10_000) + data, err := f.Marshal(s) + require.NoError(t, err) + + var result []int32 + require.NoError(t, f.Unmarshal(data, &result)) + require.Equal(t, len(s), len(result)) +} + +// ============================================================================ +// MaxMapSize +// ============================================================================ + +func TestMaxMapSizeBlocksOversizedMap(t *testing.T) { + f := NewFory(WithXlang(true), WithMaxMapSize(2)) + + m := map[string]string{"k1": "v1", "k2": "v2", "k3": "v3"} // 3 entries > limit 2 + data, err := f.Marshal(m) + require.NoError(t, err) + + var result map[string]string + err = f.Unmarshal(data, &result) + require.Error(t, err, "expected error: map exceeds MaxMapSize") +} + +func TestMaxMapSizeAllowsWithinLimit(t *testing.T) { + f := NewFory(WithXlang(true), WithMaxMapSize(5)) + + m := map[string]string{"k1": "v1", "k2": "v2"} + data, err := f.Marshal(m) + require.NoError(t, err) + + var result map[string]string + require.NoError(t, f.Unmarshal(data, &result)) + require.Equal(t, m, result) +} + +func TestMaxMapSizeAllowsExactLimit(t *testing.T) { + f := NewFory(WithXlang(true), WithMaxMapSize(2)) + + m := map[string]string{"k1": "v1", "k2": "v2"} // exactly 2 — must NOT be rejected + data, err := f.Marshal(m) + require.NoError(t, err) + + var result map[string]string + require.NoError(t, f.Unmarshal(data, &result)) + require.Equal(t, m, result) +} + +func TestMaxMapSizeZeroMeansNoLimit(t *testing.T) { + f := NewFory(WithXlang(true)) // default 0 = no limit + + m := make(map[string]string, 1000) + for i := 0; i < 1000; i++ { + m[fmt.Sprintf("k%d", i)] = "v" + } + data, err := f.Marshal(m) + require.NoError(t, err) + + var result map[string]string + require.NoError(t, f.Unmarshal(data, &result)) + require.Equal(t, 1000, len(result)) +} + +// ============================================================================ +// Combined limits +// ============================================================================ + +func TestCombinedLimitsStringInsideSlice(t *testing.T) { + // Slice size is within limit, but one element string is too long + f := NewFory(WithXlang(true), WithMaxCollectionSize(10), WithMaxStringBytes(3)) + + s := []string{"ab", "cd", "this-is-too-long"} // third element 16 bytes > limit 3 + data, err := f.Marshal(s) + require.NoError(t, err) + + var result []string + err = f.Unmarshal(data, &result) + require.Error(t, err) +} + +func TestCombinedLimitsCollectionFiresBeforeString(t *testing.T) { + // Collection limit fires before any string element is read + f := NewFory(WithXlang(true), WithMaxCollectionSize(2), WithMaxStringBytes(1000)) + + s := []string{"a", "b", "c", "d"} // 4 elements > collection limit 2 + data, err := f.Marshal(s) + require.NoError(t, err) + + var result []string + err = f.Unmarshal(data, &result) + require.Error(t, err) +} + +func TestCombinedLimitsAllWithinBounds(t *testing.T) { + // All limits set, all values within bounds — must succeed end-to-end + f := NewFory(WithXlang(true), + WithMaxStringBytes(20), + WithMaxCollectionSize(10), + WithMaxMapSize(10), + ) + + s := []string{"hello", "world"} + data, err := f.Marshal(s) + require.NoError(t, err) + var sliceResult []string + require.NoError(t, f.Unmarshal(data, &sliceResult)) + require.Equal(t, s, sliceResult) + + m := map[string]string{"k1": "v1", "k2": "v2"} + data, err = f.Marshal(m) + require.NoError(t, err) + var mapResult map[string]string + require.NoError(t, f.Unmarshal(data, &mapResult)) + require.Equal(t, m, mapResult) +} diff --git a/go/fory/map.go b/go/fory/map.go index f2489601f3..91dc0bcd24 100644 --- a/go/fory/map.go +++ b/go/fory/map.go @@ -306,6 +306,7 @@ func (s mapSerializer) ReadData(ctx *ReadContext, value reflect.Value) { refResolver.Reference(value) size := int(buf.ReadVarUint32(ctxErr)) + ctx.checkMapSize(size) if size == 0 || ctx.HasError() { return } diff --git a/go/fory/map_primitive.go b/go/fory/map_primitive.go index 21a4bd7b5d..0cfbbd076f 100644 --- a/go/fory/map_primitive.go +++ b/go/fory/map_primitive.go @@ -69,8 +69,13 @@ func writeMapStringString(buf *ByteBuffer, m map[string]string, hasGenerics bool } // readMapStringString reads map[string]string using chunk protocol -func readMapStringString(buf *ByteBuffer, err *Error) map[string]string { +func readMapStringString(buf *ByteBuffer, err *Error, maxStringBytes, maxMapSize int) map[string]string { size := int(buf.ReadVarUint32(err)) +if maxMapSize > 0 && size > maxMapSize { + err.SetError(DeserializationErrorf( + "fory: map size %d exceeds limit %d", size, maxMapSize)) + return nil +} result := make(map[string]string, size) if size == 0 { return result @@ -94,7 +99,8 @@ func readMapStringString(buf *ByteBuffer, err *Error) map[string]string { if !valueDeclared { buf.ReadUint8(err) // skip value type } - v := readString(buf, err) + v := readString(buf, err, maxStringBytes) + result[""] = v // empty string as null key size-- continue @@ -104,7 +110,8 @@ func readMapStringString(buf *ByteBuffer, err *Error) map[string]string { if !keyDeclared { buf.ReadUint8(err) // skip key type } - k := readString(buf, err) + k := readString(buf, err, maxStringBytes) + result[k] = "" // empty string as null value size-- continue @@ -123,8 +130,10 @@ func readMapStringString(buf *ByteBuffer, err *Error) map[string]string { // ReadData chunk entries for i := 0; i < chunkSize && size > 0; i++ { - k := readString(buf, err) - v := readString(buf, err) + k := readString(buf, err, maxStringBytes) + + v := readString(buf, err, maxStringBytes) + result[k] = v size-- } @@ -172,8 +181,13 @@ func writeMapStringInt64(buf *ByteBuffer, m map[string]int64, hasGenerics bool) } // readMapStringInt64 reads map[string]int64 using chunk protocol -func readMapStringInt64(buf *ByteBuffer, err *Error) map[string]int64 { +func readMapStringInt64(buf *ByteBuffer, err *Error, maxStringBytes, maxMapSize int) map[string]int64 { size := int(buf.ReadVarUint32(err)) +if maxMapSize > 0 && size > maxMapSize { + err.SetError(DeserializationErrorf( + "fory: map size %d exceeds limit %d", size, maxMapSize)) + return nil +} result := make(map[string]int64, size) if size == 0 { return result @@ -197,7 +211,8 @@ func readMapStringInt64(buf *ByteBuffer, err *Error) map[string]int64 { buf.ReadUint8(err) } for i := 0; i < chunkSize && size > 0; i++ { - k := readString(buf, err) + k := readString(buf, err, maxStringBytes) + v := buf.ReadVarint64(err) result[k] = v size-- @@ -246,8 +261,13 @@ func writeMapStringInt32(buf *ByteBuffer, m map[string]int32, hasGenerics bool) } // readMapStringInt32 reads map[string]int32 using chunk protocol -func readMapStringInt32(buf *ByteBuffer, err *Error) map[string]int32 { +func readMapStringInt32(buf *ByteBuffer, err *Error, maxStringBytes, maxMapSize int) map[string]int32 { size := int(buf.ReadVarUint32(err)) +if maxMapSize > 0 && size > maxMapSize { + err.SetError(DeserializationErrorf( + "fory: map size %d exceeds limit %d", size, maxMapSize)) + return nil +} result := make(map[string]int32, size) if size == 0 { return result @@ -271,7 +291,8 @@ func readMapStringInt32(buf *ByteBuffer, err *Error) map[string]int32 { buf.ReadUint8(err) } for i := 0; i < chunkSize && size > 0; i++ { - k := readString(buf, err) + k := readString(buf, err, maxStringBytes) + v := buf.ReadVarint32(err) result[k] = v size-- @@ -320,8 +341,13 @@ func writeMapStringInt(buf *ByteBuffer, m map[string]int, hasGenerics bool) { } // readMapStringInt reads map[string]int using chunk protocol -func readMapStringInt(buf *ByteBuffer, err *Error) map[string]int { +func readMapStringInt(buf *ByteBuffer, err *Error, maxStringBytes, maxMapSize int) map[string]int { size := int(buf.ReadVarUint32(err)) +if maxMapSize > 0 && size > maxMapSize { + err.SetError(DeserializationErrorf( + "fory: map size %d exceeds limit %d", size, maxMapSize)) + return nil +} result := make(map[string]int, size) if size == 0 { return result @@ -345,7 +371,8 @@ func readMapStringInt(buf *ByteBuffer, err *Error) map[string]int { buf.ReadUint8(err) } for i := 0; i < chunkSize && size > 0; i++ { - k := readString(buf, err) + k := readString(buf, err, maxStringBytes) + v := buf.ReadVarint64(err) result[k] = int(v) size-- @@ -394,8 +421,13 @@ func writeMapStringFloat64(buf *ByteBuffer, m map[string]float64, hasGenerics bo } // readMapStringFloat64 reads map[string]float64 using chunk protocol -func readMapStringFloat64(buf *ByteBuffer, err *Error) map[string]float64 { +func readMapStringFloat64(buf *ByteBuffer, err *Error, maxStringBytes, maxMapSize int) map[string]float64 { size := int(buf.ReadVarUint32(err)) +if maxMapSize > 0 && size > maxMapSize { + err.SetError(DeserializationErrorf( + "fory: map size %d exceeds limit %d", size, maxMapSize)) + return nil +} result := make(map[string]float64, size) if size == 0 { return result @@ -419,7 +451,8 @@ func readMapStringFloat64(buf *ByteBuffer, err *Error) map[string]float64 { buf.ReadUint8(err) } for i := 0; i < chunkSize && size > 0; i++ { - k := readString(buf, err) + k := readString(buf, err, maxStringBytes) + v := buf.ReadFloat64(err) result[k] = v size-- @@ -468,8 +501,13 @@ func writeMapStringBool(buf *ByteBuffer, m map[string]bool, hasGenerics bool) { } // readMapStringBool reads map[string]bool using chunk protocol -func readMapStringBool(buf *ByteBuffer, err *Error) map[string]bool { +func readMapStringBool(buf *ByteBuffer, err *Error, maxStringBytes, maxMapSize int) map[string]bool { size := int(buf.ReadVarUint32(err)) + if maxMapSize > 0 && size > maxMapSize { + err.SetError(DeserializationErrorf( + "fory: map size %d exceeds limit %d", size, maxMapSize)) + return nil + } result := make(map[string]bool, size) if size == 0 { return result @@ -498,7 +536,8 @@ func readMapStringBool(buf *ByteBuffer, err *Error) map[string]bool { } for i := 0; i < chunkSize && size > 0; i++ { - k := readString(buf, err) + k := readString(buf, err, maxStringBytes) + v := buf.ReadBool(err) result[k] = v size-- @@ -549,6 +588,11 @@ func writeMapInt32Int32(buf *ByteBuffer, m map[int32]int32, hasGenerics bool) { // readMapInt32Int32 reads map[int32]int32 using chunk protocol func readMapInt32Int32(buf *ByteBuffer, err *Error) map[int32]int32 { size := int(buf.ReadVarUint32(err)) +if maxMapSize > 0 && size > maxMapSize { + err.SetError(DeserializationErrorf( + "fory: map size %d exceeds limit %d", size, maxMapSize)) + return nil +} result := make(map[int32]int32, size) if size == 0 { return result @@ -623,6 +667,11 @@ func writeMapInt64Int64(buf *ByteBuffer, m map[int64]int64, hasGenerics bool) { // readMapInt64Int64 reads map[int64]int64 using chunk protocol func readMapInt64Int64(buf *ByteBuffer, err *Error) map[int64]int64 { size := int(buf.ReadVarUint32(err)) +if maxMapSize > 0 && size > maxMapSize { + err.SetError(DeserializationErrorf( + "fory: map size %d exceeds limit %d", size, maxMapSize)) + return nil +} result := make(map[int64]int64, size) if size == 0 { return result @@ -697,6 +746,11 @@ func writeMapIntInt(buf *ByteBuffer, m map[int]int, hasGenerics bool) { // readMapIntInt reads map[int]int using chunk protocol func readMapIntInt(buf *ByteBuffer, err *Error) map[int]int { size := int(buf.ReadVarUint32(err)) +if maxMapSize > 0 && size > maxMapSize { + err.SetError(DeserializationErrorf( + "fory: map size %d exceeds limit %d", size, maxMapSize)) + return nil +} result := make(map[int]int, size) if size == 0 { return result @@ -752,7 +806,7 @@ func (s stringStringMapSerializer) ReadData(ctx *ReadContext, value reflect.Valu value.Set(reflect.MakeMap(value.Type())) } ctx.RefResolver().Reference(value) - result := readMapStringString(ctx.buffer, ctx.Err()) + result := readMapStringString(ctx.buffer, ctx.Err(), ctx.maxStringBytes, ctx.maxMapSize) value.Set(reflect.ValueOf(result)) } @@ -787,7 +841,7 @@ func (s stringInt64MapSerializer) ReadData(ctx *ReadContext, value reflect.Value value.Set(reflect.MakeMap(value.Type())) } ctx.RefResolver().Reference(value) - result := readMapStringInt64(ctx.buffer, ctx.Err()) + result := readMapStringInt64(ctx.buffer, ctx.Err(), ctx.maxStringBytes, ctx.maxMapSize) value.Set(reflect.ValueOf(result)) } @@ -822,7 +876,7 @@ func (s stringIntMapSerializer) ReadData(ctx *ReadContext, value reflect.Value) value.Set(reflect.MakeMap(value.Type())) } ctx.RefResolver().Reference(value) - result := readMapStringInt(ctx.buffer, ctx.Err()) + result := readMapStringInt(ctx.buffer, ctx.Err(), ctx.maxStringBytes, ctx.maxMapSize) value.Set(reflect.ValueOf(result)) } @@ -857,7 +911,7 @@ func (s stringFloat64MapSerializer) ReadData(ctx *ReadContext, value reflect.Val value.Set(reflect.MakeMap(value.Type())) } ctx.RefResolver().Reference(value) - result := readMapStringFloat64(ctx.buffer, ctx.Err()) + result := readMapStringFloat64(ctx.buffer, ctx.Err(), ctx.maxStringBytes, ctx.maxMapSize) value.Set(reflect.ValueOf(result)) } @@ -892,7 +946,7 @@ func (s stringBoolMapSerializer) ReadData(ctx *ReadContext, value reflect.Value) value.Set(reflect.MakeMap(value.Type())) } ctx.RefResolver().Reference(value) - result := readMapStringBool(ctx.buffer, ctx.Err()) + result := readMapStringBool(ctx.buffer, ctx.Err(), ctx.maxStringBytes, ctx.maxMapSize) value.Set(reflect.ValueOf(result)) } diff --git a/go/fory/reader.go b/go/fory/reader.go index e7a1df1710..c455fffa0d 100644 --- a/go/fory/reader.go +++ b/go/fory/reader.go @@ -43,6 +43,9 @@ type ReadContext struct { err Error // Accumulated error state for deferred checking lastTypePtr uintptr lastTypeInfo *TypeInfo + maxStringBytes int + maxCollectionSize int + maxMapSize int } // IsXlang returns whether cross-language serialization mode is enabled @@ -224,7 +227,7 @@ func (c *ReadContext) readFast(ptr unsafe.Pointer, ct DispatchId) { case PrimitiveFloat16DispatchId: *(*uint16)(ptr) = c.buffer.ReadUint16(err) case StringDispatchId: - *(*string)(ptr) = readString(c.buffer, err) + *(*string)(ptr) = readString(c.buffer, err, c.maxStringBytes) } } @@ -251,7 +254,7 @@ func (c *ReadContext) ReadLength() int { // ReadString reads a string value (caller handles nullable/type meta) func (c *ReadContext) ReadString() string { - return readString(c.buffer, c.Err()) + return readString(c.buffer, c.Err(), c.maxStringBytes) } // ReadBoolSlice reads []bool with ref/type info @@ -914,3 +917,26 @@ func (c *ReadContext) ReadArrayValue(target reflect.Value, refMode RefMode, read c.RefResolver().SetReadObject(refID, target) } } + +func (ctx *ReadContext) checkStringBytes(n int) { + if ctx.maxStringBytes > 0 && n > ctx.maxStringBytes { + ctx.SetError(DeserializationErrorf( + "fory: string byte length %d exceeds limit %d", n, ctx.maxStringBytes)) + } +} + +func (ctx *ReadContext) checkCollectionSize(n int) { + if ctx.maxCollectionSize > 0 && n > ctx.maxCollectionSize { + ctx.SetError(DeserializationErrorf( + "fory: collection size %d exceeds limit %d", n, ctx.maxCollectionSize)) + } +} + +func (ctx *ReadContext) checkMapSize(n int) { + if ctx.maxMapSize > 0 && n > ctx.maxMapSize { + ctx.SetError(DeserializationErrorf( + "fory: map size %d exceeds limit %d", n, ctx.maxMapSize)) + } +} + + diff --git a/go/fory/slice.go b/go/fory/slice.go index bd3a9aa7ee..2e8c1efd1a 100644 --- a/go/fory/slice.go +++ b/go/fory/slice.go @@ -265,6 +265,8 @@ func (s *sliceSerializer) ReadData(ctx *ReadContext, value reflect.Value) { buf := ctx.Buffer() ctxErr := ctx.Err() length := int(buf.ReadVarUint32(ctxErr)) + ctx.checkCollectionSize(length) + if ctx.HasError() { return } isArrayType := value.Type().Kind() == reflect.Array if length == 0 { diff --git a/go/fory/slice_primitive.go b/go/fory/slice_primitive.go index e4daf990be..fce61d7d25 100644 --- a/go/fory/slice_primitive.go +++ b/go/fory/slice_primitive.go @@ -643,6 +643,8 @@ func (s stringSliceSerializer) ReadData(ctx *ReadContext, value reflect.Value) { buf := ctx.Buffer() ctxErr := ctx.Err() length := int(buf.ReadVarUint32(ctxErr)) + ctx.checkCollectionSize(length) + if ctx.HasError() { return } ptr := (*[]string)(value.Addr().UnsafePointer()) if length == 0 { *ptr = make([]string, 0) @@ -670,7 +672,7 @@ func (s stringSliceSerializer) ReadData(ctx *ReadContext, value reflect.Value) { continue // null string, leave as zero value } } - result[i] = readString(buf, ctxErr) + result[i] = readString(buf, ctxErr, ctx.maxStringBytes) } *ptr = result } diff --git a/go/fory/string.go b/go/fory/string.go index 10586a27cd..94275cb875 100644 --- a/go/fory/string.go +++ b/go/fory/string.go @@ -49,11 +49,17 @@ func writeString(buf *ByteBuffer, value string) { } // readString reads a string from buffer using xlang encoding -func readString(buf *ByteBuffer, err *Error) string { +func readString(buf *ByteBuffer, err *Error, maxBytes int) string { header := buf.ReadVaruint36Small(err) size := header >> 2 // Extract byte count encoding := header & 0b11 // Extract encoding type + if maxBytes > 0 && int(size) > maxBytes { + err.SetError(DeserializationErrorf( + "fory: string byte length %d exceeds limit %d", int(size), maxBytes)) + return "" + } + switch encoding { case encodingLatin1: return readLatin1(buf, int(size), err) @@ -132,7 +138,7 @@ func (s stringSerializer) Write(ctx *WriteContext, refMode RefMode, writeType bo func (s stringSerializer) ReadData(ctx *ReadContext, value reflect.Value) { err := ctx.Err() - str := readString(ctx.buffer, err) + str := readString(ctx.buffer, err, ctx.maxStringBytes) if ctx.HasError() { return } @@ -202,7 +208,7 @@ func (s ptrToStringSerializer) Read(ctx *ReadContext, refMode RefMode, readType func (s ptrToStringSerializer) ReadData(ctx *ReadContext, value reflect.Value) { err := ctx.Err() - str := readString(ctx.buffer, err) + str := readString(ctx.buffer, err, ctx.maxStringBytes) if ctx.HasError() { return }