Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 59 additions & 8 deletions tfjs-core/src/io/io_utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,23 @@ function getWeightBytelength(spec: WeightsManifestEntry,
// Can not statically determine string length.
let byteLength = 0;
for (let i = 0; i < size; i++) {
byteLength += NUM_BYTES_STRING_LENGTH + new Uint32Array(
slice(byteLength, byteLength + NUM_BYTES_STRING_LENGTH))[0];
const lengthBuffer = slice(
byteLength, byteLength + NUM_BYTES_STRING_LENGTH);
if (lengthBuffer.byteLength !== NUM_BYTES_STRING_LENGTH) {
throw new Error(
`Invalid string tensor ${spec.name}: expected ${NUM_BYTES_STRING_LENGTH} ` +
`bytes for string ${i} length, found ${lengthBuffer.byteLength}.`);
}
const stringByteLength = new Uint32Array(lengthBuffer)[0];
const stringBuffer = slice(
byteLength + NUM_BYTES_STRING_LENGTH,
byteLength + NUM_BYTES_STRING_LENGTH + stringByteLength);
if (stringBuffer.byteLength !== stringByteLength) {
throw new Error(
`Invalid string tensor ${spec.name}: expected ${stringByteLength} ` +
`bytes for string ${i}, found ${stringBuffer.byteLength}.`);
}
byteLength += NUM_BYTES_STRING_LENGTH + stringByteLength;
}
return byteLength;
} else {
Expand All @@ -171,8 +186,33 @@ async function getWeightBytelengthAsync(
// Can not statically determine string length.
let byteLength = 0;
for (let i = 0; i < size; i++) {
byteLength += NUM_BYTES_STRING_LENGTH + new Uint32Array(
await slice(byteLength, byteLength + NUM_BYTES_STRING_LENGTH))[0];
let lengthBuffer: ArrayBuffer;
try {
lengthBuffer = await slice(
byteLength, byteLength + NUM_BYTES_STRING_LENGTH);
} catch {
lengthBuffer = new ArrayBuffer(0);
}
if (lengthBuffer.byteLength !== NUM_BYTES_STRING_LENGTH) {
throw new Error(
`Invalid string tensor ${spec.name}: expected ${NUM_BYTES_STRING_LENGTH} ` +
`bytes for string ${i} length, found ${lengthBuffer.byteLength}.`);
}
const stringByteLength = new Uint32Array(lengthBuffer)[0];
let stringBuffer: ArrayBuffer;
try {
stringBuffer = await slice(
byteLength + NUM_BYTES_STRING_LENGTH,
byteLength + NUM_BYTES_STRING_LENGTH + stringByteLength);
} catch {
stringBuffer = new ArrayBuffer(0);
}
if (stringBuffer.byteLength !== stringByteLength) {
throw new Error(
`Invalid string tensor ${spec.name}: expected ${stringByteLength} ` +
`bytes for string ${i}, found ${stringBuffer.byteLength}.`);
}
byteLength += NUM_BYTES_STRING_LENGTH + stringByteLength;
}
return byteLength;
} else {
Expand Down Expand Up @@ -253,11 +293,22 @@ function decodeWeight(
const size = sizeFromShape(spec.shape);
values = [];
for (let i = 0; i < size; i++) {
const byteLength = new Uint32Array(
byteBuffer.slice(offset, offset + NUM_BYTES_STRING_LENGTH))[0];
const lengthBuffer = byteBuffer.slice(
offset, offset + NUM_BYTES_STRING_LENGTH);
if (lengthBuffer.byteLength !== NUM_BYTES_STRING_LENGTH) {
throw new Error(
`Invalid string tensor ${name}: expected ${NUM_BYTES_STRING_LENGTH} ` +
`bytes for string ${i} length, found ${lengthBuffer.byteLength}.`);
}
const byteLength = new Uint32Array(lengthBuffer)[0];
offset += NUM_BYTES_STRING_LENGTH;
const bytes = new Uint8Array(
byteBuffer.slice(offset, offset + byteLength));
const bytesBuffer = byteBuffer.slice(offset, offset + byteLength);
if (bytesBuffer.byteLength !== byteLength) {
throw new Error(
`Invalid string tensor ${name}: expected ${byteLength} ` +
`bytes for string ${i}, found ${bytesBuffer.byteLength}.`);
}
const bytes = new Uint8Array(bytesBuffer);
(values as Uint8Array[]).push(bytes);
offset += byteLength;
}
Expand Down
13 changes: 13 additions & 0 deletions tfjs-core/src/io/io_utils_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,19 @@ describeWithFlags('decodeWeights', {}, () => {
.toBeRejectedWithError(/Unsupported dtype in weight \'x\': int16/);
});

it('truncated string tensor raises Error', async () => {
const buffer = new ArrayBuffer(5);
const view = new DataView(buffer);
view.setUint32(0, 10, true);
new Uint8Array(buffer)[4] = 65;
const specs: WeightsManifestEntry[] = [
{name: 'x', dtype: 'string', shape: [1]}
];

await expectAsync(decode(buffer, specs))
.toBeRejectedWithError(/expected 10 bytes for string 0/);
});

it('support quantization uint8 weights', async () => {
const manifestSpecs: WeightsManifestEntry[] = [
{
Expand Down