Skip to content

Commit 99e2db4

Browse files
committed
onnx-converter: new npm workspace to convert GPT2 from ONNX to TFJS
1 parent c3301b6 commit 99e2db4

17 files changed

Lines changed: 10580 additions & 117 deletions

cli/src/hellaswag_gpt.ts

Lines changed: 79 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,50 +1,104 @@
1+
// import fs from 'fs';
2+
import fsPromise from 'node:fs/promises';
3+
4+
import { dirname } from 'path';
5+
import { fileURLToPath } from 'url';
6+
import { parse } from 'ts-command-line-args'
7+
18
import '@tensorflow/tfjs-node';
29
import fs from 'node:fs';
310
import path from 'node:path';
4-
import { Tokenizer, models } from '@epfml/discojs';
11+
import { models, serialization, Tokenizer } from '@epfml/discojs';
512
import { loadHellaSwag } from '@epfml/discojs-node';
13+
// import { AutoTokenizer } from '@xenova/transformers';
614

7-
const logFile = path.join('..', 'datasets', 'LogFile_hellaswag.txt');
8-
const logLines: string[] = [];
15+
const __dirname = dirname(fileURLToPath(import.meta.url));
916

17+
const logLines: string[] = [];
1018
function log(message: string) {
1119
console.log(message);
1220
logLines.push(message);
1321
}
1422

15-
const hellaswagDataset: models.HellaSwagDataset = await loadHellaSwag(-1)
16-
17-
async function evaluateTFJS(tokenizer: Tokenizer) {
18-
const model = new models.GPT({ seed: 42 });
19-
log('Evaluating TFJS GPT on HellaSwag...');
23+
async function evaluateModel(model: models.GPT | models.ONNXModel, numDataPoints = -1) {
24+
const hellaswagDataset: models.HellaSwagDataset = await loadHellaSwag(numDataPoints)
25+
const tokenizer = await Tokenizer.from_pretrained('Xenova/gpt2');
26+
log('Starting the HellaSwag benchmark...');
2027

2128
const start = Date.now();
22-
const accuracy = await models.evaluate_hellaswag(model, tokenizer, hellaswagDataset, false);
29+
const accuracy = await models.evaluate_hellaswag(model, tokenizer, hellaswagDataset, true);
2330
const duration = ((Date.now() - start) / 1000).toFixed(2);
2431

25-
log(`TFJS GPT Accuracy: ${(accuracy * 100).toFixed(2)}%`);
26-
log(`TFJS GPT Evaluation Time: ${duration} seconds`);
32+
log(`Final accuracy: ${(accuracy * 100).toFixed(2)}%`);
33+
log(`Evaluation Time: ${duration} seconds`);
2734
}
2835

29-
async function evaluateXenova(tokenizer: Tokenizer) {
30-
const model = await models.ONNXModel.init_pretrained('Xenova/gpt2');
31-
log('Evaluating Xenova GPT-2 (ONNX) on HellaSwag...');
36+
const ModelTypes = ['onnx', 'gpt-tfjs-random', 'gpt-tfjs-pretrained'] as const;
37+
type ModelType = typeof ModelTypes[number];
3238

33-
const start = Date.now();
34-
const accuracy = await models.evaluate_hellaswag(model, tokenizer, hellaswagDataset, false);
35-
const duration = ((Date.now() - start) / 1000).toFixed(2);
36-
37-
log(`Xenova GPT-2 Accuracy: ${(accuracy * 100).toFixed(2)}%`);
38-
log(`Xenova GPT-2 Evaluation Time: ${duration} seconds`);
39+
interface HellaSwagArgs {
40+
model: ModelType
41+
numDataPoints: number
42+
logFile: string
43+
pretrainedModelPath: string
44+
help?: boolean
3945
}
4046

4147
async function main(): Promise<void> {
42-
fs.writeFileSync(logFile, '', 'utf-8'); // Clear old log file
48+
const defaultPretrainedModelPath = path.join(__dirname, "..", "..", "onnx-converter", "assets", "model.json")
49+
const args = parse<HellaSwagArgs>({
50+
model: {
51+
type: (raw: string) => raw as ModelType,
52+
description: `Model type, one of ${ModelTypes}`,
53+
defaultValue: 'onnx'
54+
},
55+
numDataPoints: {
56+
type: Number,
57+
description: 'Number of HellaSwag datapoints to evaluate, set -1 for the whole benchmark',
58+
defaultValue: -1
59+
},
60+
logFile: {
61+
type: String,
62+
description: 'Relative path to the log file, default to ./hellaswag.log', defaultValue: 'hellaswag.log'
63+
},
64+
pretrainedModelPath: {
65+
type: String,
66+
description: 'If specifying gpt-tfjs-pretrained, provide the relative path to the TF.js pretrained model',
67+
defaultValue: defaultPretrainedModelPath
68+
},
69+
help: {
70+
type: Boolean,
71+
optional: true,
72+
alias: 'h',
73+
description: 'Prints this usage guide'
74+
}
75+
}, { helpArg: 'help' })
4376

44-
const tokenizer = await Tokenizer.from_pretrained('Xenova/gpt2');
45-
await evaluateTFJS(tokenizer);
46-
log('\n---\n');
47-
await evaluateXenova(tokenizer);
77+
const logFile = path.join(__dirname, args.logFile);
78+
fs.writeFileSync(logFile, '', 'utf-8'); // Clear the log file
79+
80+
let model: | models.GPT | models.ONNXModel | undefined;
81+
switch (args.model) {
82+
case 'onnx':
83+
log("Using ONNX pretrained model Xenova/gpt2")
84+
model = await models.ONNXModel.init_pretrained('Xenova/gpt2');
85+
break;
86+
case 'gpt-tfjs-random':
87+
log("Using GPT-TFJS with random initialization")
88+
model = new models.GPT({ seed: 42 });
89+
break;
90+
case 'gpt-tfjs-pretrained':
91+
log("Using GPT-TFJS with pretrained weights")
92+
if (args.pretrainedModelPath === undefined) {
93+
throw new Error("If choosing gpt-tfjs-pretrained, provide the relative path to the TF.js pretrained model `pretrainedModelPath")
94+
}
95+
const encodedModel = await fsPromise.readFile(args.pretrainedModelPath);
96+
model = await serialization.model.decode(encodedModel) as models.GPT;
97+
break;
98+
default:
99+
throw new Error(`Unrecognized model type: ${model}`);
100+
}
101+
await evaluateModel(model, args.numDataPoints);
48102

49103
fs.writeFileSync(logFile, logLines.join('\n'), 'utf-8');
50104
console.log(`\nResults written to ${logFile}`);

datasets/.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,6 @@
2020

2121
# GDHF demo
2222
/tinder_dog/
23+
24+
# HellaSwag benchmark
25+
hellaswag*

discojs/src/models/gpt/layers.spec.ts

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,8 +174,9 @@ describe('GPT Layers', () => {
174174
name: 'testCSA',
175175
contextLength: 5,
176176
nHead: 2,
177-
nEmbd: 8, // divisible by nHead, so head size = 4
178-
dropout: 0.0, // no dropout for deterministic tests
177+
nEmbd: 8, // divisible by nHead, so head size = 4
178+
attnDrop: 0.0, // no dropout for deterministic tests
179+
residDrop: 0.0,
179180
nLayer: 2,
180181
seed: 42
181182
};

discojs/src/models/hellaswag.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ type ModelType = GPT | ONNXModel;
126126
export async function evaluate(
127127
model: ModelType,
128128
tokenizer: Tokenizer,
129-
dataset: HellaSwagExample[],
129+
dataset: HellaSwagDataset,
130130
print = true
131131
): Promise<number> {
132132
let correct = 0;

onnx-converter/.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
node_modules
2+
assets
3+
dist

onnx-converter/README.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
Structure:
2+
1. Read the ONNX model from Xenova's repository
3+
2. Use onnx.js protobuf to read the file and iterate through the layers: https://github.com/microsoft/onnxruntime/blob/main/js/web/lib/onnxjs/
4+
3. Create a map from layer to weight and convert each weight to TF.js tensor
5+
4. Init a TF.js model with the loaded weights and export the model
6+
7+
Run `npm run convert_onnx` to create GPT-tfjs `model.json` file in the `./assets/` folder.

onnx-converter/package.json

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
{
2+
"name": "onnx-converter",
3+
"private": true,
4+
"type": "module",
5+
"main": "dist/gpt2_from_onnx.js",
6+
"scripts": {
7+
"convert_onnx": "npm run build && node dist/convert_onnx.js",
8+
"build": "tsc && cp -r src/protobuf dist/protobuf",
9+
"lint": "npx eslint .",
10+
"test": ": nothing"
11+
},
12+
"author": "",
13+
"license": "ISC",
14+
"dependencies": {
15+
"@epfml/discojs-node": "*",
16+
"@eslint/compat": "^1.4.0",
17+
"@eslint/eslintrc": "^3.3.1",
18+
"@eslint/js": "^9.36.0",
19+
"globals": "^16.4.0",
20+
"onnxruntime-web": "^1.22.0",
21+
"server": "*",
22+
"tslib": "2"
23+
},
24+
"devDependencies": {
25+
"nodemon": "3",
26+
"ts-command-line-args": "2"
27+
}
28+
}

0 commit comments

Comments
 (0)