diff --git a/tfjs-core/src/io/io_utils.ts b/tfjs-core/src/io/io_utils.ts index 25dbe5fd46..8803c5a4e5 100644 --- a/tfjs-core/src/io/io_utils.ts +++ b/tfjs-core/src/io/io_utils.ts @@ -146,8 +146,23 @@ function getWeightBytelength(spec: WeightsManifestEntry, // Can not statically determine string length. let byteLength = 0; for (let i = 0; i < size; i++) { - byteLength += NUM_BYTES_STRING_LENGTH + new Uint32Array( - slice(byteLength, byteLength + NUM_BYTES_STRING_LENGTH))[0]; + const lengthBuffer = slice( + byteLength, byteLength + NUM_BYTES_STRING_LENGTH); + if (lengthBuffer.byteLength !== NUM_BYTES_STRING_LENGTH) { + throw new Error( + `Invalid string tensor ${spec.name}: expected ${NUM_BYTES_STRING_LENGTH} ` + + `bytes for string ${i} length, found ${lengthBuffer.byteLength}.`); + } + const stringByteLength = new Uint32Array(lengthBuffer)[0]; + const stringBuffer = slice( + byteLength + NUM_BYTES_STRING_LENGTH, + byteLength + NUM_BYTES_STRING_LENGTH + stringByteLength); + if (stringBuffer.byteLength !== stringByteLength) { + throw new Error( + `Invalid string tensor ${spec.name}: expected ${stringByteLength} ` + + `bytes for string ${i}, found ${stringBuffer.byteLength}.`); + } + byteLength += NUM_BYTES_STRING_LENGTH + stringByteLength; } return byteLength; } else { @@ -171,8 +186,33 @@ async function getWeightBytelengthAsync( // Can not statically determine string length. let byteLength = 0; for (let i = 0; i < size; i++) { - byteLength += NUM_BYTES_STRING_LENGTH + new Uint32Array( - await slice(byteLength, byteLength + NUM_BYTES_STRING_LENGTH))[0]; + let lengthBuffer: ArrayBuffer; + try { + lengthBuffer = await slice( + byteLength, byteLength + NUM_BYTES_STRING_LENGTH); + } catch { + lengthBuffer = new ArrayBuffer(0); + } + if (lengthBuffer.byteLength !== NUM_BYTES_STRING_LENGTH) { + throw new Error( + `Invalid string tensor ${spec.name}: expected ${NUM_BYTES_STRING_LENGTH} ` + + `bytes for string ${i} length, found ${lengthBuffer.byteLength}.`); + } + const stringByteLength = new Uint32Array(lengthBuffer)[0]; + let stringBuffer: ArrayBuffer; + try { + stringBuffer = await slice( + byteLength + NUM_BYTES_STRING_LENGTH, + byteLength + NUM_BYTES_STRING_LENGTH + stringByteLength); + } catch { + stringBuffer = new ArrayBuffer(0); + } + if (stringBuffer.byteLength !== stringByteLength) { + throw new Error( + `Invalid string tensor ${spec.name}: expected ${stringByteLength} ` + + `bytes for string ${i}, found ${stringBuffer.byteLength}.`); + } + byteLength += NUM_BYTES_STRING_LENGTH + stringByteLength; } return byteLength; } else { @@ -253,11 +293,22 @@ function decodeWeight( const size = sizeFromShape(spec.shape); values = []; for (let i = 0; i < size; i++) { - const byteLength = new Uint32Array( - byteBuffer.slice(offset, offset + NUM_BYTES_STRING_LENGTH))[0]; + const lengthBuffer = byteBuffer.slice( + offset, offset + NUM_BYTES_STRING_LENGTH); + if (lengthBuffer.byteLength !== NUM_BYTES_STRING_LENGTH) { + throw new Error( + `Invalid string tensor ${name}: expected ${NUM_BYTES_STRING_LENGTH} ` + + `bytes for string ${i} length, found ${lengthBuffer.byteLength}.`); + } + const byteLength = new Uint32Array(lengthBuffer)[0]; offset += NUM_BYTES_STRING_LENGTH; - const bytes = new Uint8Array( - byteBuffer.slice(offset, offset + byteLength)); + const bytesBuffer = byteBuffer.slice(offset, offset + byteLength); + if (bytesBuffer.byteLength !== byteLength) { + throw new Error( + `Invalid string tensor ${name}: expected ${byteLength} ` + + `bytes for string ${i}, found ${bytesBuffer.byteLength}.`); + } + const bytes = new Uint8Array(bytesBuffer); (values as Uint8Array[]).push(bytes); offset += byteLength; } diff --git a/tfjs-core/src/io/io_utils_test.ts b/tfjs-core/src/io/io_utils_test.ts index 6d71028853..c2df39e970 100644 --- a/tfjs-core/src/io/io_utils_test.ts +++ b/tfjs-core/src/io/io_utils_test.ts @@ -543,6 +543,19 @@ describeWithFlags('decodeWeights', {}, () => { .toBeRejectedWithError(/Unsupported dtype in weight \'x\': int16/); }); + it('truncated string tensor raises Error', async () => { + const buffer = new ArrayBuffer(5); + const view = new DataView(buffer); + view.setUint32(0, 10, true); + new Uint8Array(buffer)[4] = 65; + const specs: WeightsManifestEntry[] = [ + {name: 'x', dtype: 'string', shape: [1]} + ]; + + await expectAsync(decode(buffer, specs)) + .toBeRejectedWithError(/expected 10 bytes for string 0/); + }); + it('support quantization uint8 weights', async () => { const manifestSpecs: WeightsManifestEntry[] = [ {