From f7013bd25ab522b032c71c012b643ef318589725 Mon Sep 17 00:00:00 2001 From: hugo Date: Fri, 12 Jun 2026 09:51:43 +0200 Subject: [PATCH 1/6] ci: add prettier dev dependency --- package-lock.json | 10 +++---- package.json | 69 ++++++++++++++++++++++++----------------------- 2 files changed, 41 insertions(+), 38 deletions(-) diff --git a/package-lock.json b/package-lock.json index bc477ea6e..12342d425 100644 --- a/package-lock.json +++ b/package-lock.json @@ -30,6 +30,7 @@ "eslint": "10", "eslint-plugin-cypress": "6", "eslint-plugin-vue": "10", + "prettier": "^3.8.4", "typescript": "5", "typescript-eslint": "8" } @@ -10712,12 +10713,11 @@ } }, "node_modules/prettier": { - "version": "3.6.2", - "resolved": "https://registry.npmjs.org/prettier/-/prettier-3.6.2.tgz", - "integrity": "sha512-I7AIg5boAr5R0FFtJ6rCfD+LFsWHp81dolrFD8S79U9tb8Az2nGrJncnMSnys+bpQJfRUzqs9hnA81OAA3hCuQ==", + "version": "3.8.4", + "resolved": "https://registry.npmjs.org/prettier/-/prettier-3.8.4.tgz", + "integrity": "sha512-N2MylSdi48+5N/6S5j+maeHbUSIzzZ5uOcX5Hm4QpV8Dkb1HFjfAKTKX6yNPJQD9AhcT3ifHNB66tWTTJDi11Q==", "dev": true, "license": "MIT", - "peer": true, "bin": { "prettier": "bin/prettier.cjs" }, @@ -13815,7 +13815,7 @@ "vee-validate": "5.0.0-beta.0", "vue": "3", "vue-chartjs": "5", - "vue-router": "^5.0.7", + "vue-router": "5", "vue-tippy": "6", "vue-toast-notification": "3", "yup": "1", diff --git a/package.json b/package.json index cd6730fd9..5bcb4cfc6 100644 --- a/package.json +++ b/package.json @@ -1,35 +1,38 @@ { - "name": "disco", - "type": "module", - "scripts": { - "lint": "eslint" - }, - "workspaces": [ - "cli", - "discojs", - "discojs-node", - "discojs-web", - "server", - "webapp", - "onnx-converter" - ], - "dependencies": { - "@epfml/isomorphic-wrtc": "file:isomorphic-wrtc", - "debug": "4", - "immutable": "5", - "vitest": "4", - "zod": "4" - }, - "devDependencies": { - "@types/debug": "4", - "@types/node": "22", - "@vitest/eslint-plugin": "1", - "@vue/eslint-config-prettier": "10", - "@vue/eslint-config-typescript": "14", - "eslint": "10", - "eslint-plugin-cypress": "6", - "eslint-plugin-vue": "10", - "typescript": "5", - "typescript-eslint": "8" - } + "name": "disco", + "type": "module", + "scripts": { + "lint": "eslint", + "format:check": "prettier -c .", + "format:fix": "prettier ." + }, + "workspaces": [ + "cli", + "discojs", + "discojs-node", + "discojs-web", + "server", + "webapp", + "onnx-converter" + ], + "dependencies": { + "@epfml/isomorphic-wrtc": "file:isomorphic-wrtc", + "debug": "4", + "immutable": "5", + "vitest": "4", + "zod": "4" + }, + "devDependencies": { + "@types/debug": "4", + "@types/node": "22", + "@vitest/eslint-plugin": "1", + "@vue/eslint-config-prettier": "10", + "@vue/eslint-config-typescript": "14", + "eslint": "10", + "eslint-plugin-cypress": "6", + "eslint-plugin-vue": "10", + "prettier": "^3.8.4", + "typescript": "5", + "typescript-eslint": "8" + } } From 2be61501c1d431f0ebea0ed19f400d71a30fe8e3 Mon Sep 17 00:00:00 2001 From: hugo Date: Fri, 12 Jun 2026 09:52:07 +0200 Subject: [PATCH 2/6] ci: prettier ignore irrelevant or erroring files --- .prettierignore | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 .prettierignore diff --git a/.prettierignore b/.prettierignore new file mode 100644 index 000000000..81ba31000 --- /dev/null +++ b/.prettierignore @@ -0,0 +1,3 @@ +package-lock.json +# Prettier errors on this file, not sure why +datasets/wikitext/wiki.train.tokens From 0fb4119362bfe0848bf33f5464c03f3a5ab19d29 Mon Sep 17 00:00:00 2001 From: hugo Date: Fri, 12 Jun 2026 10:04:41 +0200 Subject: [PATCH 3/6] ci: prettier ignore all datasets speeds up formatting significantly --- .prettierignore | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/.prettierignore b/.prettierignore index 81ba31000..9345a1319 100644 --- a/.prettierignore +++ b/.prettierignore @@ -1,3 +1,2 @@ package-lock.json -# Prettier errors on this file, not sure why -datasets/wikitext/wiki.train.tokens +datasets From d73ec26538cf17a69ffd14ca8a105c4b1b1ccb88 Mon Sep 17 00:00:00 2001 From: hugo Date: Fri, 12 Jun 2026 10:05:06 +0200 Subject: [PATCH 4/6] ci: add format check to ci --- .github/workflows/lint-test-build.yml | 35 +++++++++++++++++++++++++-- 1 file changed, 33 insertions(+), 2 deletions(-) diff --git a/.github/workflows/lint-test-build.yml b/.github/workflows/lint-test-build.yml index e12f20f71..43ae60f74 100644 --- a/.github/workflows/lint-test-build.yml +++ b/.github/workflows/lint-test-build.yml @@ -18,8 +18,32 @@ jobs: key: datasets-${{ hashFiles('datasets/**') }} - run: datasets/populate + format-check: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v6 + - uses: actions/setup-node@v6 + with: + node-version-file: .nvmrc + - uses: actions/cache@v5 + with: + path: | + ~/.npm + ~/.cache/Cypress + key: npm-${{ runner.os }}-${{ hashFiles('package-lock.json') }} + - run: npm ci + - run: npm run format:check + lint-most: - needs: [build-cli, build-lib, build-lib-node, build-lib-web, build-server, build-webapp] + needs: + [ + build-cli, + build-lib, + build-lib-node, + build-lib-web, + build-server, + build-webapp, + ] runs-on: ubuntu-latest steps: - uses: actions/checkout@v6 @@ -175,7 +199,14 @@ jobs: working-directory: docs/examples test-most: - needs: [build-lib, build-lib-node, build-lib-web, build-server, download-datasets] + needs: + [ + build-lib, + build-lib-node, + build-lib-web, + build-server, + download-datasets, + ] runs-on: ubuntu-latest steps: - uses: actions/checkout@v6 From b58129965b22e54e85767927791749c26d3f25ab Mon Sep 17 00:00:00 2001 From: hugo Date: Fri, 12 Jun 2026 10:07:37 +0200 Subject: [PATCH 5/6] ci: fix prettier format command --- package.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/package.json b/package.json index 5bcb4cfc6..7d33024b9 100644 --- a/package.json +++ b/package.json @@ -4,7 +4,7 @@ "scripts": { "lint": "eslint", "format:check": "prettier -c .", - "format:fix": "prettier ." + "format:fix": "prettier -w ." }, "workspaces": [ "cli", From 93bee1d79ef1d1e820aebf20d1931568369f931d Mon Sep 17 00:00:00 2001 From: hugo Date: Fri, 12 Jun 2026 10:08:17 +0200 Subject: [PATCH 6/6] chore: format project --- .github/ISSUE_TEMPLATE/bug_report.md | 11 +- .github/ISSUE_TEMPLATE/feature_request.md | 5 +- .github/workflows/record-cypress.yml | 2 +- README.md | 43 +- app.yaml | 4 +- cli/README.md | 12 +- cli/package.json | 48 +- cli/src/args.ts | 215 +- cli/src/benchmark_gpt.ts | 161 +- cli/src/cli.ts | 100 +- cli/src/data.ts | 68 +- cli/src/hellaswag_gpt.ts | 215 +- cli/src/python/README.md | 30 +- cli/src/python/experiments/basic_tests.json | 6 +- cli/src/train_gpt.ts | 38 +- cli/src/user_log.ts | 66 +- cli/tsconfig.json | 6 +- discojs-node/package.json | 62 +- discojs-node/src/hellaswag.spec.ts | 40 +- discojs-node/src/hellaswag.ts | 116 +- discojs-node/src/index.ts | 6 +- discojs-node/src/loaders.spec.ts | 74 +- discojs-node/src/loaders/index.ts | 5 +- discojs-node/src/loaders/text.ts | 12 +- discojs-node/tsconfig.lib.json | 8 +- discojs-node/tsconfig.vitest.json | 6 +- discojs-web/package.json | 56 +- discojs-web/src/hellaswag.spec.ts | 41 +- discojs-web/src/hellaswag.ts | 72 +- discojs-web/src/index.ts | 2 +- discojs-web/src/loaders.spec.ts | 8 +- discojs-web/src/loaders/csv.ts | 2 +- discojs-web/src/loaders/index.ts | 6 +- discojs-web/tsconfig.lib.json | 14 +- discojs-web/tsconfig.vitest.json | 12 +- discojs/README.md | 5 +- discojs/package.json | 74 +- discojs/src/aggregator.spec.ts | 35 +- discojs/src/aggregator/aggregator.ts | 154 +- discojs/src/aggregator/byzantine.spec.ts | 28 +- discojs/src/aggregator/byzantine.ts | 76 +- discojs/src/aggregator/get.ts | 84 +- discojs/src/aggregator/index.ts | 10 +- discojs/src/aggregator/mean.spec.ts | 71 +- discojs/src/aggregator/mean.ts | 13 +- discojs/src/aggregator/multiround.ts | 52 +- discojs/src/aggregator/secure.spec.ts | 36 +- discojs/src/aggregator/secure.ts | 5 +- discojs/src/aggregator/secure_history.spec.ts | 255 +- discojs/src/aggregator/secure_history.ts | 9 +- discojs/src/client/client.ts | 158 +- .../decentralized/decentralized_client.ts | 342 +- discojs/src/client/decentralized/index.ts | 4 +- discojs/src/client/decentralized/messages.ts | 116 +- discojs/src/client/decentralized/peer.spec.ts | 79 +- discojs/src/client/decentralized/peer.ts | 278 +- .../client/decentralized/peer_pool.spec.ts | 205 +- discojs/src/client/decentralized/peer_pool.ts | 65 +- discojs/src/client/event_connection.ts | 161 +- .../src/client/federated/federated_client.ts | 85 +- discojs/src/client/federated/index.ts | 4 +- discojs/src/client/federated/messages.ts | 52 +- discojs/src/client/index.ts | 16 +- discojs/src/client/local_client.ts | 7 +- discojs/src/client/messages.ts | 52 +- discojs/src/client/types.ts | 6 +- discojs/src/client/utils.ts | 73 +- discojs/src/dataset/dataset.spec.ts | 41 +- discojs/src/dataset/dataset.ts | 69 +- discojs/src/dataset/types.ts | 2 +- discojs/src/default_tasks/cifar10.ts | 75 +- discojs/src/default_tasks/index.ts | 14 +- discojs/src/default_tasks/lus_covid.ts | 110 +- discojs/src/default_tasks/mnist.ts | 84 +- discojs/src/default_tasks/simple_face.ts | 42 +- discojs/src/default_tasks/tinder_dog.ts | 103 +- discojs/src/default_tasks/titanic.ts | 70 +- discojs/src/default_tasks/wikitext.ts | 41 +- discojs/src/index.ts | 35 +- discojs/src/logging/index.ts | 4 +- discojs/src/privacy.spec.ts | 38 +- discojs/src/privacy.ts | 37 +- discojs/src/processing/index.ts | 62 +- discojs/src/serialization/coder.ts | 6 +- discojs/src/serialization/index.ts | 18 +- discojs/src/serialization/model.spec.ts | 70 +- discojs/src/serialization/model.ts | 50 +- discojs/src/serialization/task.spec.ts | 8 +- discojs/src/serialization/task.ts | 58 +- discojs/src/serialization/weights.spec.ts | 38 +- discojs/src/task/display_information.ts | 64 +- discojs/src/task/task.ts | 104 +- discojs/src/task/task_handler.ts | 70 +- discojs/src/task/task_provider.ts | 6 +- discojs/src/task/training_information.ts | 216 +- discojs/src/training/disco.ts | 100 +- discojs/src/training/index.ts | 4 +- discojs/src/training/trainer.ts | 181 +- discojs/src/types/data_format.ts | 8 +- discojs/src/utils/async_iterator.spec.ts | 3 +- discojs/src/utils/event_emitter.ts | 32 +- discojs/src/validator.ts | 56 +- discojs/src/weights/aggregation.spec.ts | 48 +- discojs/src/weights/aggregation.ts | 60 +- discojs/src/weights/index.ts | 4 +- discojs/src/weights/weights_container.ts | 65 +- discojs/tsconfig.json | 2 +- discojs/tsconfig.lib.json | 8 +- discojs/tsconfig.vitest.json | 6 +- docs/CONTRIBUTING.md | 27 +- docs/FAQ.md | 23 +- docs/PRIVACY.md | 53 +- docs/examples/README.md | 3 +- docs/examples/custom_task.ts | 64 +- docs/examples/package.json | 42 +- docs/examples/training.ts | 38 +- docs/examples/wikitext.ts | 60 +- eslint.config.js | 116 +- isomorphic-wrtc/README.md | 4 +- isomorphic-wrtc/browser.js | 4 +- isomorphic-wrtc/node.js | 4 +- isomorphic-wrtc/package.json | 34 +- onnx-converter/README.md | 4 +- onnx-converter/package.json | 2 +- onnx-converter/src/convert_onnx.ts | 143 +- onnx-converter/src/protobuf/onnx-proto.d.ts | 6 +- onnx-converter/src/protobuf/onnx-proto.js | 6 +- onnx-converter/src/protobuf/onnx.cjs | 3350 +++++++++++------ onnx-converter/src/protobuf/onnx.d.ts | 1035 +++-- onnx-converter/tsconfig.json | 2 +- onnx-converter/tsconfig.lib.json | 6 +- server/package.json | 76 +- .../controllers/decentralized_controller.ts | 157 +- .../src/controllers/federated_controller.ts | 169 +- server/src/controllers/index.ts | 2 +- server/src/controllers/training_controller.ts | 78 +- server/src/index.ts | 2 +- server/src/main.ts | 16 +- server/src/routes/index.ts | 2 +- server/src/routes/task_router.ts | 160 +- server/src/routes/training_router.ts | 66 +- server/src/server.ts | 40 +- server/src/task_set.ts | 106 +- server/tests/client.spec.ts | 182 +- server/tests/e2e/decentralized.spec.ts | 312 +- server/tests/e2e/federated.spec.ts | 685 ++-- server/tests/utils.ts | 110 +- server/tests/validator.spec.ts | 112 +- server/tsconfig.json | 2 +- server/tsconfig.lib.json | 6 +- server/tsconfig.vitest.json | 6 +- tsconfig.base.json | 1 - vitest.config.ts | 50 +- webapp/cypress/e2e/store/models.cy.ts | 66 +- webapp/cypress/e2e/task-creation.cy.ts | 4 +- webapp/cypress/support/e2e.ts | 46 +- webapp/index.html | 18 +- webapp/package.json | 96 +- webapp/public/404.html | 22 +- webapp/src/assets/css/styles.css | 19 +- webapp/src/assets/css/tailwind.css | 4 +- webapp/src/assets/gif/DecentralizedGIF.vue | 10 +- webapp/src/assets/gif/DiscoGIF.vue | 6 +- webapp/src/assets/gif/FederatedGIF.vue | 10 +- webapp/src/assets/logos/AriadneLabsLogo.vue | 94 +- webapp/src/assets/logos/DiscoLogo.vue | 4 +- webapp/src/assets/logos/MLOLogo.vue | 6 +- webapp/src/assets/logos/TensorflowLogo.vue | 131 +- webapp/src/assets/svg/BinIcon.vue | 10 +- webapp/src/assets/svg/CreateIcon.vue | 18 +- webapp/src/assets/svg/DiscoParticlesIcon.vue | 40 +- webapp/src/assets/svg/DownArrow.vue | 11 +- webapp/src/assets/svg/EvaluateIcon.vue | 15 +- webapp/src/assets/svg/HomeIcon.vue | 10 +- webapp/src/assets/svg/InfoIcon.vue | 10 +- webapp/src/assets/svg/ModelExchangeIcon.vue | 17 +- webapp/src/assets/svg/ModelIcon.vue | 8 +- webapp/src/assets/svg/MoonIcon.vue | 12 +- webapp/src/assets/svg/PeopleIcon.vue | 10 +- webapp/src/assets/svg/PlugIcon.vue | 32 +- webapp/src/assets/svg/StackIcon.vue | 12 +- webapp/src/assets/svg/SunIcon.vue | 51 +- webapp/src/assets/svg/TimerIcon.vue | 10 +- webapp/src/assets/svg/UpArrow.vue | 11 +- webapp/src/components/App.vue | 35 +- .../src/components/containers/ButtonsCard.vue | 3 +- webapp/src/components/containers/Card.vue | 13 +- .../components/containers/IconCardHeader.vue | 18 +- .../dataset_input/DataDescription.vue | 2 +- webapp/src/components/pages/AboutUs.vue | 13 +- webapp/src/components/pages/HomePage.vue | 230 +- webapp/src/components/pages/NotFound.vue | 23 +- webapp/src/components/pages/TaskList.vue | 204 +- .../components/progress_bars/ProgressIcon.vue | 36 +- .../progress_bars/TestingButtons.vue | 10 +- .../components/progress_bars/TrainingBar.vue | 93 +- .../progress_bars/TrainingButtons.vue | 61 +- webapp/src/components/sidebar/SideBar.vue | 39 +- .../src/components/sidebar/SidebarButton.vue | 5 +- webapp/src/components/simple/CTAButton.vue | 12 +- webapp/src/components/simple/CheckBox.vue | 22 +- webapp/src/components/simple/CustomButton.vue | 10 +- .../components/simple/DISCOllaborative.vue | 4 +- .../components/simple/DISCOllaboratives.vue | 4 +- webapp/src/components/simple/ToggleButton.vue | 29 +- .../src/components/testing/PredictSteps.vue | 30 +- .../testing/__tests__/ModelLibrary.spec.ts | 15 +- .../components/training/TrainerDashboard.vue | 200 +- .../training/TrainingDescription.vue | 15 +- .../components/training/TrainingFinished.vue | 4 +- .../src/components/training/TrainingSteps.vue | 10 +- webapp/src/config.ts | 6 +- webapp/src/main.ts | 59 +- webapp/src/router/index.ts | 2 +- webapp/src/router/router.ts | 88 +- webapp/src/shims-vue.d.ts | 8 +- webapp/src/store/information.ts | 22 +- webapp/src/store/tasks.ts | 44 +- webapp/src/store/theme.ts | 17 +- webapp/src/store/training.ts | 62 +- webapp/src/store/tutorial.ts | 426 ++- webapp/src/utils.ts | 4 +- webapp/vite.config.ts | 2 +- webapp/vitest.config.ts | 18 +- 224 files changed, 9472 insertions(+), 6899 deletions(-) diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md index 14e36ea79..407281952 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.md +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -1,10 +1,9 @@ --- name: " \U0001F41E Bug report " about: Create a report to help us improve -title: '' +title: "" labels: bug -assignees: '' - +assignees: "" --- **Describe the bug** @@ -12,6 +11,7 @@ A clear and concise description of what the bug is. **To Reproduce** Steps to reproduce the behavior: + 1. Go to '...' 2. Click on '....' 3. Scroll down to '....' @@ -24,8 +24,9 @@ A clear and concise description of what you expected to happen. If applicable, add screenshots to help explain your problem. **Desktop (please complete the following information):** - - OS: [e.g. iOS] - - Browser [e.g. chrome, safari] + +- OS: [e.g. iOS] +- Browser [e.g. chrome, safari] **Additional context** Add any other context about the problem here. diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md index 34f39d2fa..5625e5a2c 100644 --- a/.github/ISSUE_TEMPLATE/feature_request.md +++ b/.github/ISSUE_TEMPLATE/feature_request.md @@ -1,10 +1,9 @@ --- name: "\U0001F680 Feature request" about: Suggest an idea for this project -title: '' +title: "" labels: enhancement -assignees: '' - +assignees: "" --- **Is your feature request related to a problem? Please describe.** diff --git a/.github/workflows/record-cypress.yml b/.github/workflows/record-cypress.yml index 6940efa02..89f11e322 100644 --- a/.github/workflows/record-cypress.yml +++ b/.github/workflows/record-cypress.yml @@ -76,7 +76,7 @@ jobs: working-directory: webapp install: false start: npm start - wait-on: 'http://localhost:8081' # Waits for above + wait-on: "http://localhost:8081" # Waits for above # Records to Cypress Cloud # https://docs.cypress.io/guides/cloud/projects#Set-up-a-project-to-record record: true diff --git a/README.md b/README.md index 7542799d0..704051d21 100644 --- a/README.md +++ b/README.md @@ -2,42 +2,49 @@ -

+

# **DISCO** - DIStributed COllaborative Machine Learning DISCO leverages federated :star2: and decentralized :sparkles: learning to allow several data owners to collaboratively build machine learning models without sharing any original data. The latest version is always running on the following link, for web and mobile: +

:man_dancing: https://discolab.ai/ :man_dancing:

-___ +--- + :magic_wand: **DEVELOPERS:** DISCO is written fully in JavaScript/TypeScript. Have a look at our [developer guide](DEV.md). -___ -:question: **WHY DISCO?** +--- + +:question: **WHY DISCO?** + - To build deep learning models across private datasets without compromising data privacy, ownership, sovereignty, or model performance - To create an easy-to-use platform that allows non-specialists to participate in collaborative learning -___ +--- :gear: **HOW DISCO WORKS** -- DISCO has a *public model – private data* approach -- Private and secure model updates – *not data* – are communicated to either: - - a central server : **federated** learning ( :star2: ) - - directly between users : **decentralized** learning ( :sparkles: ) i.e. no central coordination + +- DISCO has a _public model – private data_ approach +- Private and secure model updates – _not data_ – are communicated to either: + - a central server : **federated** learning ( :star2: ) + - directly between users : **decentralized** learning ( :sparkles: ) i.e. no central coordination - Model updates are then securely aggregated into a trained model - See more [HERE](https://discolab.ai/#/information) -___ -:question: **DISCO TECHNOLOGY** +--- + +:question: **DISCO TECHNOLOGY** + - DISCO runs arbitrary deep learning tasks and model architectures in your browser, via [TF.js](https://www.tensorflow.org/js) - Decentralized learning :sparkles: relies on [peer2peer](https://github.com/feross/simple-peer) communication - Have a look at how DISCO ensures privacy and confidentiality [HERE](docs/PRIVACY.md) -___ +--- :test_tube: **RESEARCH-BASED DESIGN** @@ -56,13 +63,13 @@ And more on the roadmap - :mirror: personalizable ([R10](https://arxiv.org/abs/2103.00710)) - :carrot: fairly incentivizing participation -___ - +--- :checkered_flag: **HOW TO USE DISCO** -- Start by exploring our examples tasks in the [`DISCOllaboratives` page](https://discolab.ai/#/list). + +- Start by exploring our examples tasks in the [`DISCOllaboratives` page](https://discolab.ai/#/list). - The example DISCOllaboratives are based on popular ML tasks such as [GPT2](https://d4mucfpksywv.cloudfront.net/better-language-models/language-models.pdf), [Titanic](https://www.kaggle.com/c/titanic), [MNIST](https://www.kaggle.com/c/digit-recognizer) or [CIFAR-10](https://www.kaggle.com/pankrzysiu/cifar10-python) - It is also possible to create your own DISCOllaboratives without coding on the [custom training page](https://discolab.ai/#/create): - - Upload the initial model - - Choose between federated and decentralized for your DISCO training scheme ... connect your data and... done! :bar_chart: - - For more details on ML tasks and custom training have a look at [this guide](./docs/TASK.md) + - Upload the initial model + - Choose between federated and decentralized for your DISCO training scheme ... connect your data and... done! :bar_chart: + - For more details on ML tasks and custom training have a look at [this guide](./docs/TASK.md) diff --git a/app.yaml b/app.yaml index 7d12e8c4b..f70c8352d 100644 --- a/app.yaml +++ b/app.yaml @@ -1,8 +1,8 @@ # Custom means that it will use the Dockerfile -runtime: custom +runtime: custom # Flex environment required for WebSocket support, which is required for PeerJS. -env: flex +env: flex # Limit resources to one instance, one CPU, very little memory or disk. instance_class: F1 diff --git a/cli/README.md b/cli/README.md index 1b6589b65..ddf0c358e 100644 --- a/cli/README.md +++ b/cli/README.md @@ -27,23 +27,32 @@ npm -w cli start -- --help # or -h ``` ## Command arguments + Based on the task specification, we can adjust the command arguments. Available arguments are listed below. Non-mandatory fields will automatically use values from the task specification. + ### Test specification arguments + - `testID`: (mandatory) arbitrary test ID defined by the user for the test run - `task`: (mandatory) pre-defined task (adding a new task is described in the next section) - `numberOfUsers`: number of users participating in the learning round - `save`: whether to save the logs of the test run + ### Learning hyperparameters + - `epochs`: total number of training epochs - `roundDuration`: number of epochs per round - `batchSize`: batch size - `validationSplit`: ratio of the validation set used for evaluation + ### Aggregator parameters + - `aggregator`: aggregator specification - `clippingRadius`, `maxIterations`, `beta`: (optional, for byzantine aggregator settings) byzantine aggregator hyperparameters - `maxShareValue`: (optional, for secure aggregator settings) secure aggregator hyperparameter + ### Differential Privacy parameters + - `epsilon`, `delta`, `dpDefaultClippingRadius`: (optional, for testing with differential privacy) differential privacy hyperparameters ## Adding new tasks @@ -80,10 +89,12 @@ The CLI includes a script to evaluate GPT models on the [HellaSwag](https://rowa To run the evaluation: `npm -w cli run hellaswag_gpt` The script benchmarks the following models: + - A TensorFlow.js implementation of GPT (`gpt-tfjs`) - A pre-trained ONNX model (`Xenova/gpt2`) Both models are evaluated using a shared tokenizer (`Xenova/gpt2`), and the script reports: + - Accuracy (proportion of correct multiple-choice predictions) - Total evaluation time (in seconds) @@ -91,7 +102,6 @@ Both models are evaluated using a shared tokenizer (`Xenova/gpt2`), and the scri Results are printed to the console and saved to a log file: `../datasets/logFile_hellaswag.txt` - This allows for a direct comparison between the inference performance and accuracy of the two architectures. The TFJS implementation is generally slower and more memory-intensive than ONNX, but offers compatibility with browser-based environments and custom training workflows. See the [Benchmarking GPT-TF.js](#benchmarking-gpt-tfjs) section for more details on performance tradeoffs. diff --git a/cli/package.json b/cli/package.json index cc0f741e2..fb40ae17c 100644 --- a/cli/package.json +++ b/cli/package.json @@ -1,26 +1,26 @@ { - "name": "cli", - "private": true, - "type": "module", - "main": "dist/cli.js", - "scripts": { - "watch": "nodemon --ext ts --ignore dist --watch ../discojs-node/dist --watch ../server/dist --watch . --exec npm run", - "start": "npm run build && node dist/cli.js", - "benchmark_gpt": "npm run build && node dist/benchmark_gpt.js", - "train_gpt": "npm run build && node dist/train_gpt.js", - "hellaswag_gpt": "npm run build && node dist/hellaswag_gpt.js", - "build": "tsc --build", - "test": ": nothing" - }, - "author": "", - "license": "ISC", - "dependencies": { - "@epfml/discojs-node": "*", - "server": "*", - "tslib": "2" - }, - "devDependencies": { - "nodemon": "3", - "ts-command-line-args": "2" - } + "name": "cli", + "private": true, + "type": "module", + "main": "dist/cli.js", + "scripts": { + "watch": "nodemon --ext ts --ignore dist --watch ../discojs-node/dist --watch ../server/dist --watch . --exec npm run", + "start": "npm run build && node dist/cli.js", + "benchmark_gpt": "npm run build && node dist/benchmark_gpt.js", + "train_gpt": "npm run build && node dist/train_gpt.js", + "hellaswag_gpt": "npm run build && node dist/hellaswag_gpt.js", + "build": "tsc --build", + "test": ": nothing" + }, + "author": "", + "license": "ISC", + "dependencies": { + "@epfml/discojs-node": "*", + "server": "*", + "tslib": "2" + }, + "devDependencies": { + "nodemon": "3", + "ts-command-line-args": "2" + } } diff --git a/cli/src/args.ts b/cli/src/args.ts index ced893a72..e74d623bb 100644 --- a/cli/src/args.ts +++ b/cli/src/args.ts @@ -1,61 +1,95 @@ -import { parse } from 'ts-command-line-args' -import { Map, Set } from 'immutable' +import { parse } from "ts-command-line-args"; +import { Map, Set } from "immutable"; import type { DataType, Network, TaskProvider } from "@epfml/discojs"; -import { defaultTasks } from '@epfml/discojs' +import { defaultTasks } from "@epfml/discojs"; type AggregationStrategy = "mean" | "byzantine" | "secure"; -function parseAggregator(raw: string): AggregationStrategy{ - if (raw === "mean" || raw == "byzantine" || raw == "secure") - return raw; - else - throw new Error(`Aggregator ${raw} is not supported.`); +function parseAggregator(raw: string): AggregationStrategy { + if (raw === "mean" || raw == "byzantine" || raw == "secure") return raw; + else throw new Error(`Aggregator ${raw} is not supported.`); } export interface BenchmarkArguments { - provider: TaskProvider; - testID: string - numberOfUsers: number - epochs: number - roundDuration: number - batchSize: number - validationSplit: number + provider: TaskProvider; + testID: string; + numberOfUsers: number; + epochs: number; + roundDuration: number; + batchSize: number; + validationSplit: number; // DP - epsilon?: number - delta?: number - dpDefaultClippingRadius?: number + epsilon?: number; + delta?: number; + dpDefaultClippingRadius?: number; // Aggregator - aggregator: AggregationStrategy + aggregator: AggregationStrategy; // Byzantine aggregator - clippingRadius?: number - maxIterations?: number - beta?: number + clippingRadius?: number; + maxIterations?: number; + beta?: number; // Secure aggregator - maxShareValue?: number + maxShareValue?: number; - save: boolean - host: URL + save: boolean; + host: URL; } -type BenchmarkUnsafeArguments = Omit & { - task: string - help?: boolean -} +type BenchmarkUnsafeArguments = Omit & { + task: string; + help?: boolean; +}; -const argExample = 'e.g. npm start -- -u 2 -e 3 # runs 2 users for 3 epochs' +const argExample = "e.g. npm start -- -u 2 -e 3 # runs 2 users for 3 epochs"; const unsafeArgs = parse( { - testID: { type: String, alias: 'i', description: 'ID of the testcase' }, - task: { type: String, alias: 't', description: 'Task: tinder_dog, titanic, simple_face, cifar10 or lus_covid', defaultValue: 'tinder_dog' }, - numberOfUsers: { type: Number, alias: 'u', description: 'Number of users', defaultValue: 2 }, - epochs: { type: Number, alias: 'e', description: 'Number of epochs', defaultValue: 10 }, - roundDuration: { type: Number, alias: 'r', description: 'Round duration (in epochs)', defaultValue: 2 }, - batchSize: { type: Number, alias: 'b', description: 'Training batch size', defaultValue: 10 }, - validationSplit : { type: Number, alias: 'v', description: 'Validation dataset ratio', defaultValue: 0.2 }, - save: { type: Boolean, alias: 's', description: 'Save logs of benchmark', defaultValue: false }, + testID: { type: String, alias: "i", description: "ID of the testcase" }, + task: { + type: String, + alias: "t", + description: + "Task: tinder_dog, titanic, simple_face, cifar10 or lus_covid", + defaultValue: "tinder_dog", + }, + numberOfUsers: { + type: Number, + alias: "u", + description: "Number of users", + defaultValue: 2, + }, + epochs: { + type: Number, + alias: "e", + description: "Number of epochs", + defaultValue: 10, + }, + roundDuration: { + type: Number, + alias: "r", + description: "Round duration (in epochs)", + defaultValue: 2, + }, + batchSize: { + type: Number, + alias: "b", + description: "Training batch size", + defaultValue: 10, + }, + validationSplit: { + type: Number, + alias: "v", + description: "Validation dataset ratio", + defaultValue: 0.2, + }, + save: { + type: Boolean, + alias: "s", + description: "Save logs of benchmark", + defaultValue: false, + }, host: { type: (raw: string) => new URL(raw), typeLabel: "URL", @@ -64,28 +98,71 @@ const unsafeArgs = parse( }, // Aggregator - aggregator: { type: parseAggregator, description: 'Type of weight aggregator', defaultValue: 'mean' }, + aggregator: { + type: parseAggregator, + description: "Type of weight aggregator", + defaultValue: "mean", + }, // Byzantine aggregator - clippingRadius: { type: Number, description: "Clipping radius for centered clipping", optional: true }, - maxIterations: { type: Number, description: "Maximum centered clipping iterations", optional: true }, - beta: { type: Number, description: "Momentum coefficient to smooth the aggregation over multiple rounds", optional: true }, + clippingRadius: { + type: Number, + description: "Clipping radius for centered clipping", + optional: true, + }, + maxIterations: { + type: Number, + description: "Maximum centered clipping iterations", + optional: true, + }, + beta: { + type: Number, + description: + "Momentum coefficient to smooth the aggregation over multiple rounds", + optional: true, + }, // Secure aggregator - maxShareValue: { type: Number, description: "Maximum absolute value over all the weights", optional: true }, + maxShareValue: { + type: Number, + description: "Maximum absolute value over all the weights", + optional: true, + }, // Differential Privacy - epsilon: { type: Number, description: 'Privacy budget', optional: true, defaultValue: undefined}, - delta: { type: Number, description: 'Probability of failure, slack parameter', optional: true, defaultValue: undefined}, - dpDefaultClippingRadius: {type: Number, description: 'Default clipping radius for DP', optional: true, defaultValue: undefined}, + epsilon: { + type: Number, + description: "Privacy budget", + optional: true, + defaultValue: undefined, + }, + delta: { + type: Number, + description: "Probability of failure, slack parameter", + optional: true, + defaultValue: undefined, + }, + dpDefaultClippingRadius: { + type: Number, + description: "Default clipping radius for DP", + optional: true, + defaultValue: undefined, + }, - help: { type: Boolean, optional: true, alias: 'h', description: 'Prints this usage guide' } + help: { + type: Boolean, + optional: true, + alias: "h", + description: "Prints this usage guide", + }, }, { - helpArg: 'help', - headerContentSections: [{ header: 'DISCO CLI', content: 'npm start -- [Options]\n' + argExample }] - } -) + helpArg: "help", + headerContentSections: [ + { header: "DISCO CLI", content: "npm start -- [Options]\n" + argExample }, + ], + }, +); const supportedTasks = Map( await Promise.all( @@ -108,7 +185,7 @@ const supportedTasks = Map( const provider = supportedTasks.get(unsafeArgs.task); if (provider === undefined) { - throw Error(`${unsafeArgs.task} not implemented.`) + throw Error(`${unsafeArgs.task} not implemented.`); } export const args: BenchmarkArguments = { @@ -123,7 +200,8 @@ export const args: BenchmarkArguments = { task.trainingInformation.epochs = unsafeArgs.epochs; task.trainingInformation.validationSplit = unsafeArgs.validationSplit; - const {aggregator, clippingRadius, maxIterations, beta, maxShareValue} = unsafeArgs; + const { aggregator, clippingRadius, maxIterations, beta, maxShareValue } = + unsafeArgs; // For aggregators if (aggregator !== undefined) @@ -134,11 +212,15 @@ export const args: BenchmarkArguments = { clippingRadius !== undefined && maxIterations !== undefined && beta !== undefined - ){ + ) { if (task.trainingInformation.scheme === "local") - throw new Error("Byzantine aggregator is not supported for local training"); + throw new Error( + "Byzantine aggregator is not supported for local training", + ); if (task.trainingInformation.aggregationStrategy !== "byzantine") - throw new Error("Byzantine parameters can be set only when aggregationStrategy is byzantine"); + throw new Error( + "Byzantine parameters can be set only when aggregationStrategy is byzantine", + ); task.trainingInformation.privacy = { ...task.trainingInformation.privacy, @@ -151,29 +233,34 @@ export const args: BenchmarkArguments = { } // For secure aggregator - if (maxShareValue !== undefined){ - + if (maxShareValue !== undefined) { if (task.trainingInformation.scheme !== "decentralized") - throw new Error("Secure aggation is only supported for decentralized laerning") + throw new Error( + "Secure aggation is only supported for decentralized laerning", + ); if (task.trainingInformation.aggregationStrategy !== "secure") - throw new Error("maxShareValue can be set when aggregationStrategy is secure"); + throw new Error( + "maxShareValue can be set when aggregationStrategy is secure", + ); task.trainingInformation.maxShareValue = maxShareValue; } // For DP - const {dpDefaultClippingRadius, epsilon, delta} = unsafeArgs; + const { dpDefaultClippingRadius, epsilon, delta } = unsafeArgs; if ( // dpDefaultClippingRadius !== undefined && epsilon !== undefined && delta !== undefined - ){ + ) { if (task.trainingInformation.scheme === "local") throw new Error("Can't have differential privacy for local training"); - const defaultRadius = dpDefaultClippingRadius ? dpDefaultClippingRadius : 1; - + const defaultRadius = dpDefaultClippingRadius + ? dpDefaultClippingRadius + : 1; + // for the case where privacy parameters are not defined in the default tasks task.trainingInformation.privacy ??= {}; task.trainingInformation.privacy.differentialPrivacy = { diff --git a/cli/src/benchmark_gpt.ts b/cli/src/benchmark_gpt.ts index ca0f4faed..49ab4ea4b 100644 --- a/cli/src/benchmark_gpt.ts +++ b/cli/src/benchmark_gpt.ts @@ -1,4 +1,4 @@ -import '@tensorflow/tfjs-node'; +import "@tensorflow/tfjs-node"; import { List } from "immutable"; import { parse } from "ts-command-line-args"; @@ -9,56 +9,92 @@ import { fetchTasks, models, } from "@epfml/discojs"; -import { loadModelFromDisk, loadText } from '@epfml/discojs-node' +import { loadModelFromDisk, loadText } from "@epfml/discojs-node"; import { Server } from "server"; -interface CLIArguments{ +interface CLIArguments { modelType?: string; // 'gpt-nano', 'gpt-micro', 'gpt-mini', 'gpt2' contextLength?: number; // 128, 256, 512, 1024, 2048 batchSize?: number; // 8, 16, 32, 64 inference?: boolean; // benchmark inference if true, training otherwise modelPath?: string; - help?: boolean // print help + help?: boolean; // print help } -const parsedArgs = parse({ - modelType: { type: String, optional: true, description: "A GPT architecture: 'gpt-nano', 'gpt-micro', 'gpt-mini', 'gpt2'" }, - contextLength: { type: Number, optional: true, description: "The maximum input sequence length to train the model on" }, - batchSize: { type: Number, optional: true, description: "The model training bat size" }, - inference: { type: Boolean, optional: true, description: "Whether to benchmark the model inference or training" }, - modelPath: { type: String, optional: true, description: "If benchmarking inference, the path to the trained model" }, - help: { type: Boolean, optional: true, alias: 'h', description: 'Prints this usage guide' }, -}, {helpArg: 'help'}); +const parsedArgs = parse( + { + modelType: { + type: String, + optional: true, + description: + "A GPT architecture: 'gpt-nano', 'gpt-micro', 'gpt-mini', 'gpt2'", + }, + contextLength: { + type: Number, + optional: true, + description: "The maximum input sequence length to train the model on", + }, + batchSize: { + type: Number, + optional: true, + description: "The model training bat size", + }, + inference: { + type: Boolean, + optional: true, + description: "Whether to benchmark the model inference or training", + }, + modelPath: { + type: String, + optional: true, + description: "If benchmarking inference, the path to the trained model", + }, + help: { + type: Boolean, + optional: true, + alias: "h", + description: "Prints this usage guide", + }, + }, + { helpArg: "help" }, +); const defaultArgs: Required = { - modelType: 'gpt-nano', + modelType: "gpt-nano", contextLength: 128, batchSize: 8, inference: false, - modelPath: 'models/model.json', - help: false -} + modelPath: "models/model.json", + help: false, +}; // Fill parsed args with default args -const args = { ...defaultArgs, ...parsedArgs } +const args = { ...defaultArgs, ...parsedArgs }; /** * Benchmark results are reported in https://github.com/epfml/disco/pull/659 */ async function main(args: Required): Promise { - const { inference: benchmarkInference, modelType, - contextLength, batchSize, modelPath } = args + const { + inference: benchmarkInference, + modelType, + contextLength, + batchSize, + modelPath, + } = args; // Launch a server instance const server = await Server.with(defaultTasks.wikitext); const [handle, url] = await server.serve(); // Fetch the wikitext task from the server - const tasks = await fetchTasks(url) - const task = tasks.get("llm_task") as Task<"text", Network> | undefined; - if (task === undefined) { throw new Error('task not found') } + const tasks = await fetchTasks(url); + const task = tasks.get("llm_task") as Task<"text", Network> | undefined; + if (task === undefined) { + throw new Error("task not found"); + } const { tokenizer } = task.trainingInformation; /** @@ -66,69 +102,80 @@ async function main(args: Required): Promise { */ if (!benchmarkInference) { // Benchmark parameters - const epochsCount = 1 - const iterationsPerEpoch = 10 + const epochsCount = 1; + const iterationsPerEpoch = 10; const config: models.GPTConfig = { - modelType: modelType as models.GPTConfig['modelType'], + modelType: modelType as models.GPTConfig["modelType"], maxIter: iterationsPerEpoch, lr: 0.0001, contextLength, - } + }; // Load the dataset after setting the Task batch size and max sequence length // to make sure the dataset is batched and tokenized correctly - task.trainingInformation.batchSize = batchSize - task.trainingInformation.contextLength = contextLength - const dataset = loadText('../datasets/wikitext/wiki.train.tokens') + task.trainingInformation.batchSize = batchSize; + task.trainingInformation.contextLength = contextLength; + const dataset = loadText("../datasets/wikitext/wiki.train.tokens") .map((text) => tokenizer.tokenize(text)) .flatten() - .batch(config.contextLength + 1, 1) + .batch(config.contextLength + 1, 1); const preprocessedDataset = dataset .map((tokens) => [tokens.pop(), tokens.last()] as [List, number]) .batch(batchSize); - + // Init and train the model - const model = new models.GPT(config) - console.log(`\tmodel type ${modelType} \n\tbatch size ${batchSize} \n\tcontext length ${contextLength}`) + const model = new models.GPT(config); + console.log( + `\tmodel type ${modelType} \n\tbatch size ${batchSize} \n\tcontext length ${contextLength}`, + ); - let epochTime = performance.now() + let epochTime = performance.now(); for (let epochsCounter = 1; epochsCounter <= epochsCount; epochsCounter++) { - const [_, logs] = await async_iterator.gather(model.train(preprocessedDataset)) - epochTime = (performance.now() - epochTime) - const msPerToken = epochTime / (batchSize * contextLength * iterationsPerEpoch * epochsCounter) - console.log(`\t\tTraining time: ${msPerToken.toFixed(2)} ms/token
${logs.peakMemory.toFixed(2)} GB`) + const [_, logs] = await async_iterator.gather( + model.train(preprocessedDataset), + ); + epochTime = performance.now() - epochTime; + const msPerToken = + epochTime / + (batchSize * contextLength * iterationsPerEpoch * epochsCounter); + console.log( + `\t\tTraining time: ${msPerToken.toFixed(2)} ms/token
${logs.peakMemory.toFixed(2)} GB`, + ); } - /** - * Inference benchmark - */ + /** + * Inference benchmark + */ } else { - const model = await loadModelFromDisk(modelPath) - if (!(model instanceof models.GPT)){ - throw new Error("Loaded model isn't a GPT model") + const model = await loadModelFromDisk(modelPath); + if (!(model instanceof models.GPT)) { + throw new Error("Loaded model isn't a GPT model"); } - + // Benchmark parameters - const prompt = 'The game began development in 2010 , carrying over a large portion, The game began development in 2010 , carrying over a large portion, The game began development in 2010 , carrying over a large portion,' - const maxNewTokens = 200 - const iterations = 10 - console.log("Generating", maxNewTokens, "new tokens") + const prompt = + "The game began development in 2010 , carrying over a large portion, The game began development in 2010 , carrying over a large portion, The game began development in 2010 , carrying over a large portion,"; + const maxNewTokens = 200; + const iterations = 10; + console.log("Generating", maxNewTokens, "new tokens"); let tokens = tokenizer.tokenize(prompt); - let inferenceTime = 0 + let inferenceTime = 0; for (let i = 0; i < iterations; i++) { - const timeStart = performance.now() + const timeStart = performance.now(); for (let n = 0; n < maxNewTokens; n++) { - const next = (await model.predict(List.of(tokens))).first(); - if (next === undefined) throw new Error("empty prediction"); - tokens = tokens.push(next) + const next = (await model.predict(List.of(tokens))).first(); + if (next === undefined) throw new Error("empty prediction"); + tokens = tokens.push(next); } - inferenceTime += performance.now() - timeStart + inferenceTime += performance.now() - timeStart; } - console.log(`Inference time: ${(inferenceTime/ maxNewTokens / iterations).toFixed(2)} ms/token`) + console.log( + `Inference time: ${(inferenceTime / maxNewTokens / iterations).toFixed(2)} ms/token`, + ); } await new Promise((resolve, reject) => { handle.once("close", resolve); @@ -137,4 +184,4 @@ async function main(args: Required): Promise { } // You can run this example with "npm start" from this folder -main(args).catch(console.error) +main(args).catch(console.error); diff --git a/cli/src/cli.ts b/cli/src/cli.ts index 2e23c6514..a2a586cc9 100644 --- a/cli/src/cli.ts +++ b/cli/src/cli.ts @@ -1,8 +1,8 @@ // speed things up TODO how to avoid the need to import it -import "@tensorflow/tfjs-node" +import "@tensorflow/tfjs-node"; -import { List, Range } from 'immutable' -import fs from 'node:fs/promises' +import { List, Range } from "immutable"; +import fs from "node:fs/promises"; import { createWriteStream } from "node:fs"; import path from "node:path"; @@ -15,25 +15,28 @@ import type { TaskProvider, Network, } from "@epfml/discojs"; -import { Disco, aggregator as aggregators, client as clients } from '@epfml/discojs' +import { + Disco, + aggregator as aggregators, + client as clients, +} from "@epfml/discojs"; -import { getTaskData } from './data.js' -import { args } from './args.js' +import { getTaskData } from "./data.js"; +import { args } from "./args.js"; import { makeUserLogFile } from "./user_log.js"; import type { UserLogFile } from "./user_log.js"; - async function runUser( - task: Task, - url: URL, - data: Dataset, + task: Task, + url: URL, + data: Dataset, userIndex: number, numberOfUsers: number, ): Promise> { // cast as typescript isn't good with generics - const trainingScheme = task.trainingInformation.scheme as N - const aggregator = aggregators.getAggregator(task) - const client = clients.getClient(trainingScheme, url, task, aggregator) + const trainingScheme = task.trainingInformation.scheme as N; + const aggregator = aggregators.getAggregator(task); + const client = clients.getClient(trainingScheme, url, task, aggregator); const disco = new Disco(task, client, { scheme: trainingScheme }); const dir = path.join(".", `${args.testID}`); @@ -44,37 +47,43 @@ async function runUser( // create a write stream that saves learning logs during the train let jsonStream: ReturnType | null = null; - if (args.save){ - jsonStream = createWriteStream(streamPath, {flags: "w"}); + if (args.save) { + jsonStream = createWriteStream(streamPath, { flags: "w" }); } - try{ - for await (const log of disco.trainSummary(data)){ + try { + for await (const log of disco.trainSummary(data)) { finalLog.push(log); - if (jsonStream){ + if (jsonStream) { jsonStream.write(JSON.stringify(log) + "\n"); } } - await new Promise((res, _) => setTimeout(() => res('timeout'), 1000)) // Wait for other peers to finish + await new Promise((res, _) => setTimeout(() => res("timeout"), 1000)); // Wait for other peers to finish // saving the entire per-user logs if (args.save) { const finalPath = path.join(dir, `client${userIndex}_local_log.json`); - const userLog: UserLogFile = makeUserLogFile(task, numberOfUsers, userIndex, client.ownId, finalLog); + const userLog: UserLogFile = makeUserLogFile( + task, + numberOfUsers, + userIndex, + client.ownId, + finalLog, + ); await fs.writeFile(finalPath, JSON.stringify(userLog, null, 2)); } return List(finalLog); - }catch(err){ + } catch (err) { console.error(`Run user failed for client ${userIndex}: `, err); throw err; - }finally{ - try{ - if (jsonStream){ + } finally { + try { + if (jsonStream) { jsonStream.end(); await new Promise((resolve, reject) => { @@ -82,33 +91,48 @@ async function runUser( jsonStream.once("error", reject); }); } - }catch(err){ - console.error(`failed to close log stream for client ${userIndex}: `, err); + } catch (err) { + console.error( + `failed to close log stream for client ${userIndex}: `, + err, + ); } - try{ + try { await disco.close(); - }catch(err){ + } catch (err) { console.error(`failed to close disco for client ${userIndex}: `, err); } } } async function main( - provider: TaskProvider, - numberOfUsers: number, + provider: TaskProvider, + numberOfUsers: number, ): Promise { const task = await provider.getTask(); - console.log(`Test ID: ${args.testID}`) - console.log(`Started ${task.trainingInformation.scheme} training of ${task.id}`) - console.log({ args }) + console.log(`Test ID: ${args.testID}`); + console.log( + `Started ${task.trainingInformation.scheme} training of ${task.id}`, + ); + console.log({ args }); const dataSplits = await Promise.all( - Range(0, numberOfUsers).map(async i => getTaskData(task.id, i, numberOfUsers)) - ) + Range(0, numberOfUsers).map(async (i) => + getTaskData(task.id, i, numberOfUsers), + ), + ); const logs = await Promise.all( - dataSplits.map((data, i) => runUser(task, args.host, data as Dataset, i, numberOfUsers)) - ) + dataSplits.map((data, i) => + runUser( + task, + args.host, + data as Dataset, + i, + numberOfUsers, + ), + ), + ); if (args.save) { const dir = path.join(".", `${args.testID}`, `${task.id}`); @@ -119,4 +143,4 @@ async function main( } } -main(args.provider, args.numberOfUsers).catch(console.error) +main(args.provider, args.numberOfUsers).catch(console.error); diff --git a/cli/src/data.ts b/cli/src/data.ts index aa4d0a330..3921781a5 100644 --- a/cli/src/data.ts +++ b/cli/src/data.ts @@ -1,15 +1,13 @@ import path from "node:path"; import { Dataset, processing } from "@epfml/discojs"; -import { - DataFormat, - DataType, - Image, - Task, -} from "@epfml/discojs"; +import { DataFormat, DataType, Image, Task } from "@epfml/discojs"; import { loadCSV, loadImage, loadImagesInDir } from "@epfml/discojs-node"; import { Repeat } from "immutable"; -async function loadSimpleFaceData(userIdx: number, totalClient: number): Promise> { +async function loadSimpleFaceData( + userIdx: number, + totalClient: number, +): Promise> { const folder = path.join("..", "datasets", "simple_face"); const [adults, childs]: Dataset<[Image, string]>[] = [ @@ -24,7 +22,10 @@ async function loadSimpleFaceData(userIdx: number, totalClient: number): Promise return sharded; } -async function loadLusCovidData(userIdx: number, totalClient: number): Promise> { +async function loadLusCovidData( + userIdx: number, + totalClient: number, +): Promise> { const folder = path.join("..", "datasets", "lus_covid"); const [positive, negative]: Dataset<[Image, string]>[] = [ @@ -67,38 +68,43 @@ function loadTinderDogData(split: number): Dataset { }); } -function loadData(dataName: string, split: number): Dataset{ +function loadData( + dataName: string, + split: number, +): Dataset { const folder = path.join("..", "datasets", `${dataName}`, `client_${split}`); return loadCSV(path.join(folder, "labels.csv")) .map( - (row) => [ - processing.extractColumn(row, "filename"), - processing.extractColumn(row, "label"), - ] as const, + (row) => + [ + processing.extractColumn(row, "filename"), + processing.extractColumn(row, "label"), + ] as const, ) - .map( - async ([filename, label]) => { - try { - const img = await Promise.any( - ["png", "jpg", "jpeg"].map((ext) => - loadImage(path.join(folder, `${filename}.${ext}`))) - ); - return [img, label] - } catch { - throw Error(`${filename} not found in ${folder}`); - } + .map(async ([filename, label]) => { + try { + const img = await Promise.any( + ["png", "jpg", "jpeg"].map((ext) => + loadImage(path.join(folder, `${filename}.${ext}`)), + ), + ); + return [img, label]; + } catch { + throw Error(`${filename} not found in ${folder}`); } - ); + }); } export async function getTaskData( - taskID: Task.ID, - userIdx: number, - totalClient: number + taskID: Task.ID, + userIdx: number, + totalClient: number, ): Promise> { switch (taskID) { case "simple_face": // remove - return (await loadSimpleFaceData(userIdx, totalClient)) as Dataset; + return (await loadSimpleFaceData(userIdx, totalClient)) as Dataset< + DataFormat.Raw[D] + >; case "titanic": case "titanic_decentralized": const titanicData = loadCSV( @@ -112,7 +118,9 @@ export async function getTaskData( return loadData("cifar10_ext", userIdx) as Dataset; case "lus_covid": case "lus_covid_decentralized": - return (await loadLusCovidData(userIdx, totalClient)) as Dataset; + return (await loadLusCovidData(userIdx, totalClient)) as Dataset< + DataFormat.Raw[D] + >; case "tinder_dog": // remove return loadTinderDogData(userIdx) as Dataset; case "mnist_federated": diff --git a/cli/src/hellaswag_gpt.ts b/cli/src/hellaswag_gpt.ts index 0d86c1eff..ba29a3221 100644 --- a/cli/src/hellaswag_gpt.ts +++ b/cli/src/hellaswag_gpt.ts @@ -1,95 +1,120 @@ -import fsPromise from 'node:fs/promises'; -import { dirname } from 'path'; -import { fileURLToPath } from 'url'; -import { parse } from 'ts-command-line-args' - -import '@tensorflow/tfjs-node'; -import path from 'node:path'; -import { models, serialization, Tokenizer } from '@epfml/discojs'; -import { loadHellaSwag } from '@epfml/discojs-node'; - -const __dirname = dirname(fileURLToPath(import.meta.url)); - -async function evaluateModel(model: models.GPT | models.ONNXModel, numDataPoints = -1) { - const hellaswagDataset: models.HellaSwagDataset = await loadHellaSwag(numDataPoints) - const tokenizer = await Tokenizer.from_pretrained('Xenova/gpt2'); - console.log('Starting the HellaSwag benchmark...'); - - const start = Date.now(); - const accuracy = await models.evaluate_hellaswag(model, tokenizer, hellaswagDataset, true); - const duration = ((Date.now() - start) / 1000).toFixed(2); - - console.log(`Final accuracy: ${(accuracy * 100).toFixed(2)}%`); - console.log(`Evaluation Time: ${duration} seconds`); -} - -const ModelTypes = ['onnx', 'gpt-tfjs-random', 'gpt-tfjs-pretrained'] as const; -type ModelType = typeof ModelTypes[number]; - -interface HellaSwagArgs { - model: ModelType - numDataPoints: number - logFile: string - pretrainedModelPath: string - help?: boolean -} - -function castModelType(raw: string): ModelType { - for (const t of ModelTypes) if (raw === t) return t - throw new Error(`Invalid model type: ${raw}`) -} - -async function main(): Promise { - const args = parse({ - model: { - type: (raw: string) => castModelType(raw), - description: `Model type, one of ${ModelTypes.toString()}`, - defaultValue: 'onnx' - }, - numDataPoints: { - type: Number, - description: 'Number of HellaSwag datapoints to evaluate, set -1 for the whole benchmark', - defaultValue: -1 - }, - logFile: { - type: String, - description: 'Relative path to the log file, default to ./hellaswag.log', defaultValue: 'hellaswag.log' - }, - pretrainedModelPath: { - type: String, - description: 'If specifying gpt-tfjs-pretrained, provide the relative path to the TF.js pretrained model', - defaultValue: path.join(__dirname, "..", "..", "onnx-converter", "assets", "model.json") - }, - help: { - type: Boolean, - optional: true, - alias: 'h', - description: 'Prints this usage guide' - } - }, { helpArg: 'help' }) - - let model: models.GPT | models.ONNXModel | undefined; - switch (args.model) { - case 'onnx': - console.log("Using ONNX pretrained model Xenova/gpt2") - model = await models.ONNXModel.init_pretrained('Xenova/gpt2'); - break; - case 'gpt-tfjs-random': - console.log("Using GPT-TFJS with random initialization") - model = new models.GPT({ seed: 42 }); - break; - case 'gpt-tfjs-pretrained': - console.log("Using GPT-TFJS with pretrained weights") - if (args.pretrainedModelPath === undefined) { - throw new Error("If choosing gpt-tfjs-pretrained, provide the relative path to the TF.js pretrained model `pretrainedModelPath") - } - const encodedModel = await fsPromise.readFile(args.pretrainedModelPath); - model = await serialization.model.decode(encodedModel) as models.GPT; - break; - } - await evaluateModel(model, args.numDataPoints); - - console.log("Benchmark completed!") -} - -main().catch(console.error); +import fsPromise from "node:fs/promises"; +import { dirname } from "path"; +import { fileURLToPath } from "url"; +import { parse } from "ts-command-line-args"; + +import "@tensorflow/tfjs-node"; +import path from "node:path"; +import { models, serialization, Tokenizer } from "@epfml/discojs"; +import { loadHellaSwag } from "@epfml/discojs-node"; + +const __dirname = dirname(fileURLToPath(import.meta.url)); + +async function evaluateModel( + model: models.GPT | models.ONNXModel, + numDataPoints = -1, +) { + const hellaswagDataset: models.HellaSwagDataset = + await loadHellaSwag(numDataPoints); + const tokenizer = await Tokenizer.from_pretrained("Xenova/gpt2"); + console.log("Starting the HellaSwag benchmark..."); + + const start = Date.now(); + const accuracy = await models.evaluate_hellaswag( + model, + tokenizer, + hellaswagDataset, + true, + ); + const duration = ((Date.now() - start) / 1000).toFixed(2); + + console.log(`Final accuracy: ${(accuracy * 100).toFixed(2)}%`); + console.log(`Evaluation Time: ${duration} seconds`); +} + +const ModelTypes = ["onnx", "gpt-tfjs-random", "gpt-tfjs-pretrained"] as const; +type ModelType = (typeof ModelTypes)[number]; + +interface HellaSwagArgs { + model: ModelType; + numDataPoints: number; + logFile: string; + pretrainedModelPath: string; + help?: boolean; +} + +function castModelType(raw: string): ModelType { + for (const t of ModelTypes) if (raw === t) return t; + throw new Error(`Invalid model type: ${raw}`); +} + +async function main(): Promise { + const args = parse( + { + model: { + type: (raw: string) => castModelType(raw), + description: `Model type, one of ${ModelTypes.toString()}`, + defaultValue: "onnx", + }, + numDataPoints: { + type: Number, + description: + "Number of HellaSwag datapoints to evaluate, set -1 for the whole benchmark", + defaultValue: -1, + }, + logFile: { + type: String, + description: + "Relative path to the log file, default to ./hellaswag.log", + defaultValue: "hellaswag.log", + }, + pretrainedModelPath: { + type: String, + description: + "If specifying gpt-tfjs-pretrained, provide the relative path to the TF.js pretrained model", + defaultValue: path.join( + __dirname, + "..", + "..", + "onnx-converter", + "assets", + "model.json", + ), + }, + help: { + type: Boolean, + optional: true, + alias: "h", + description: "Prints this usage guide", + }, + }, + { helpArg: "help" }, + ); + + let model: models.GPT | models.ONNXModel | undefined; + switch (args.model) { + case "onnx": + console.log("Using ONNX pretrained model Xenova/gpt2"); + model = await models.ONNXModel.init_pretrained("Xenova/gpt2"); + break; + case "gpt-tfjs-random": + console.log("Using GPT-TFJS with random initialization"); + model = new models.GPT({ seed: 42 }); + break; + case "gpt-tfjs-pretrained": + console.log("Using GPT-TFJS with pretrained weights"); + if (args.pretrainedModelPath === undefined) { + throw new Error( + "If choosing gpt-tfjs-pretrained, provide the relative path to the TF.js pretrained model `pretrainedModelPath", + ); + } + const encodedModel = await fsPromise.readFile(args.pretrainedModelPath); + model = (await serialization.model.decode(encodedModel)) as models.GPT; + break; + } + await evaluateModel(model, args.numDataPoints); + + console.log("Benchmark completed!"); +} + +main().catch(console.error); diff --git a/cli/src/python/README.md b/cli/src/python/README.md index 18b488d20..7ae806da6 100644 --- a/cli/src/python/README.md +++ b/cli/src/python/README.md @@ -1,26 +1,36 @@ # Experiment Runner & Log Visualization + This folder contains two Python scripts for running experiments with different configurations and visualizing their results. ## Environment Setup + We recommend using Python virtual environment to manage dependencies. 1. Create a virtual environment + ```bash python3 -m venv .venv ``` + 2. Activate the virtual environment + ```bash # Linux / macOS source .venv/bin/activate -``` +``` + 3. Install the required libraries from requirements.txt + ```bash pip install -r requirements.txt ``` + Once completed, you are ready to run the scripts. ## Key Components + ### `run_experiments.py` + This script executes experiments based on a JSON configuration file. - Reads experiment setting from a JSON file @@ -28,18 +38,22 @@ This script executes experiments based on a JSON configuration file. - Default configuration file: `./experiments/basic_tests.json` **Usage:** + ```bash python3 run_experiments.py ``` ### `visualize_logs.py` + This script visualizes experiment logs using `pandas` and `matplotlib`. **Input:** + - A directory containing log files from a single experiment (typically contains multiple JSON files from different clients) **Generated Outputs:** The script produces the following plots in the same directory + 1. Training loss (all clients) 2. Training accuracy (all clients) 3. Validation loss (all clients) @@ -49,12 +63,15 @@ The script produces the following plots in the same directory 7. Average validation accuracy across clients **Usage:** + ```bash python3 visualize_logs.py ``` ### `experiment/` Directory + This directory contains JSON files defining experiment configurations + ``` { "name": "Optional description of the test suite", @@ -75,14 +92,17 @@ This directory contains JSON files defining experiment configurations ] } ``` + Field Descriptions + - `name` (optional): Description of the test suite - `defaults` (optional): Default parameters applied to all experiments - `experiments`: List of experiment configurations - - `testID` (required): Unique identifier for the experiment - - `task` (required): Task name to run - - Other fields: Training parameters + - `testID` (required): Unique identifier for the experiment + - `task` (required): Task name to run + - Other fields: Training parameters **Important Notes** + 1. **Training scheme (federated, decentralized)** cannot be adjusted in this JSON file. Since training schemes are bound to task objects, we must create the task separately, import it in `args.ts`, and specify the task name in test suite JSON to run the experiments with the intended training scheme. -2. `minNbOfParticipants` cannot be adjusted in test suite JSON file. Similar to training scheme, this must be specified during task creation. \ No newline at end of file +2. `minNbOfParticipants` cannot be adjusted in test suite JSON file. Similar to training scheme, this must be specified during task creation. diff --git a/cli/src/python/experiments/basic_tests.json b/cli/src/python/experiments/basic_tests.json index f7df00563..5037db9f1 100644 --- a/cli/src/python/experiments/basic_tests.json +++ b/cli/src/python/experiments/basic_tests.json @@ -50,7 +50,7 @@ "maxIterations": 1, "beta": 0.9 }, - + { "testID": "lus_covid_fed_mean_cnn2_p3_d32_e30_r2", "task": "lus_covid", @@ -68,7 +68,7 @@ "maxIterations": 1, "beta": 0.9 }, - + { "testID": "cifar10_dec_mean_cnn4_p3_d600_e500_r5", "task": "cifar10", @@ -101,4 +101,4 @@ "maxShareValue": 100 } ] -} \ No newline at end of file +} diff --git a/cli/src/train_gpt.ts b/cli/src/train_gpt.ts index ff162e530..6bd21d61c 100644 --- a/cli/src/train_gpt.ts +++ b/cli/src/train_gpt.ts @@ -1,22 +1,22 @@ -import "@tensorflow/tfjs-node" +import "@tensorflow/tfjs-node"; import { models, Dataset, Tokenizer } from "@epfml/discojs"; import { List } from "immutable"; -async function main(): Promise { - const data = "Lorem ipsum dolor sit amet, consectetur adipis" - const seed = 42 +async function main(): Promise { + const data = "Lorem ipsum dolor sit amet, consectetur adipis"; + const seed = 42; const config: models.GPTConfig = { - modelType: 'gpt-nano', + modelType: "gpt-nano", lr: 0.01, maxIter: 50, - evaluateEvery:50, + evaluateEvery: 50, maxEvalBatches: 10, contextLength: 16, - seed - } + seed, + }; - const tokenizer = await Tokenizer.from_pretrained('Xenova/gpt2') + const tokenizer = await Tokenizer.from_pretrained("Xenova/gpt2"); const tokenDataset = new Dataset([data]) .map((text) => tokenizer.tokenize(text)) @@ -25,23 +25,23 @@ async function main(): Promise { .map((tokens) => [tokens.pop(), tokens.last()] as [List, number]) .repeat() .batch(8); - - const model = new models.GPT(config) + + const model = new models.GPT(config); for await (const logs of model.train(tokenDataset, undefined)) { - console.log(logs) + console.log(logs); } let tokens = tokenizer.tokenize("Lorem"); - const maxNewTokens = 14 + const maxNewTokens = 14; for (let n = 0; n < maxNewTokens; n++) { - const next = (await model.predict(List.of(tokens), { seed })).first(); - if (next === undefined) throw new Error("empty prediction"); - tokens = tokens.push(next) + const next = (await model.predict(List.of(tokens), { seed })).first(); + if (next === undefined) throw new Error("empty prediction"); + tokens = tokens.push(next); } - const generation = tokenizer.decode(tokens.toArray()) - console.log(generation) + const generation = tokenizer.decode(tokens.toArray()); + console.log(generation); } // You can run this example with "npm run run_gpt" from this folder -main().catch(console.error) +main().catch(console.error); diff --git a/cli/src/user_log.ts b/cli/src/user_log.ts index 5acdca3dd..2b7fc7bca 100644 --- a/cli/src/user_log.ts +++ b/cli/src/user_log.ts @@ -1,12 +1,12 @@ import { args, BenchmarkArguments } from "./args.js"; import type { SummaryLogs, DataType, Network, Task } from "@epfml/discojs"; -type SerializableArguments = Omit & { +type SerializableArguments = Omit & { host: string; -} +}; export interface UserLogFile { - run: { + run: { testID: string; taskID: string; numberOfUsers: number; @@ -24,37 +24,37 @@ export interface UserLogFile { logs: SummaryLogs[]; } -function serializeArgs(): SerializableArguments{ - const {provider, host, ...rest} = args; - return { - ...rest, - host: host.toString(), - }; +function serializeArgs(): SerializableArguments { + const { provider, host, ...rest } = args; + return { + ...rest, + host: host.toString(), + }; } export function makeUserLogFile( - task: Task, - numberOfUsers: number, - userIndex: number, - clientId: string, - logs: SummaryLogs[] -): UserLogFile{ - return { - run: { - testID: args.testID, - taskID: task.id, - numberOfUsers, - }, - task: { - id: task.id, - dataType: task.dataType, - trainingInformation: task.trainingInformation, - }, - args: serializeArgs(), - user: { - index: userIndex, - clientId: clientId, - }, - logs: logs - } + task: Task, + numberOfUsers: number, + userIndex: number, + clientId: string, + logs: SummaryLogs[], +): UserLogFile { + return { + run: { + testID: args.testID, + taskID: task.id, + numberOfUsers, + }, + task: { + id: task.id, + dataType: task.dataType, + trainingInformation: task.trainingInformation, + }, + args: serializeArgs(), + user: { + index: userIndex, + clientId: clientId, + }, + logs: logs, + }; } diff --git a/cli/tsconfig.json b/cli/tsconfig.json index 7ef565dc5..327092157 100644 --- a/cli/tsconfig.json +++ b/cli/tsconfig.json @@ -14,7 +14,5 @@ "compilerOptions": { "outDir": "dist" }, - "include": [ - "src" - ] -} \ No newline at end of file + "include": ["src"] +} diff --git a/discojs-node/package.json b/discojs-node/package.json index 8205d9636..edbd19cdd 100644 --- a/discojs-node/package.json +++ b/discojs-node/package.json @@ -1,33 +1,33 @@ { - "name": "@epfml/discojs-node", - "version": "3.0.0", - "type": "module", - "exports": "./dist/index.js", - "types": "dist/index.d.ts", - "scripts": { - "watch": "nodemon --ext ts --ignore dist --watch ../discojs/dist --watch . --exec npm run", - "build": "tsc --build", - "test": "cd .. && vitest --run --project=discojs-node" - }, - "repository": { - "type": "git", - "url": "git+https://github.com/epfml/disco.git" - }, - "bugs": { - "url": "https://github.com/epfml/disco/issues" - }, - "homepage": "https://github.com/epfml/disco#readme", - "dependencies": { - "@epfml/discojs": "*", - "@roamhq/wrtc": "0.10", - "@tensorflow/tfjs-node": "4", - "csv-parse": "6", - "sharp": "0.34" - }, - "devDependencies": { - "@types/node": "22", - "nodemon": "3", - "tmp-promise": "3", - "ts-node": "10" - } + "name": "@epfml/discojs-node", + "version": "3.0.0", + "type": "module", + "exports": "./dist/index.js", + "types": "dist/index.d.ts", + "scripts": { + "watch": "nodemon --ext ts --ignore dist --watch ../discojs/dist --watch . --exec npm run", + "build": "tsc --build", + "test": "cd .. && vitest --run --project=discojs-node" + }, + "repository": { + "type": "git", + "url": "git+https://github.com/epfml/disco.git" + }, + "bugs": { + "url": "https://github.com/epfml/disco/issues" + }, + "homepage": "https://github.com/epfml/disco#readme", + "dependencies": { + "@epfml/discojs": "*", + "@roamhq/wrtc": "0.10", + "@tensorflow/tfjs-node": "4", + "csv-parse": "6", + "sharp": "0.34" + }, + "devDependencies": { + "@types/node": "22", + "nodemon": "3", + "tmp-promise": "3", + "ts-node": "10" + } } diff --git a/discojs-node/src/hellaswag.spec.ts b/discojs-node/src/hellaswag.spec.ts index 0c28a7d2b..dc4455d3a 100644 --- a/discojs-node/src/hellaswag.spec.ts +++ b/discojs-node/src/hellaswag.spec.ts @@ -1,19 +1,21 @@ -import { describe, expect, it } from "vitest"; - -import { load as loadHellaSwag } from './hellaswag.js'; - -describe('HellaSwag parser', () => { - it('should load all examples and return them as an array', async () => { - const dataset = await loadHellaSwag(10); - - expect(dataset).to.be.an('array'); - expect(dataset.length).to.be.greaterThan(0); - - // Check the structure of the first example - const example = dataset[0]; - expect(example).to.have.property('ctx').that.is.a('string'); - expect(example).to.have.property('endings').that.is.an('array').with.lengthOf(4); - expect(example).to.have.property('label').that.is.a('number'); - }); -}); - +import { describe, expect, it } from "vitest"; + +import { load as loadHellaSwag } from "./hellaswag.js"; + +describe("HellaSwag parser", () => { + it("should load all examples and return them as an array", async () => { + const dataset = await loadHellaSwag(10); + + expect(dataset).to.be.an("array"); + expect(dataset.length).to.be.greaterThan(0); + + // Check the structure of the first example + const example = dataset[0]; + expect(example).to.have.property("ctx").that.is.a("string"); + expect(example) + .to.have.property("endings") + .that.is.an("array") + .with.lengthOf(4); + expect(example).to.have.property("label").that.is.a("number"); + }); +}); diff --git a/discojs-node/src/hellaswag.ts b/discojs-node/src/hellaswag.ts index af8ec90ed..77aa23998 100644 --- a/discojs-node/src/hellaswag.ts +++ b/discojs-node/src/hellaswag.ts @@ -1,57 +1,59 @@ -import path from "node:path"; -import fetch from 'node-fetch'; -import fs from 'node:fs/promises'; - -import { models } from '@epfml/discojs'; - -import { dirname } from 'path'; -import { fileURLToPath } from 'url'; -const __dirname = dirname(fileURLToPath(import.meta.url)); - -const DATASET_DIR = path.join(__dirname, "..", "..", "datasets"); -const hellaswag_filepath = path.join(DATASET_DIR, "hellaswag_val.jsonl") - -/** - * Loads the HellaSwag dataset from the remote URL in Node.js - * - * @param limit - Maximum number of examples to load (-1 means all) - * @returns A HellaSwagDataset containing the examples. - */ -export async function load(limit = -1): Promise { - let text: string; - try { - // Reads the file if it exists locally - text = (await fs.readFile(hellaswag_filepath)).toString(); - } catch { - console.log("Downloading the Hellaswag benchmark") - // Otherwise fetch it - const response = await fetch(models.HELLASWAG_URL); - if (!response.ok) { - throw new Error(`Failed to fetch dataset from ${models.HELLASWAG_URL}: ${response.statusText}`); - } - - text = await response.text(); - // Save the file locally - await fs.writeFile(hellaswag_filepath, text); - } - - const lines = text.split('\n'); - - const dataset: models.HellaSwagDataset = []; - let count = 0; - for (const line of lines) { - if (line.trim().length === 0) continue; - if (limit !== -1 && count >= limit) break; - - try { - const data = JSON.parse(line.trim()) as models.HellaSwagExample; - dataset.push(data); - count++; - } catch (e) { - console.error(`Failed to parse line:`, line); - throw e; - } - } - - return dataset; -} +import path from "node:path"; +import fetch from "node-fetch"; +import fs from "node:fs/promises"; + +import { models } from "@epfml/discojs"; + +import { dirname } from "path"; +import { fileURLToPath } from "url"; +const __dirname = dirname(fileURLToPath(import.meta.url)); + +const DATASET_DIR = path.join(__dirname, "..", "..", "datasets"); +const hellaswag_filepath = path.join(DATASET_DIR, "hellaswag_val.jsonl"); + +/** + * Loads the HellaSwag dataset from the remote URL in Node.js + * + * @param limit - Maximum number of examples to load (-1 means all) + * @returns A HellaSwagDataset containing the examples. + */ +export async function load(limit = -1): Promise { + let text: string; + try { + // Reads the file if it exists locally + text = (await fs.readFile(hellaswag_filepath)).toString(); + } catch { + console.log("Downloading the Hellaswag benchmark"); + // Otherwise fetch it + const response = await fetch(models.HELLASWAG_URL); + if (!response.ok) { + throw new Error( + `Failed to fetch dataset from ${models.HELLASWAG_URL}: ${response.statusText}`, + ); + } + + text = await response.text(); + // Save the file locally + await fs.writeFile(hellaswag_filepath, text); + } + + const lines = text.split("\n"); + + const dataset: models.HellaSwagDataset = []; + let count = 0; + for (const line of lines) { + if (line.trim().length === 0) continue; + if (limit !== -1 && count >= limit) break; + + try { + const data = JSON.parse(line.trim()) as models.HellaSwagExample; + dataset.push(data); + count++; + } catch (e) { + console.error(`Failed to parse line:`, line); + throw e; + } + } + + return dataset; +} diff --git a/discojs-node/src/index.ts b/discojs-node/src/index.ts index 219b31222..0a53e8cac 100644 --- a/discojs-node/src/index.ts +++ b/discojs-node/src/index.ts @@ -1,3 +1,3 @@ -export * from './loaders/index.js' -export { saveModelToDisk, loadModelFromDisk } from './model_loader.js' -export { load as loadHellaSwag } from './hellaswag.js' \ No newline at end of file +export * from "./loaders/index.js"; +export { saveModelToDisk, loadModelFromDisk } from "./model_loader.js"; +export { load as loadHellaSwag } from "./hellaswag.js"; diff --git a/discojs-node/src/loaders.spec.ts b/discojs-node/src/loaders.spec.ts index fd6b2363f..494842328 100644 --- a/discojs-node/src/loaders.spec.ts +++ b/discojs-node/src/loaders.spec.ts @@ -3,64 +3,64 @@ import path from "node:path"; import { withFile } from "tmp-promise"; import { describe, expect, it } from "vitest"; import { - loadCSV, - loadImage, - loadImagesInDir, - loadText, + loadCSV, + loadImage, + loadImagesInDir, + loadText, } from "./loaders/index.js"; const DATASETS_DIR = path.join(__dirname, "..", "..", "datasets"); // Array.fromAsync not yet widely used (2024) async function arrayFromAsync(iter: AsyncIterable): Promise { - const ret: T[] = []; - for await (const e of iter) ret.push(e); - return ret; + const ret: T[] = []; + for await (const e of iter) ret.push(e); + return ret; } describe("csv parser", () => { - it("parses basic file", async () => { - await withFile(async ({ path }) => { - await fs.writeFile(path, ["a,b,c", "1,2,3", "4,5,6"].join("\n")); + it("parses basic file", async () => { + await withFile(async ({ path }) => { + await fs.writeFile(path, ["a,b,c", "1,2,3", "4,5,6"].join("\n")); - const dataset = loadCSV(path); + const dataset = loadCSV(path); - expect(await arrayFromAsync(dataset)).to.have.deep.ordered.members([ - { a: "1", b: "2", c: "3" }, - { a: "4", b: "5", c: "6" }, - ]); - }); - }); + expect(await arrayFromAsync(dataset)).to.have.deep.ordered.members([ + { a: "1", b: "2", c: "3" }, + { a: "4", b: "5", c: "6" }, + ]); + }); + }); }); describe("image parser", () => { - it("parses mnist example", async () => { - const parsed = await loadImage( - path.join(DATASETS_DIR, "9-mnist-example.png"), - ); + it("parses mnist example", async () => { + const parsed = await loadImage( + path.join(DATASETS_DIR, "9-mnist-example.png"), + ); - expect(parsed).to.have.property("width").that.equals(172); - expect(parsed).to.have.property("height").that.equals(178); - }); + expect(parsed).to.have.property("width").that.equals(172); + expect(parsed).to.have.property("height").that.equals(178); + }); }); describe("image directory parser", () => { - it("parses all cifar10 files", async () => { - const parsed = await loadImagesInDir(path.join(DATASETS_DIR, "CIFAR10")); + it("parses all cifar10 files", async () => { + const parsed = await loadImagesInDir(path.join(DATASETS_DIR, "CIFAR10")); - expect(await parsed.size()).to.equal(24); - }); + expect(await parsed.size()).to.equal(24); + }); }); describe("text parser", () => { - it("parses basic file", async () => { - const text = ["a", "b", "c"].join("\n"); - await withFile(async ({ path }) => { - await fs.writeFile(path, text); + it("parses basic file", async () => { + const text = ["a", "b", "c"].join("\n"); + await withFile(async ({ path }) => { + await fs.writeFile(path, text); - const sequences = await arrayFromAsync(loadText(path)); - expect(sequences.length).to.equal(1); - expect(sequences[0]).to.equal(text); - }); - }); + const sequences = await arrayFromAsync(loadText(path)); + expect(sequences.length).to.equal(1); + expect(sequences[0]).to.equal(text); + }); + }); }); diff --git a/discojs-node/src/loaders/index.ts b/discojs-node/src/loaders/index.ts index bf934732b..8e585bcb1 100644 --- a/discojs-node/src/loaders/index.ts +++ b/discojs-node/src/loaders/index.ts @@ -1,6 +1,3 @@ export { load as loadCSV } from "./csv.js"; -export { - load as loadImage, - loadAllInDir as loadImagesInDir, -} from "./image.js"; +export { load as loadImage, loadAllInDir as loadImagesInDir } from "./image.js"; export { load as loadText } from "./text.js"; diff --git a/discojs-node/src/loaders/text.ts b/discojs-node/src/loaders/text.ts index aa4e66383..cf5d22a48 100644 --- a/discojs-node/src/loaders/text.ts +++ b/discojs-node/src/loaders/text.ts @@ -1,13 +1,13 @@ import createDebug from "debug"; -import { createReadStream } from 'node:fs'; +import { createReadStream } from "node:fs"; import { Dataset, Text } from "@epfml/discojs"; const debug = createDebug("discojs-node:loaders:text"); /** - * Returns chunks of text. Use `minChunkSize` to ensure that + * Returns chunks of text. Use `minChunkSize` to ensure that * each chunk is bigger than the expected sequence length. - * + * * @param path path to the text file to read * @returns a dataset of tokenized input and label sequences */ @@ -16,11 +16,11 @@ export function load(path: string): Dataset { // Create a stream to read the text file chunk by chunk const stream = createReadStream(path, { encoding: "utf8" }); for await (const chunk of stream) { - if (typeof chunk !== 'string') - throw new Error('Expected file stream to yield string') + if (typeof chunk !== "string") + throw new Error("Expected file stream to yield string"); debug("yield chunk of length: %o", chunk.length); - yield chunk + yield chunk; } }); } diff --git a/discojs-node/tsconfig.lib.json b/discojs-node/tsconfig.lib.json index 81e28b0ba..8c8a0ef5d 100644 --- a/discojs-node/tsconfig.lib.json +++ b/discojs-node/tsconfig.lib.json @@ -1,6 +1,6 @@ { - "extends": "../tsconfig.base.json", - "compilerOptions": { "outDir": "dist" }, - "include": ["src"], - "exclude": ["**/*.spec.ts"] + "extends": "../tsconfig.base.json", + "compilerOptions": { "outDir": "dist" }, + "include": ["src"], + "exclude": ["**/*.spec.ts"] } diff --git a/discojs-node/tsconfig.vitest.json b/discojs-node/tsconfig.vitest.json index c4a3913ec..63288c889 100644 --- a/discojs-node/tsconfig.vitest.json +++ b/discojs-node/tsconfig.vitest.json @@ -1,5 +1,5 @@ { - "extends": "../tsconfig.base.json", - "compilerOptions": { "noEmit": true }, - "include": ["src"] + "extends": "../tsconfig.base.json", + "compilerOptions": { "noEmit": true }, + "include": ["src"] } diff --git a/discojs-web/package.json b/discojs-web/package.json index 4e55d18b1..c286f011c 100644 --- a/discojs-web/package.json +++ b/discojs-web/package.json @@ -1,30 +1,30 @@ { - "name": "@epfml/discojs-web", - "version": "3.0.0", - "type": "module", - "exports": "./dist/index.js", - "types": "dist/index.d.ts", - "scripts": { - "watch": "nodemon --ext ts --ignore dist --watch ../discojs/dist --watch . --exec npm run", - "build": "tsc --build", - "test": "cd .. && vitest --run --project=discojs-web" - }, - "repository": { - "type": "git", - "url": "git+https://github.com/epfml/disco.git" - }, - "bugs": { - "url": "https://github.com/epfml/disco/issues" - }, - "homepage": "https://github.com/epfml/disco#readme", - "dependencies": { - "@epfml/discojs": "*", - "@tensorflow/tfjs": "4", - "papaparse": "5" - }, - "devDependencies": { - "@types/papaparse": "5", - "jsdom": "29", - "nodemon": "3" - } + "name": "@epfml/discojs-web", + "version": "3.0.0", + "type": "module", + "exports": "./dist/index.js", + "types": "dist/index.d.ts", + "scripts": { + "watch": "nodemon --ext ts --ignore dist --watch ../discojs/dist --watch . --exec npm run", + "build": "tsc --build", + "test": "cd .. && vitest --run --project=discojs-web" + }, + "repository": { + "type": "git", + "url": "git+https://github.com/epfml/disco.git" + }, + "bugs": { + "url": "https://github.com/epfml/disco/issues" + }, + "homepage": "https://github.com/epfml/disco#readme", + "dependencies": { + "@epfml/discojs": "*", + "@tensorflow/tfjs": "4", + "papaparse": "5" + }, + "devDependencies": { + "@types/papaparse": "5", + "jsdom": "29", + "nodemon": "3" + } } diff --git a/discojs-web/src/hellaswag.spec.ts b/discojs-web/src/hellaswag.spec.ts index 28ce9a700..76b5dbcef 100644 --- a/discojs-web/src/hellaswag.spec.ts +++ b/discojs-web/src/hellaswag.spec.ts @@ -1,19 +1,22 @@ -import { describe, it, expect } from "vitest"; -import { load as loadHellaSwag } from './hellaswag.js'; -import { models } from '@epfml/discojs'; - -describe('hellaswag parser', () => { - it('loads the whole hellaswag dataset', async () => { - const dataset: models.HellaSwagDataset = await loadHellaSwag(2); - - // basic assertions - expect(dataset).to.be.an('array'); - expect(dataset.length).to.equal(2); - - // check structure of the first example - const first = dataset[0]; - expect(first).to.have.property('ctx').that.is.a('string'); - expect(first).to.have.property('endings').that.is.an('array').with.lengthOf(4); - expect(first).to.have.property('label').that.is.a('number'); - }); -}); +import { describe, it, expect } from "vitest"; +import { load as loadHellaSwag } from "./hellaswag.js"; +import { models } from "@epfml/discojs"; + +describe("hellaswag parser", () => { + it("loads the whole hellaswag dataset", async () => { + const dataset: models.HellaSwagDataset = await loadHellaSwag(2); + + // basic assertions + expect(dataset).to.be.an("array"); + expect(dataset.length).to.equal(2); + + // check structure of the first example + const first = dataset[0]; + expect(first).to.have.property("ctx").that.is.a("string"); + expect(first) + .to.have.property("endings") + .that.is.an("array") + .with.lengthOf(4); + expect(first).to.have.property("label").that.is.a("number"); + }); +}); diff --git a/discojs-web/src/hellaswag.ts b/discojs-web/src/hellaswag.ts index bca561e4f..da96aac95 100644 --- a/discojs-web/src/hellaswag.ts +++ b/discojs-web/src/hellaswag.ts @@ -1,35 +1,37 @@ -import { models } from '@epfml/discojs'; - -/** - * Loads the HellaSwag dataset from the remote URL in the browser - * - * @param limit - Maximum number of examples to load (-1 means all) - * @returns A HellaSwagDataset containing the examples - */ -export async function load(limit = -1): Promise { - const response = await fetch(models.HELLASWAG_URL); - if (!response.ok) { - throw new Error(`Failed to fetch dataset from ${models.HELLASWAG_URL}: ${response.statusText}`); - } - - const text = await response.text(); - const lines = text.split('\n'); - - const dataset: models.HellaSwagDataset = []; - let count = 0; - for (const line of lines) { - if (line.trim().length === 0) continue; - if (limit !== -1 && count >= limit) break; - - try { - const data = JSON.parse(line.trim()) as models.HellaSwagExample; - dataset.push(data); - count++; - } catch (e) { - console.error(`Failed to parse line:`, line); - throw e; - } - } - - return dataset; -} +import { models } from "@epfml/discojs"; + +/** + * Loads the HellaSwag dataset from the remote URL in the browser + * + * @param limit - Maximum number of examples to load (-1 means all) + * @returns A HellaSwagDataset containing the examples + */ +export async function load(limit = -1): Promise { + const response = await fetch(models.HELLASWAG_URL); + if (!response.ok) { + throw new Error( + `Failed to fetch dataset from ${models.HELLASWAG_URL}: ${response.statusText}`, + ); + } + + const text = await response.text(); + const lines = text.split("\n"); + + const dataset: models.HellaSwagDataset = []; + let count = 0; + for (const line of lines) { + if (line.trim().length === 0) continue; + if (limit !== -1 && count >= limit) break; + + try { + const data = JSON.parse(line.trim()) as models.HellaSwagExample; + dataset.push(data); + count++; + } catch (e) { + console.error(`Failed to parse line:`, line); + throw e; + } + } + + return dataset; +} diff --git a/discojs-web/src/index.ts b/discojs-web/src/index.ts index 9c7506d09..f8f6f8db1 100644 --- a/discojs-web/src/index.ts +++ b/discojs-web/src/index.ts @@ -1,2 +1,2 @@ export * from "./loaders/index.js"; -export { load as loadHellaSwag } from "./hellaswag.js"; \ No newline at end of file +export { load as loadHellaSwag } from "./hellaswag.js"; diff --git a/discojs-web/src/loaders.spec.ts b/discojs-web/src/loaders.spec.ts index 7bbb6bd6f..2db72c1fa 100644 --- a/discojs-web/src/loaders.spec.ts +++ b/discojs-web/src/loaders.spec.ts @@ -22,14 +22,14 @@ describe("csv parser", () => { describe("text parser", () => { it("loads a simple sequence", async () => { - const text = ["first", "second", "third"].join("\n") - + const text = ["first", "second", "third"].join("\n"); + // jsdom doesn't implement .text on File/Blob // trick from https://github.com/jsdom/jsdom/issues/2555 const file = await ( - await fetch( "data:," + encodeURIComponent(text)) + await fetch("data:," + encodeURIComponent(text)) ).blob(); - const parsed = loadText(file) + const parsed = loadText(file); expect(await parsed.size()).to.equal(1); expect((await arrayFromAsync(parsed))[0]).to.equal(text); }); diff --git a/discojs-web/src/loaders/csv.ts b/discojs-web/src/loaders/csv.ts index 074c2c2ae..db319b903 100644 --- a/discojs-web/src/loaders/csv.ts +++ b/discojs-web/src/loaders/csv.ts @@ -24,7 +24,7 @@ export function load(file: File): Dataset>> { skipEmptyLines: true, // TODO needed to avoid parsing last empty line complete(results) { if (results.errors.length > 0) { - const error = results.errors[0] + const error = results.errors[0]; reject(new Error(error.message)); return; } diff --git a/discojs-web/src/loaders/index.ts b/discojs-web/src/loaders/index.ts index a193d7d66..d4128b507 100644 --- a/discojs-web/src/loaders/index.ts +++ b/discojs-web/src/loaders/index.ts @@ -1,3 +1,3 @@ -export { load as loadCSV } from "./csv.js" -export { load as loadImage } from "./image.js" -export { load as loadText } from "./text.js" +export { load as loadCSV } from "./csv.js"; +export { load as loadImage } from "./image.js"; +export { load as loadText } from "./text.js"; diff --git a/discojs-web/tsconfig.lib.json b/discojs-web/tsconfig.lib.json index 034904728..dce2acb92 100644 --- a/discojs-web/tsconfig.lib.json +++ b/discojs-web/tsconfig.lib.json @@ -1,9 +1,9 @@ { - "extends": "../tsconfig.base.json", - "compilerOptions": { - "lib": ["DOM"], - "outDir": "dist" - }, - "include": ["src"], - "exclude": ["**/*.spec.ts"] + "extends": "../tsconfig.base.json", + "compilerOptions": { + "lib": ["DOM"], + "outDir": "dist" + }, + "include": ["src"], + "exclude": ["**/*.spec.ts"] } diff --git a/discojs-web/tsconfig.vitest.json b/discojs-web/tsconfig.vitest.json index 7dc009819..7b4cdb6a2 100644 --- a/discojs-web/tsconfig.vitest.json +++ b/discojs-web/tsconfig.vitest.json @@ -1,8 +1,8 @@ { - "extends": "../tsconfig.base.json", - "compilerOptions": { - "lib": ["DOM"], - "noEmit": true - }, - "include": ["src"] + "extends": "../tsconfig.base.json", + "compilerOptions": { + "lib": ["DOM"], + "noEmit": true + }, + "include": ["src"] } diff --git a/discojs/README.md b/discojs/README.md index 7af371ce6..37d6b1e3f 100644 --- a/discojs/README.md +++ b/discojs/README.md @@ -5,8 +5,9 @@ Decentralized & federated privacy-preserving ML training in TypeScript. This is the core library of the Disco.js project. It is platform-agnostic, and has two companions library: - - [`discojs-node`](../discojs-node) for Node.js - - [`discojs-web`](../discojs-web) for web browsers + +- [`discojs-node`](../discojs-node) for Node.js +- [`discojs-web`](../discojs-web) for web browsers The easiest way to start using it is through the `Disco` object. Create your own `Task` or load one from our `default_tasks`, diff --git a/discojs/package.json b/discojs/package.json index 4e3b63a5f..74b8ba268 100644 --- a/discojs/package.json +++ b/discojs/package.json @@ -1,39 +1,39 @@ { - "name": "@epfml/discojs", - "version": "3.0.0", - "type": "module", - "exports": "./dist/index.js", - "types": "dist/index.d.ts", - "scripts": { - "watch": "nodemon --ext ts --ignore dist --exec npm run", - "build": "tsc --build", - "test": "cd .. && vitest --run --project=discojs" - }, - "repository": { - "type": "git", - "url": "git+https://github.com/epfml/disco.git" - }, - "bugs": { - "url": "https://github.com/epfml/disco/issues" - }, - "homepage": "https://github.com/epfml/disco#readme", - "dependencies": { - "@epfml/isomorphic-wrtc": "1", - "@jimp/core": "1", - "@jimp/plugin-resize": "1", - "@msgpack/msgpack": "3", - "@tensorflow/tfjs": "4", - "@xenova/transformers": "2", - "isomorphic-ws": "5", - "simple-peer": "9", - "tslib": "2", - "ws": "8", - "zod": "4" - }, - "devDependencies": { - "@tensorflow/tfjs-node": "4", - "@types/simple-peer": "9", - "nodemon": "3", - "ts-node": "10" - } + "name": "@epfml/discojs", + "version": "3.0.0", + "type": "module", + "exports": "./dist/index.js", + "types": "dist/index.d.ts", + "scripts": { + "watch": "nodemon --ext ts --ignore dist --exec npm run", + "build": "tsc --build", + "test": "cd .. && vitest --run --project=discojs" + }, + "repository": { + "type": "git", + "url": "git+https://github.com/epfml/disco.git" + }, + "bugs": { + "url": "https://github.com/epfml/disco/issues" + }, + "homepage": "https://github.com/epfml/disco#readme", + "dependencies": { + "@epfml/isomorphic-wrtc": "1", + "@jimp/core": "1", + "@jimp/plugin-resize": "1", + "@msgpack/msgpack": "3", + "@tensorflow/tfjs": "4", + "@xenova/transformers": "2", + "isomorphic-ws": "5", + "simple-peer": "9", + "tslib": "2", + "ws": "8", + "zod": "4" + }, + "devDependencies": { + "@tensorflow/tfjs-node": "4", + "@types/simple-peer": "9", + "nodemon": "3", + "ts-node": "10" + } } diff --git a/discojs/src/aggregator.spec.ts b/discojs/src/aggregator.spec.ts index cc377b784..40f5806e0 100644 --- a/discojs/src/aggregator.spec.ts +++ b/discojs/src/aggregator.spec.ts @@ -31,14 +31,14 @@ for (const [name, Aggregator] of AGGREGATORS) { const results = new Promise((resolve) => aggregator.on("aggregation", resolve), ); - - let promises = List>() + + let promises = List>(); for (let i = 0; i < 3; i++) - for (let r = 0; r < aggregator.communicationRounds; r++){ - promises = promises.push(aggregator.getPromiseForAggregation()) - aggregator.add(`client ${i}`, WeightsContainer.of([i]), 0, r) + for (let r = 0; r < aggregator.communicationRounds; r++) { + promises = promises.push(aggregator.getPromiseForAggregation()); + aggregator.add(`client ${i}`, WeightsContainer.of([i]), 0, r); } - await Promise.all(promises) + await Promise.all(promises); await results; // nothing to test expect(aggregator.round).to.equal(1); @@ -59,7 +59,8 @@ for (const [name, Aggregator] of AGGREGATORS) { id, [agg, WeightsContainer.of([ws])], ]), - ), 0 + ), + 0, ) ) .valueSeq() @@ -73,7 +74,7 @@ for (const [name, Aggregator] of AGGREGATORS) { return first; }); }); - }) + }); } export async function wsIntoArrays(ws: WeightsContainer): Promise { @@ -88,7 +89,7 @@ export function setupNetwork( const ret = Map( Range(0, 3).map((i) => [`client ${i}`, new Aggregator()] as [NodeID, A]), ); - for (const secure of ret.values()) secure.setNodes(ret.keySeq().toSet()); + for (const secure of ret.values()) secure.setNodes(ret.keySeq().toSet()); return ret; } @@ -96,7 +97,7 @@ export function setupNetwork( // run all rounds of communication export async function communicate( networkWithContributions: Map, - aggregationRound: number + aggregationRound: number, ): Promise> { const communicationsRound = networkWithContributions.first()?.[0].communicationRounds; @@ -116,14 +117,14 @@ export async function communicate( ]) .toArray(); - for (const [id, agg] of network) { - const contribution = contributions.get(id); - if (contribution === undefined) - throw new Error(`no contribution for ${id}`); + for (const [id, agg] of network) { + const contribution = contributions.get(id); + if (contribution === undefined) + throw new Error(`no contribution for ${id}`); - for (const [to, payload] of agg.makePayloads(contribution)) - network.get(to)?.add(id, payload, aggregationRound, r); - } + for (const [to, payload] of agg.makePayloads(contribution)) + network.get(to)?.add(id, payload, aggregationRound, r); + } contributions = Map(await Promise.all(nextContributions)); } diff --git a/discojs/src/aggregator/aggregator.ts b/discojs/src/aggregator/aggregator.ts index f21aff2aa..7f437a27d 100644 --- a/discojs/src/aggregator/aggregator.ts +++ b/discojs/src/aggregator/aggregator.ts @@ -1,16 +1,16 @@ import createDebug from "debug"; -import { Map, Set } from 'immutable' +import { Map, Set } from "immutable"; -import type { client, WeightsContainer } from '../index.js' +import type { client, WeightsContainer } from "../index.js"; -import { EventEmitter } from '../utils/event_emitter.js' +import { EventEmitter } from "../utils/event_emitter.js"; const debug = createDebug("discojs:aggregator"); export enum AggregationStep { ADD, UPDATE, - AGGREGATE + AGGREGATE, } /** @@ -20,34 +20,36 @@ export enum AggregationStep { * Emits an event whenever an aggregation step is performed with the counrd's aggregated weights. * Users subscribes to this event to get the aggregation result. */ -export abstract class Aggregator extends EventEmitter<{'aggregation': WeightsContainer }> { +export abstract class Aggregator extends EventEmitter<{ + aggregation: WeightsContainer; +}> { /** * Contains the ids of all active nodes, i.e. members of the aggregation group at * a given round. It is a subset of all the nodes available in the network. */ - protected _nodes: Set + protected _nodes: Set; /** * Contains the contributions received from active nodes, accessible by node id. * It defines the effective aggregation group, which is possibly a subset * of all active nodes, depending on the aggregation scheme. */ // communication round -> NodeID -> WeightsContainer - protected contributions: Map> + protected contributions: Map>; /** * The current aggregation round, used for assessing whether a node contribution is recent enough * or not. */ - protected _round = 0 + protected _round = 0; /** * The current communication round. A single aggregation round is made of possibly multiple * communication rounds. This makes the aggregator free to perform intermediate aggregation * steps based off communication with its nodes. Overall, this allows for more complex * aggregation schemes requiring an exchange of information between nodes before aggregating. */ - protected _communicationRound = 0 + protected _communicationRound = 0; - constructor ( + constructor( /** * The round cut-off for contributions. */ @@ -55,22 +57,24 @@ export abstract class Aggregator extends EventEmitter<{'aggregation': WeightsCon /** * The number of communication rounds occurring during any given aggregation round. */ - public readonly communicationRounds = 1 + public readonly communicationRounds = 1, ) { - super() + super(); - this.contributions = Map() - this._nodes = Set() + this.contributions = Map(); + this._nodes = Set(); } /** * Convenience method to subscribe to the 'aggregation' event. * Await this promise returns the aggregated weights for the current round. - * + * * @returns a promise for the aggregated weights */ getPromiseForAggregation(): Promise { - return new Promise((resolve) => this.once('aggregation', resolve)); + return new Promise((resolve) => + this.once("aggregation", resolve), + ); } /** @@ -79,63 +83,79 @@ export abstract class Aggregator extends EventEmitter<{'aggregation': WeightsCon * Within one aggregation round there may be multiple communication rounds (such as for the decentralized secure aggregation * which requires multiple steps to obtain a global model) * The contribution is aggregated during the next aggregation step. - * + * * @param nodeId The node's id * @param contribution The node's contribution */ - add(nodeId: client.NodeID, contribution: WeightsContainer, - aggregationRound: number, communicationRound?: number): void { + add( + nodeId: client.NodeID, + contribution: WeightsContainer, + aggregationRound: number, + communicationRound?: number, + ): void { if (!this.isValidContribution(nodeId, aggregationRound)) - throw new Error("Tried adding an invalid contribution. Handle this case before calling add.") - + throw new Error( + "Tried adding an invalid contribution. Handle this case before calling add.", + ); + // call the abstract method _add, implemented by subclasses - this._add(nodeId, contribution, communicationRound ?? this.communicationRound) + this._add( + nodeId, + contribution, + communicationRound ?? this.communicationRound, + ); // If the aggregator has enough contributions then aggregate the weights // and emit the 'aggregation' event if (this.isFull()) { - const aggregatedWeights = this.aggregate() + const aggregatedWeights = this.aggregate(); // On each aggregation, increment the communication round // If all communication rounds were performed, proceed to the next aggregation round // and empty the past contributions. this._communicationRound++; if (this.communicationRound === this.communicationRounds) { - this._communicationRound = 0 + this._communicationRound = 0; this._round++; - this.contributions = Map() + this.contributions = Map(); } // Emitting the 'aggregation' communicates the weights to subscribers - this.emit('aggregation', aggregatedWeights) + this.emit("aggregation", aggregatedWeights); } } - + // Abstract method to be implemented by subclasses // Handles logging and adding the contribution to the list of the current round's contributions - protected abstract _add(nodeId: client.NodeID, contribution: WeightsContainer, communicationRound?: number): void + protected abstract _add( + nodeId: client.NodeID, + contribution: WeightsContainer, + communicationRound?: number, + ): void; /** * Evaluates whether a given participant contribution can be used in the current aggregation round * the boolean returned by `this.add` is obtained via `this.isValidContribution` - * + * * @param nodeId the node id of the contribution to be added * @param round the aggregation round of the contribution to be added */ isValidContribution(nodeId: client.NodeID, round: number): boolean { if (!this.nodes.has(nodeId)) { - debug("Contribution rejected because node id is not registered") + debug("Contribution rejected because node id is not registered"); return false; } if (!this.isWithinRoundCutoff(round)) { - debug(`Contribution rejected because round ${round} is not within round cutoff`) + debug( + `Contribution rejected because round ${round} is not within round cutoff`, + ); return false; } - return true + return true; } /** * Performs an aggregation step over the received node contributions. * Must store the aggregation's result in the aggregator's result promise. */ - protected abstract aggregate (): WeightsContainer + protected abstract aggregate(): WeightsContainer; /** * Returns whether the given round is recent enough, dependent on the @@ -143,8 +163,8 @@ export abstract class Aggregator extends EventEmitter<{'aggregation': WeightsCon * @param round The round * @returns True if the round is recent enough, false otherwise */ - private isWithinRoundCutoff (round: number): boolean { - return this.round - round <= this.roundCutoff + private isWithinRoundCutoff(round: number): boolean { + return this.round - round <= this.roundCutoff; } /** @@ -152,23 +172,29 @@ export abstract class Aggregator extends EventEmitter<{'aggregation': WeightsCon * @param step The aggregation step * @param from The node which triggered the logging message */ - log (step: AggregationStep, from?: client.NodeID): void { + log(step: AggregationStep, from?: client.NodeID): void { switch (step) { case AggregationStep.ADD: - debug(`Adding contribution from node ${from ?? '"unknown"'} for aggregation round ${this.round} and communication round ${this.communicationRound}`); - break + debug( + `Adding contribution from node ${from ?? '"unknown"'} for aggregation round ${this.round} and communication round ${this.communicationRound}`, + ); + break; case AggregationStep.UPDATE: if (from === undefined) { - return + return; } - debug(`Updating contribution from node ${from} for aggregation round ${this.round} and communication round ${this.communicationRound}`) - break + debug( + `Updating contribution from node ${from} for aggregation round ${this.round} and communication round ${this.communicationRound}`, + ); + break; case AggregationStep.AGGREGATE: - debug(`Buffer is full. Aggregating weights for round aggregation round ${this.round} and communication round ${this.communicationRound}`) - break + debug( + `Buffer is full. Aggregating weights for round aggregation round ${this.round} and communication round ${this.communicationRound}`, + ); + break; default: { - const _: never = step - throw new Error('should never happen') + const _: never = step; + throw new Error("should never happen"); } } } @@ -180,20 +206,20 @@ export abstract class Aggregator extends EventEmitter<{'aggregation': WeightsCon * @param nodeId The node to be added * @returns True is the node wasn't already in the list of nodes, False if already included */ - registerNode (nodeId: client.NodeID): boolean { + registerNode(nodeId: client.NodeID): boolean { if (!this.nodes.has(nodeId)) { - this._nodes = this._nodes.add(nodeId) - return true + this._nodes = this._nodes.add(nodeId); + return true; } - return false + return false; } /** * Remove a node's id from the set of active nodes. * @param nodeId The node to be removed */ - removeNode (nodeId: client.NodeID): void{ - this._nodes = this._nodes.delete(nodeId) + removeNode(nodeId: client.NodeID): void { + this._nodes = this._nodes.delete(nodeId); } /** @@ -202,8 +228,8 @@ export abstract class Aggregator extends EventEmitter<{'aggregation': WeightsCon * during this aggregation round. * @param nodeIds The new set of nodes */ - setNodes (nodeIds: Set): void { - this._nodes = nodeIds + setNodes(nodeIds: Set): void { + this._nodes = nodeIds; } /** @@ -211,9 +237,9 @@ export abstract class Aggregator extends EventEmitter<{'aggregation': WeightsCon * with the network's round. * @param round The new round */ - setRound (round: number): void { + setRound(round: number): void { if (round > this.round) { - this._round = round + this._round = round; } } @@ -221,28 +247,30 @@ export abstract class Aggregator extends EventEmitter<{'aggregation': WeightsCon * Constructs the payloads sent to other nodes as contribution. * @param base Object from which the payload is computed */ - abstract makePayloads (base: WeightsContainer): Map + abstract makePayloads( + base: WeightsContainer, + ): Map; - abstract isFull (): boolean + abstract isFull(): boolean; /** * The set of node ids, representing our neighbors within the network. */ - get nodes (): Set { - return this._nodes + get nodes(): Set { + return this._nodes; } /** * The aggregation round. */ - get round (): number { - return this._round + get round(): number { + return this._round; } /** * The current communication round. */ - get communicationRound (): number { - return this._communicationRound + get communicationRound(): number { + return this._communicationRound; } } diff --git a/discojs/src/aggregator/byzantine.spec.ts b/discojs/src/aggregator/byzantine.spec.ts index d300fbb4c..bd556d78e 100644 --- a/discojs/src/aggregator/byzantine.spec.ts +++ b/discojs/src/aggregator/byzantine.spec.ts @@ -6,19 +6,27 @@ import { ByzantineRobustAggregator } from "./byzantine.js"; // Helper to convert WeightsContainer → number[][] for easy assertions async function WSIntoArrays(ws: WeightsContainer): Promise { - return Promise.all(ws.weights.map(async t => Array.from(await t.data()))); + return Promise.all(ws.weights.map(async (t) => Array.from(await t.data()))); } describe("ByzantineRobustAggregator", () => { it("throws on invalid constructor parameters", () => { - expect(() => new ByzantineRobustAggregator(0, 1, 'absolute', 0, 1, 0.5)).to.throw(); - expect(() => new ByzantineRobustAggregator(0, 1, 'absolute', 1, 0, 0.5)).to.throw(); - expect(() => new ByzantineRobustAggregator(0, 1, 'absolute', 1, 1.1, 0.5)).to.throw(); - expect(() => new ByzantineRobustAggregator(0, 1, 'absolute', 1, 1, 1.5)).to.throw(); + expect( + () => new ByzantineRobustAggregator(0, 1, "absolute", 0, 1, 0.5), + ).to.throw(); + expect( + () => new ByzantineRobustAggregator(0, 1, "absolute", 1, 0, 0.5), + ).to.throw(); + expect( + () => new ByzantineRobustAggregator(0, 1, "absolute", 1, 1.1, 0.5), + ).to.throw(); + expect( + () => new ByzantineRobustAggregator(0, 1, "absolute", 1, 1, 1.5), + ).to.throw(); }); it("performs basic mean when clippingRadius is large and beta = 0", async () => { - const agg = new ByzantineRobustAggregator(0, 2, 'absolute', 1e6, 1, 0); + const agg = new ByzantineRobustAggregator(0, 2, "absolute", 1e6, 1, 0); const [id1, id2] = ["c1", "c2"]; agg.setNodes(Set.of(id1, id2)); @@ -32,7 +40,7 @@ describe("ByzantineRobustAggregator", () => { }); it("clips a single outlier with small radius", async () => { - const agg = new ByzantineRobustAggregator(0, 3, 'absolute', 1.0, 1, 0); + const agg = new ByzantineRobustAggregator(0, 3, "absolute", 1.0, 1, 0); const [c1, c2, bad] = ["c1", "c2", "bad"]; agg.setNodes(Set.of(c1, c2, bad)); @@ -47,7 +55,7 @@ describe("ByzantineRobustAggregator", () => { }); it("applies multiple clipping iterations (maxIterations > 1)", async () => { - const agg = new ByzantineRobustAggregator(0, 2, 'absolute', 1.0, 3, 0); + const agg = new ByzantineRobustAggregator(0, 2, "absolute", 1.0, 3, 0); const [c1, bad] = ["c1", "bad"]; agg.setNodes(Set.of(c1, bad)); @@ -61,7 +69,7 @@ describe("ByzantineRobustAggregator", () => { }); it("uses momentum when beta > 0", async () => { - const agg = new ByzantineRobustAggregator(0, 2, 'absolute', 1e6, 1, 0.5); + const agg = new ByzantineRobustAggregator(0, 2, "absolute", 1e6, 1, 0.5); const [c1, c2] = ["c1", "c2"]; agg.setNodes(Set.of(c1, c2)); @@ -83,7 +91,7 @@ describe("ByzantineRobustAggregator", () => { }); it("respects roundCutoff — ignores old contributions", async () => { - const agg = new ByzantineRobustAggregator(1, 1, 'absolute', 1e6, 1, 0); + const agg = new ByzantineRobustAggregator(1, 1, "absolute", 1e6, 1, 0); const id = "c1"; agg.setNodes(Set.of(id)); diff --git a/discojs/src/aggregator/byzantine.ts b/discojs/src/aggregator/byzantine.ts index 64d5cbe43..1e89dea31 100644 --- a/discojs/src/aggregator/byzantine.ts +++ b/discojs/src/aggregator/byzantine.ts @@ -1,5 +1,5 @@ import { Map } from "immutable"; -import * as tf from '@tensorflow/tfjs'; +import * as tf from "@tensorflow/tfjs"; import { AggregationStep } from "./aggregator.js"; import { MultiRoundAggregator, ThresholdType } from "./multiround.js"; import { WeightsContainer, client } from "../index.js"; @@ -8,7 +8,7 @@ import { aggregation } from "../index.js"; /** * Byzantine-robust aggregator using Centered Clipping (CC), based on the * "Learning from History for Byzantine Robust Optimization" paper: https://arxiv.org/abs/2012.10333 - * + * * This class implements a gradient aggregation rule that clips updates * in an iterative fashion to mitigate the influence of Byzantine nodes, as well as momentum calculations. */ @@ -39,13 +39,23 @@ export class ByzantineRobustAggregator extends MultiRoundAggregator { * - A higher beta gives more weight to past rounds (more smoothing), while a lower beta makes the aggregator more responsive to new updates. */ - - constructor(roundCutoff = 0, threshold = 1, thresholdType?: ThresholdType, clippingRadius = 1.0, maxIterations = 1, beta = 0.9) { + constructor( + roundCutoff = 0, + threshold = 1, + thresholdType?: ThresholdType, + clippingRadius = 1.0, + maxIterations = 1, + beta = 0.9, + ) { super(roundCutoff, threshold, thresholdType); - if (clippingRadius <= 0) throw new Error("Clipping radius needs to be positive number > 0."); - if (maxIterations < 1) throw new Error("There must be at least one iteration for clipping."); - if (!Number.isInteger(maxIterations)) throw new Error("Number of iterations must be an integer."); - if ((beta < 0) || (beta > 1)) throw new Error("Beta must be between 0 and 1, since it is coeficient."); + if (clippingRadius <= 0) + throw new Error("Clipping radius needs to be positive number > 0."); + if (maxIterations < 1) + throw new Error("There must be at least one iteration for clipping."); + if (!Number.isInteger(maxIterations)) + throw new Error("Number of iterations must be an integer."); + if (beta < 0 || beta > 1) + throw new Error("Beta must be between 0 and 1, since it is coeficient."); this.clippingRadius = clippingRadius; this.maxIterations = maxIterations; this.beta = beta; @@ -53,14 +63,18 @@ export class ByzantineRobustAggregator extends MultiRoundAggregator { override _add(nodeId: client.NodeID, contribution: WeightsContainer): void { this.log( - this.contributions.hasIn([0, nodeId]) ? AggregationStep.UPDATE : AggregationStep.ADD, + this.contributions.hasIn([0, nodeId]) + ? AggregationStep.UPDATE + : AggregationStep.ADD, nodeId, ); const prevMomentum = this.historyMomentums.get(nodeId); const newMomentum = prevMomentum - ? contribution.mapWith(prevMomentum, (g, m) => g.mul(1 - this.beta).add(m.mul(this.beta))) - : contribution; // no scaling on first momentum + ? contribution.mapWith(prevMomentum, (g, m) => + g.mul(1 - this.beta).add(m.mul(this.beta)), + ) + : contribution; // no scaling on first momentum this.historyMomentums = this.historyMomentums.set(nodeId, newMomentum); this.contributions = this.contributions.setIn([0, nodeId], newMomentum); @@ -68,7 +82,8 @@ export class ByzantineRobustAggregator extends MultiRoundAggregator { override aggregate(): WeightsContainer { const currentContributions = this.contributions.get(0); - if (!currentContributions) throw new Error("aggregating without any contribution"); + if (!currentContributions) + throw new Error("aggregating without any contribution"); this.log(AggregationStep.AGGREGATE); @@ -84,23 +99,31 @@ export class ByzantineRobustAggregator extends MultiRoundAggregator { } else { // Use shape of the first contribution to create zero vector const first = currentContributions.values().next(); - if (first.done) throw new Error("zero sized contribution") + if (first.done) throw new Error("zero sized contribution"); v = first.value.map((t: tf.Tensor) => tf.zerosLike(t)); } // Step 2: Iterative Centered Clipping for (let l = 0; l < this.maxIterations; l++) { - const clippedDiffs = Array.from(currentContributions.values()).map(m => { - const diff = m.sub(v); - const norm = tf.tidy(() => euclideanNorm(diff)); - const scale = tf.tidy(() => tf.minimum(tf.scalar(1), tf.div(tf.scalar(this.clippingRadius), norm))); - const clipped = diff.mul(scale); - norm.dispose(); scale.dispose(); - return clipped; - }); + const clippedDiffs = Array.from(currentContributions.values()).map( + (m) => { + const diff = m.sub(v); + const norm = tf.tidy(() => euclideanNorm(diff)); + const scale = tf.tidy(() => + tf.minimum( + tf.scalar(1), + tf.div(tf.scalar(this.clippingRadius), norm), + ), + ); + const clipped = diff.mul(scale); + norm.dispose(); + scale.dispose(); + return clipped; + }, + ); const avgClip = aggregation.avg(clippedDiffs); const newV = v.add(avgClip); - clippedDiffs.forEach(d => d.dispose()); + clippedDiffs.forEach((d) => d.dispose()); v.dispose(); // Safe if v is no longer needed v = newV; } @@ -109,8 +132,9 @@ export class ByzantineRobustAggregator extends MultiRoundAggregator { return v; } - - override makePayloads(weights: WeightsContainer): Map { + override makePayloads( + weights: WeightsContainer, + ): Map { // Communicate our local weights to every other node, be it a peer or a server return this.nodes.toMap().map(() => weights); } @@ -119,8 +143,8 @@ export class ByzantineRobustAggregator extends MultiRoundAggregator { function euclideanNorm(w: WeightsContainer): tf.Scalar { // Computes the Euclidean (L2) norm of all tensors in a WeightsContainer by summing the squares of their elements and taking the square root. return tf.tidy(() => { - const norms: tf.Scalar[] = w.weights.map(t => tf.sum(tf.square(t))); + const norms: tf.Scalar[] = w.weights.map((t) => tf.sum(tf.square(t))); const total = norms.reduce((a, b) => tf.add(a, b)); return tf.sqrt(total); }); -} \ No newline at end of file +} diff --git a/discojs/src/aggregator/get.ts b/discojs/src/aggregator/get.ts index a9ab81c27..ee53e1721 100644 --- a/discojs/src/aggregator/get.ts +++ b/discojs/src/aggregator/get.ts @@ -1,81 +1,83 @@ -import type { DataType, Network, Task } from '../index.js' -import { aggregator } from '../index.js' -import { ByzantineRobustAggregator } from './byzantine.js'; +import type { DataType, Network, Task } from "../index.js"; +import { aggregator } from "../index.js"; +import { ByzantineRobustAggregator } from "./byzantine.js"; type AggregatorOptions = Partial<{ scheme: Task["trainingInformation"]["scheme"]; // if undefined, fallback on task.trainingInformation.scheme - roundCutOff: number, // MeanAggregator - threshold: number, // MeanAggregator - thresholdType: 'relative' | 'absolute', // MeanAggregator -}> + roundCutOff: number; // MeanAggregator + threshold: number; // MeanAggregator + thresholdType: "relative" | "absolute"; // MeanAggregator +}>; /** * Initializes an aggregator according to the task definition, the training scheme and the aggregator parameters. * Here is the ordered list of parameters used to define the aggregator and its default behavior: * task.trainingInformation.aggregationStrategy > options.scheme > task.trainingInformation.scheme - * + * * If `task.trainingInformation.aggregationStrategy` is defined, we initialize the chosen aggregator with `options` parameter values. * Otherwise, we default to a MeanAggregator for both training schemes. - * + * * For the MeanAggregator we rely on `options.scheme` and fallback on `task.trainingInformation.scheme` to infer default values. * Unless specified otherwise, for federated learning or local training the aggregator default to waiting - * for a single contribution to trigger a model update. + * for a single contribution to trigger a model update. * (the server's model update for federated learning or our own contribution if training locally) * For decentralized learning the aggregator defaults to waiting for every nodes' contribution to trigger a model update. - * + * * @param task The task object associated with the current training session * @param options Options passed down to the aggregator's constructor * @returns The aggregator */ export function getAggregator( - task: Task, + task: Task, options: AggregatorOptions = {}, ): aggregator.Aggregator { - const scheme = options.scheme ?? task.trainingInformation.scheme + const scheme = options.scheme ?? task.trainingInformation.scheme; // If options are not specified, we default to expecting a contribution from all peers, so we set the threshold to 100% - // If scheme == 'federated' then we only expect the server's contribution at each round + // If scheme == 'federated' then we only expect the server's contribution at each round // so we set the aggregation threshold to 1 contribution // If scheme == 'local' then we only expect our own contribution const networkOptions: Required = { scheme, roundCutOff: 0, - threshold: 1, + threshold: 1, thresholdType: scheme === "decentralized" ? "relative" : "absolute", ...options, // user overrides defaults }; - + switch (task.trainingInformation.aggregationStrategy) { - case "byzantine": { - const { - clippingRadius = 1.0, - maxIterations = 1, - beta = 0.9, - } = task.trainingInformation.privacy.byzantineFaultTolerance; + case "byzantine": { + const { + clippingRadius = 1.0, + maxIterations = 1, + beta = 0.9, + } = task.trainingInformation.privacy.byzantineFaultTolerance; - return new ByzantineRobustAggregator( - networkOptions.roundCutOff, - networkOptions.threshold, - networkOptions.thresholdType, - clippingRadius, - maxIterations, - beta, - ); - } - case 'mean': + return new ByzantineRobustAggregator( + networkOptions.roundCutOff, + networkOptions.threshold, + networkOptions.thresholdType, + clippingRadius, + maxIterations, + beta, + ); + } + case "mean": return new aggregator.MeanAggregator( - networkOptions.roundCutOff, - networkOptions.threshold, - networkOptions.thresholdType - ) - case 'secure': - if (scheme !== 'decentralized') { - throw new Error('secure aggregation is currently supported for decentralized only') + networkOptions.roundCutOff, + networkOptions.threshold, + networkOptions.thresholdType, + ); + case "secure": + if (scheme !== "decentralized") { + throw new Error( + "secure aggregation is currently supported for decentralized only", + ); } return new aggregator.SecureAggregator( - task.trainingInformation.maxShareValue - ) + task.trainingInformation.maxShareValue, + ); } } diff --git a/discojs/src/aggregator/index.ts b/discojs/src/aggregator/index.ts index 6dca014f4..695c60d62 100644 --- a/discojs/src/aggregator/index.ts +++ b/discojs/src/aggregator/index.ts @@ -1,6 +1,6 @@ -export { Aggregator, AggregationStep } from './aggregator.js' -export { MeanAggregator } from './mean.js' -export { SecureAggregator } from './secure.js' -export { ByzantineRobustAggregator } from './byzantine.js' +export { Aggregator, AggregationStep } from "./aggregator.js"; +export { MeanAggregator } from "./mean.js"; +export { SecureAggregator } from "./secure.js"; +export { ByzantineRobustAggregator } from "./byzantine.js"; -export { getAggregator } from './get.js' \ No newline at end of file +export { getAggregator } from "./get.js"; diff --git a/discojs/src/aggregator/mean.spec.ts b/discojs/src/aggregator/mean.spec.ts index d15dcabed..c9eab512f 100644 --- a/discojs/src/aggregator/mean.spec.ts +++ b/discojs/src/aggregator/mean.spec.ts @@ -11,26 +11,28 @@ async function WSIntoArrays(ws: WeightsContainer): Promise { describe("mean aggregator", () => { it("updates only within round cutoff", async () => { - const aggregator = new MeanAggregator(1, 1, 'relative'); // use a round cutoff of 1 + const aggregator = new MeanAggregator(1, 1, "relative"); // use a round cutoff of 1 aggregator.setNodes(Set.of("client 1")); // round 0 - expect(aggregator.round).to.equal(0) + expect(aggregator.round).to.equal(0); expect(aggregator.isValidContribution("client 1", 0)).to.be.true; const client1Round0Promise = aggregator.getPromiseForAggregation(); aggregator.add("client 1", WeightsContainer.of([1]), 0); - expect(WeightsContainer.of([1]).equals(await client1Round0Promise)).to.be.true - expect(aggregator.round).to.equal(1) - + expect(WeightsContainer.of([1]).equals(await client1Round0Promise)).to.be + .true; + expect(aggregator.round).to.equal(1); + // round 1 aggregator.registerNode("client 2"); expect(aggregator.isValidContribution("client 2", 0)).to.be.true; // round 0 should be within the cutoff aggregator.add("client 1", WeightsContainer.of([1]), 1); - const client2Round0Promise = aggregator.getPromiseForAggregation(); + const client2Round0Promise = aggregator.getPromiseForAggregation(); aggregator.add("client 2", WeightsContainer.of([2]), 0); - expect(WeightsContainer.of([1.5]).equals(await client2Round0Promise)).to.be.true - expect(aggregator.round).to.equal(2) - + expect(WeightsContainer.of([1.5]).equals(await client2Round0Promise)).to.be + .true; + expect(aggregator.round).to.equal(2); + // round 2 aggregator.registerNode("client 3"); expect(aggregator.isValidContribution("client 3", 0)).to.be.false; // round 0 is now out of the cutoff @@ -39,13 +41,14 @@ describe("mean aggregator", () => { aggregator.add("client 2", WeightsContainer.of([1]), 2); const client3Round2Promise = aggregator.getPromiseForAggregation(); aggregator.add("client 3", WeightsContainer.of([4]), 1); - expect(WeightsContainer.of([2]).equals(await client3Round2Promise)).to.be.true - expect(aggregator.round).to.equal(3) + expect(WeightsContainer.of([2]).equals(await client3Round2Promise)).to.be + .true; + expect(aggregator.round).to.equal(3); }); it("returns the mean of the weights", async () => { - const aggregator = new MeanAggregator(0, 2, 'absolute'); - const [id1, id2] = ["client 1", "client 2"] + const aggregator = new MeanAggregator(0, 2, "absolute"); + const [id1, id2] = ["client 1", "client 2"]; aggregator.setNodes(Set.of(id1, id2)); @@ -57,33 +60,33 @@ describe("mean aggregator", () => { aggregator.add(id1, WeightsContainer.of([0], [1]), 0); const result2 = aggregator.getPromiseForAggregation(); aggregator.add(id2, WeightsContainer.of([2], [3]), 0); - expect((await result1).equals(await result2)).to.be.true + expect((await result1).equals(await result2)).to.be.true; expect(await WSIntoArrays(await results)).to.deep.equal([[1], [2]]); }); it("waits for 100% of the contributions by default", async () => { const aggregator = new MeanAggregator(); - const [id1, id2] = ["client 1", "client 2"] + const [id1, id2] = ["client 1", "client 2"]; aggregator.setNodes(Set.of(id1, id2)); const result1 = aggregator.getPromiseForAggregation(); aggregator.add(id1, WeightsContainer.of([0], [1]), 0); // Make sure that the aggregation isn't triggered - expect(aggregator.round).equals(0) - + expect(aggregator.round).equals(0); + aggregator.registerNode(id2); const result2 = aggregator.getPromiseForAggregation(); aggregator.add(id2, WeightsContainer.of([2], [3]), 0); - expect((await result1).equals(await result2)).to.be.true - expect(aggregator.round).equals(1) // round should be one now + expect((await result1).equals(await result2)).to.be.true; + expect(aggregator.round).equals(1); // round should be one now }); it("can wait for an absolute number of contributions", async () => { - const aggregator = new MeanAggregator(0, 1, 'absolute'); - const [id1, id2] = ["client 1", "client 2"] - aggregator.setNodes(Set.of(id1, id2)); // register two clients + const aggregator = new MeanAggregator(0, 1, "absolute"); + const [id1, id2] = ["client 1", "client 2"]; + aggregator.setNodes(Set.of(id1, id2)); // register two clients // should aggregate with only one contribution const result = aggregator.getPromiseForAggregation(); @@ -92,31 +95,31 @@ describe("mean aggregator", () => { }); it("can wait for an relative number of contributions", async () => { - const aggregator = new MeanAggregator(0, 0.5, 'relative'); - const [id1, id2] = ["client 1", "client 2"] - aggregator.setNodes(Set.of(id1, id2)); // register two clients + const aggregator = new MeanAggregator(0, 0.5, "relative"); + const [id1, id2] = ["client 1", "client 2"]; + aggregator.setNodes(Set.of(id1, id2)); // register two clients // should aggregate with only 50% of the contribution (1 contribution) const result = aggregator.getPromiseForAggregation(); aggregator.add(id1, WeightsContainer.of([0], [1]), 0); expect(await WSIntoArrays(await result)).to.deep.equal([[0], [1]]); }); - + it("doesn't aggregate when not enough participants", async () => { - const aggregator = new MeanAggregator(0, 1, 'absolute'); // only wait for a single participant - aggregator.minNbOfParticipants = 2 // However the task can specify another minimum number, here 2 - const [id1, id2] = ["client 1", "client 2"] + const aggregator = new MeanAggregator(0, 1, "absolute"); // only wait for a single participant + aggregator.minNbOfParticipants = 2; // However the task can specify another minimum number, here 2 + const [id1, id2] = ["client 1", "client 2"]; aggregator.setNodes(Set.of(id1)); - + const result1 = aggregator.getPromiseForAggregation(); aggregator.add(id1, WeightsContainer.of([0], [1]), 0); // Make sure that the aggregation isn't triggered - expect(aggregator.round).equals(0) - + expect(aggregator.round).equals(0); + aggregator.registerNode(id2); const result2 = aggregator.getPromiseForAggregation(); aggregator.add(id2, WeightsContainer.of([2], [3]), 0); - expect((await result1).equals(await result2)).to.be.true - expect(aggregator.round).equals(1) + expect((await result1).equals(await result2)).to.be.true; + expect(aggregator.round).equals(1); }); }); diff --git a/discojs/src/aggregator/mean.ts b/discojs/src/aggregator/mean.ts index eb24d370a..a4abba6c4 100644 --- a/discojs/src/aggregator/mean.ts +++ b/discojs/src/aggregator/mean.ts @@ -4,9 +4,9 @@ import { MultiRoundAggregator, ThresholdType } from "./multiround.js"; import type { WeightsContainer, client } from "../index.js"; import { aggregation } from "../index.js"; -/** - * Mean aggregator whose aggregation step consists in computing the mean of the received weights. - * +/** + * Mean aggregator whose aggregation step consists in computing the mean of the received weights. + * */ export class MeanAggregator extends MultiRoundAggregator { /** @@ -20,7 +20,9 @@ export class MeanAggregator extends MultiRoundAggregator { override _add(nodeId: client.NodeID, contribution: WeightsContainer): void { this.log( - this.contributions.hasIn([0, nodeId]) ? AggregationStep.UPDATE : AggregationStep.ADD, + this.contributions.hasIn([0, nodeId]) + ? AggregationStep.UPDATE + : AggregationStep.ADD, nodeId, ); this.contributions = this.contributions.setIn([0, nodeId], contribution); @@ -28,7 +30,8 @@ export class MeanAggregator extends MultiRoundAggregator { override aggregate(): WeightsContainer { const currentContributions = this.contributions.get(0); - if (!currentContributions) throw new Error("aggregating without any contribution"); + if (!currentContributions) + throw new Error("aggregating without any contribution"); this.log(AggregationStep.AGGREGATE); diff --git a/discojs/src/aggregator/multiround.ts b/discojs/src/aggregator/multiround.ts index ff6d4f021..4018b55dd 100644 --- a/discojs/src/aggregator/multiround.ts +++ b/discojs/src/aggregator/multiround.ts @@ -1,7 +1,7 @@ import { Aggregator } from "./aggregator.js"; import createDebug from "debug"; -export type ThresholdType = 'relative' | 'absolute'; +export type ThresholdType = "relative" | "absolute"; const debug = createDebug("discojs:aggregator:multiround"); @@ -19,25 +19,25 @@ export abstract class MultiRoundAggregator extends Aggregator { * Abstract class of a multi-round aggregator that wait for a certain number of contributions before aggregating * By default, initializes an aggregator that waits for 100% of the nodes' contributions and that * only accepts contributions from the current round (drops contributions from previous rounds). - * + * * @param threshold - how many contributions trigger an aggregation step. - * It can be relative (a proportion): 0 < t <= 1, requiring t * |nodes| contributions. + * It can be relative (a proportion): 0 < t <= 1, requiring t * |nodes| contributions. * Important: to specify 100% of the nodes, set `threshold = 1` and `thresholdType = 'relative'`. * It can be an absolute number, if t >=1 (then t has to be an integer), the aggregator waits fot t contributions * Note, to specify waiting for a single contribution (such as a federated client only waiting for the server weight update), * set `threshold = 1` and `thresholdType = 'absolute'` - * @param thresholdType 'relative' or 'absolute', defaults to 'relative'. Is only used to clarify the case when threshold = 1, + * @param thresholdType 'relative' or 'absolute', defaults to 'relative'. Is only used to clarify the case when threshold = 1, * If `threshold != 1` then the specified thresholdType is ignored and overwritten * If `thresholdType = 'absolute'` then `threshold = 1` means waiting for 1 contribution - * if `thresholdType = 'relative'` then `threshold = 1`` means 100% of this.nodes' contributions, - * @param roundCutoff - from how many past rounds do we still accept contributions. - * If 0 then only accept contributions from the current round, + * if `thresholdType = 'relative'` then `threshold = 1`` means 100% of this.nodes' contributions, + * @param roundCutoff - from how many past rounds do we still accept contributions. + * If 0 then only accept contributions from the current round, * if 1 then the current round and the previous one, etc. */ constructor(roundCutoff = 0, threshold = 1, thresholdType?: ThresholdType) { if (threshold <= 0) throw new Error("threshold must be strictly positive"); - if (threshold > 1 && (!Number.isInteger(threshold))) + if (threshold > 1 && !Number.isInteger(threshold)) throw new Error("absolute thresholds must be integral"); super(roundCutoff, 1); @@ -45,17 +45,20 @@ export abstract class MultiRoundAggregator extends Aggregator { if (threshold < 1) { // Throw exception if threshold and thresholdType are conflicting - if (thresholdType === 'absolute') { - throw new Error(`thresholdType has been set to 'absolute' but choosing threshold=${threshold} implies that thresholdType should be 'relative'.`) + if (thresholdType === "absolute") { + throw new Error( + `thresholdType has been set to 'absolute' but choosing threshold=${threshold} implies that thresholdType should be 'relative'.`, + ); } - this.#thresholdType = 'relative' - } - else if (threshold > 1) { + this.#thresholdType = "relative"; + } else if (threshold > 1) { // Throw exception if threshold and thresholdType are conflicting - if (thresholdType === 'relative') { - throw new Error(`thresholdType has been set to 'relative' but choosing threshold=${threshold} implies that thresholdType should be 'absolute'.`) + if (thresholdType === "relative") { + throw new Error( + `thresholdType has been set to 'relative' but choosing threshold=${threshold} implies that thresholdType should be 'absolute'.`, + ); } - this.#thresholdType = 'absolute' + this.#thresholdType = "absolute"; } // remaining case: threshold == 1 else { @@ -64,11 +67,11 @@ export abstract class MultiRoundAggregator extends Aggregator { // TODO enforce validity by splitting the different threshold types into separate classes instead of warning debug( "[WARN] Setting the aggregator's threshold to 100% of the nodes' contributions by default. " + - "To instead wait for a single contribution, set thresholdType = 'absolute'" - ) - this.#thresholdType = 'relative' + "To instead wait for a single contribution, set thresholdType = 'absolute'", + ); + this.#thresholdType = "relative"; } else { - this.#thresholdType = thresholdType + this.#thresholdType = thresholdType; } } } @@ -77,11 +80,14 @@ export abstract class MultiRoundAggregator extends Aggregator { override isFull(): boolean { // Make sure that we are over the minimum number of participants // if specified - if (this.#minNbOfParticipants !== undefined && - this.nodes.size < this.#minNbOfParticipants) return false; + if ( + this.#minNbOfParticipants !== undefined && + this.nodes.size < this.#minNbOfParticipants + ) + return false; const thresholdValue = - this.#thresholdType == 'relative' + this.#thresholdType == "relative" ? this.#threshold * this.nodes.size : this.#threshold; diff --git a/discojs/src/aggregator/secure.spec.ts b/discojs/src/aggregator/secure.spec.ts index 41ba04774..078658abe 100644 --- a/discojs/src/aggregator/secure.spec.ts +++ b/discojs/src/aggregator/secure.spec.ts @@ -54,8 +54,8 @@ describe("secret shares test", () => { describe("secure aggregator", () => { it("behaves as mean aggregator", async () => { - const secureNetwork = setupNetwork(SecureAggregator) - const meanNetwork = setupNetwork(MeanAggregator) // waits for 100% of the nodes' contributions by default + const secureNetwork = setupNetwork(SecureAggregator); + const meanNetwork = setupNetwork(MeanAggregator); // waits for 100% of the nodes' contributions by default const meanResults = await communicate( Map( @@ -63,7 +63,8 @@ describe("secure aggregator", () => { .entrySeq() .zip(Range(0, 3)) .map(([[id, agg], i]) => [id, [agg, WeightsContainer.of([i])]]), - ), 0 + ), + 0, ); const secureResults = await communicate( Map( @@ -71,21 +72,22 @@ describe("secure aggregator", () => { .entrySeq() .zip(Range(0, 3)) .map(([[id, agg], i]) => [id, [agg, WeightsContainer.of([i])]]), - ), 0 + ), + 0, ); - // biome-ignore lint/correctness/noFlatMapIdentity: .flatten convert to Collection and zipAll is picky - for (const [secure, mean] of List( - await Promise.all(secureResults.sort().valueSeq().map(wsIntoArrays)), - ) - .flatMap((x) => x) - .flatMap((x) => x) - .zipAll( - // biome-ignore lint/correctness/noFlatMapIdentity: .flatten convert to Collection and zipAll is picky - List(await Promise.all(meanResults.sort().valueSeq().map(wsIntoArrays))) - .flatMap((x) => x) - .flatMap((x) => x), - )) - expect(secure).to.be.closeTo(mean, 0.001); + // biome-ignore lint/correctness/noFlatMapIdentity: .flatten convert to Collection and zipAll is picky + for (const [secure, mean] of List( + await Promise.all(secureResults.sort().valueSeq().map(wsIntoArrays)), + ) + .flatMap((x) => x) + .flatMap((x) => x) + .zipAll( + // biome-ignore lint/correctness/noFlatMapIdentity: .flatten convert to Collection and zipAll is picky + List(await Promise.all(meanResults.sort().valueSeq().map(wsIntoArrays))) + .flatMap((x) => x) + .flatMap((x) => x), + )) + expect(secure).to.be.closeTo(mean, 0.001); }); }); diff --git a/discojs/src/aggregator/secure.ts b/discojs/src/aggregator/secure.ts index 064454e62..c20221311 100644 --- a/discojs/src/aggregator/secure.ts +++ b/discojs/src/aggregator/secure.ts @@ -56,8 +56,9 @@ export class SecureAggregator extends Aggregator { } this.log( - this.contributions.hasIn([communicationRound, nodeId]) ? - AggregationStep.UPDATE : AggregationStep.ADD, + this.contributions.hasIn([communicationRound, nodeId]) + ? AggregationStep.UPDATE + : AggregationStep.ADD, nodeId.slice(0, 4), ); diff --git a/discojs/src/aggregator/secure_history.spec.ts b/discojs/src/aggregator/secure_history.spec.ts index aa9a56776..3426ad066 100644 --- a/discojs/src/aggregator/secure_history.spec.ts +++ b/discojs/src/aggregator/secure_history.spec.ts @@ -3,10 +3,7 @@ import { describe, expect, it, assert } from "vitest"; import * as tf from "@tensorflow/tfjs"; -import { - aggregation, - WeightsContainer, -} from "../index.js"; +import { aggregation, WeightsContainer } from "../index.js"; import { SecureHistoryAggregator } from "./secure_history.js"; import { SecureAggregator } from "./secure.js"; @@ -14,140 +11,152 @@ import { SecureAggregator } from "./secure.js"; import { wsIntoArrays, communicate, setupNetwork } from "../aggregator.spec.js"; describe("Secure history aggregator", function () { - const epsilon = 1e-4; - - const secrets = List.of( - WeightsContainer.of([1, 2, 3, -1], [-5, 6]), - WeightsContainer.of([2, 3, 7, 1], [-10, 5]), - WeightsContainer.of([3, 1, 5, 3], [-15, 19]), + const epsilon = 1e-4; + + const secrets = List.of( + WeightsContainer.of([1, 2, 3, -1], [-5, 6]), + WeightsContainer.of([2, 3, 7, 1], [-10, 5]), + WeightsContainer.of([3, 1, 5, 3], [-15, 19]), + ); + + function buildShares(): List> { + const nodes = Set(secrets.keys()).map(String); + return secrets.map((secret) => { + const aggregator = new SecureHistoryAggregator(); + aggregator.setNodes(nodes); + return aggregator.generateAllShares(secret); + }); + } + + it("recovers secrets from shares", () => { + const recovered = buildShares().map((shares) => aggregation.sum(shares)); + assert.isTrue( + ( + recovered.zip(secrets) as List<[WeightsContainer, WeightsContainer]> + ).every(([actual, expected]) => actual.equals(expected, epsilon)), ); + }); - function buildShares(): List> { - const nodes = Set(secrets.keys()).map(String); - return secrets.map((secret) => { - const aggregator = new SecureHistoryAggregator(); - aggregator.setNodes(nodes); - return aggregator.generateAllShares(secret); - }); - } + it("aggregates partial sums with momentum smoothing", async () => { + const aggregator = new SecureHistoryAggregator(100, 0.8); + const nodes = Set(secrets.keys()).map(String); + aggregator.setNodes(nodes); - it("recovers secrets from shares", () => { - const recovered = buildShares().map((shares) => aggregation.sum(shares)); - assert.isTrue( - ( - recovered.zip(secrets) as List<[WeightsContainer, WeightsContainer]> - ).every(([actual, expected]) => actual.equals(expected, epsilon)), + // Prepare to capture aggregation result + const aggregationPromise = aggregator.getPromiseForAggregation(); + + const sharesRound0 = buildShares(); + + const partialSums = Range(0, nodes.size) + .map((receiverIdx) => { + const receivedShares = sharesRound0.map( + (shares) => shares.get(receiverIdx)!, ); + return aggregation.sum(receivedShares); + }) + .toList(); + + // Add one total contribution per node + partialSums.forEach((partialSum, idx) => { + const nodeId = idx.toString(); + aggregator.add(nodeId, partialSum, 0); }); - it("aggregates partial sums with momentum smoothing", async () => { - const aggregator = new SecureHistoryAggregator(100, 0.8); - const nodes = Set(secrets.keys()).map(String); - aggregator.setNodes(nodes); + const sumRound0 = await aggregationPromise; - // Prepare to capture aggregation result - const aggregationPromise = aggregator.getPromiseForAggregation(); + const expectedSum = aggregation.sum( + sharesRound0.flatMap((x) => x), // flatten to List + ); + expect(sumRound0.equals(expectedSum, epsilon)).to.be.true; - const sharesRound0 = buildShares(); + // simulate second communication round partial sums + const aggregationPromise2 = aggregator.getPromiseForAggregation(); - const partialSums = Range(0, nodes.size).map((receiverIdx) => { - const receivedShares = sharesRound0.map(shares => shares.get(receiverIdx)!); - return aggregation.sum(receivedShares); - }).toList(); + partialSums.forEach((partialSum, idx) => { + const nodeId = idx.toString(); + aggregator.add(nodeId, partialSum, 0); + }); + const sumRound1 = await aggregationPromise2; - // Add one total contribution per node - partialSums.forEach((partialSum, idx) => { - const nodeId = idx.toString(); - aggregator.add(nodeId, partialSum, 0); - }); + // First aggregation with momentum - no previous momentum, so just average + const avgPartialSum = aggregation.avg(partialSums); + expect(sumRound1.equals(avgPartialSum, epsilon)).to.be.true; - const sumRound0 = await aggregationPromise; + // Now we simulate a second round of aggregation with momentum smoothing + const dummyPromise = aggregator.getPromiseForAggregation(); + partialSums.forEach((partialSum, idx) => { + const nodeId = idx.toString(); + aggregator.add(nodeId, partialSum, 1); // round 0 of next aggregation round + }); + await dummyPromise; - const expectedSum = aggregation.sum( - sharesRound0.flatMap(x => x) // flatten to List - ); - expect(sumRound0.equals(expectedSum, epsilon)).to.be.true; - - - // simulate second communication round partial sums - const aggregationPromise2 = aggregator.getPromiseForAggregation(); - - partialSums.forEach((partialSum, idx) => { - const nodeId = idx.toString(); - aggregator.add(nodeId, partialSum, 0); - }); - const sumRound1 = await aggregationPromise2; - - // First aggregation with momentum - no previous momentum, so just average - const avgPartialSum = aggregation.avg(partialSums); - expect(sumRound1.equals(avgPartialSum, epsilon)).to.be.true; - - // Now we simulate a second round of aggregation with momentum smoothing - const dummyPromise = aggregator.getPromiseForAggregation(); - partialSums.forEach((partialSum, idx) => { - const nodeId = idx.toString(); - aggregator.add(nodeId, partialSum, 1); // round 0 of next aggregation round - }); - await dummyPromise; - - const aggregationPromise3 = aggregator.getPromiseForAggregation(); - // Add another set of partial sums with slight modification - const partialSums2 = partialSums.map(ws => - ws.map((tensor) => tf.mul(tensor, 1.1)) - ); + const aggregationPromise3 = aggregator.getPromiseForAggregation(); + // Add another set of partial sums with slight modification + const partialSums2 = partialSums.map((ws) => + ws.map((tensor) => tf.mul(tensor, 1.1)), + ); - // Add the modified partial sums to the aggregator - partialSums2.forEach((partialSum, idx) => { - const nodeId = idx.toString(); - aggregator.add(nodeId, partialSum, 1); - }); - const sumRound2 = await aggregationPromise3; + // Add the modified partial sums to the aggregator + partialSums2.forEach((partialSum, idx) => { + const nodeId = idx.toString(); + aggregator.add(nodeId, partialSum, 1); + }); + const sumRound2 = await aggregationPromise3; - const avgPartialSum2 = aggregation.avg(partialSums2); - const expectedSumRound2 = avgPartialSum.mapWith(avgPartialSum2, (prev, curr) => - prev.mul(0.8).add(curr.mul(0.2)) // 0.8 = beta, 0.2 = (1 - beta) - ); + const avgPartialSum2 = aggregation.avg(partialSums2); + const expectedSumRound2 = avgPartialSum.mapWith( + avgPartialSum2, + (prev, curr) => prev.mul(0.8).add(curr.mul(0.2)), // 0.8 = beta, 0.2 = (1 - beta) + ); - // Compare the actual result to the expected smoothed result using momentum - expect(sumRound2.equals(expectedSumRound2, 1e-3)).to.be.true; - }); + // Compare the actual result to the expected smoothed result using momentum + expect(sumRound2.equals(expectedSumRound2, 1e-3)).to.be.true; + }); - it("behaves similar to SecureAggregator without momentum (beta=0)", async () => { - class TestSecureHistoryAggregator extends SecureHistoryAggregator { - constructor() { - super(0, 0); // beta=0 disables momentum smoothing - } - } - const secureHistoryNetwork = setupNetwork(TestSecureHistoryAggregator); // beta=0 disables momentum smoothing - const secureNetwork = setupNetwork(SecureAggregator); - - const secureHistoryResults = await communicate( - Map( - secureHistoryNetwork - .entrySeq() - .zip(Range(0, 3)) - .map(([[id, agg], i]) => [id, [agg, WeightsContainer.of([i])]]), - ), - 0, - ); - const secureResults = await communicate( - Map( - secureNetwork - .entrySeq() - .zip(Range(0, 3)) - .map(([[id, agg], i]) => [id, [agg, WeightsContainer.of([i])]]), - ), - 0, - ); + it("behaves similar to SecureAggregator without momentum (beta=0)", async () => { + class TestSecureHistoryAggregator extends SecureHistoryAggregator { + constructor() { + super(0, 0); // beta=0 disables momentum smoothing + } + } + const secureHistoryNetwork = setupNetwork(TestSecureHistoryAggregator); // beta=0 disables momentum smoothing + const secureNetwork = setupNetwork(SecureAggregator); + + const secureHistoryResults = await communicate( + Map( + secureHistoryNetwork + .entrySeq() + .zip(Range(0, 3)) + .map(([[id, agg], i]) => [id, [agg, WeightsContainer.of([i])]]), + ), + 0, + ); + const secureResults = await communicate( + Map( + secureNetwork + .entrySeq() + .zip(Range(0, 3)) + .map(([[id, agg], i]) => [id, [agg, WeightsContainer.of([i])]]), + ), + 0, + ); - List(await Promise.all(secureHistoryResults.sort().valueSeq().map(wsIntoArrays))) - .flatMap((x) => x) - .flatMap((x) => x) - .zipAll( - List(await Promise.all(secureResults.sort().valueSeq().map(wsIntoArrays))) - .flatMap((x) => x) - .flatMap((x) => x), - ) - .forEach(([secureHistory, secure]) => expect(secureHistory).to.be.closeTo(secure, 0.001)); - }); + List( + await Promise.all( + secureHistoryResults.sort().valueSeq().map(wsIntoArrays), + ), + ) + .flatMap((x) => x) + .flatMap((x) => x) + .zipAll( + List( + await Promise.all(secureResults.sort().valueSeq().map(wsIntoArrays)), + ) + .flatMap((x) => x) + .flatMap((x) => x), + ) + .forEach(([secureHistory, secure]) => + expect(secureHistory).to.be.closeTo(secure, 0.001), + ); + }); }); diff --git a/discojs/src/aggregator/secure_history.ts b/discojs/src/aggregator/secure_history.ts index a24b5acab..71c77b340 100644 --- a/discojs/src/aggregator/secure_history.ts +++ b/discojs/src/aggregator/secure_history.ts @@ -1,5 +1,5 @@ import type { WeightsContainer } from "../index.js"; -import { SecureAggregator } from "./secure.js"; +import { SecureAggregator } from "./secure.js"; import { aggregation } from "../index.js"; /** @@ -40,7 +40,8 @@ export class SecureHistoryAggregator extends SecureAggregator { // For communication round 1, do average + momentum smoothing const currentContributions = this.contributions.get(1); - if (!currentContributions) throw new Error("aggregating without any contribution"); + if (!currentContributions) + throw new Error("aggregating without any contribution"); const avg = aggregation.avg(currentContributions.values()); @@ -50,11 +51,11 @@ export class SecureHistoryAggregator extends SecureAggregator { } const updatedMomentum = this.prevAggregate.mapWith(avg, (prevT, currT) => - prevT.mul(this.beta).add(currT.mul(1 - this.beta)) + prevT.mul(this.beta).add(currT.mul(1 - this.beta)), ); // Dispose old tensors to avoid memory leaks - this.prevAggregate.weights.forEach(t => t.dispose()); + this.prevAggregate.weights.forEach((t) => t.dispose()); this.prevAggregate = updatedMomentum; return updatedMomentum; diff --git a/discojs/src/client/client.ts b/discojs/src/client/client.ts index 9e1298b93..038d8ffad 100644 --- a/discojs/src/client/client.ts +++ b/discojs/src/client/client.ts @@ -8,11 +8,11 @@ import type { Task, WeightsContainer, } from "../index.js"; -import { serialization } from '../index.js' -import type { NodeID } from './types.js' -import type { EventConnection } from './event_connection.js' -import type { Aggregator } from '../aggregator/index.js' -import { EventEmitter } from '../utils/event_emitter.js' +import { serialization } from "../index.js"; +import type { NodeID } from "./types.js"; +import type { EventConnection } from "./event_connection.js"; +import type { Aggregator } from "../aggregator/index.js"; +import { EventEmitter } from "../utils/event_emitter.js"; import { type } from "./messages.js"; const debug = createDebug("discojs:client"); @@ -22,15 +22,15 @@ const debug = createDebug("discojs:client"); * communication with other nodes, be it peers or a server. */ export abstract class Client extends EventEmitter<{ - status: RoundStatus; - participants: number; + status: RoundStatus; + participants: number; }> { // Own ID provided by the network's server. - protected _ownId?: NodeID + protected _ownId?: NodeID; // The network's server. - protected _server?: EventConnection + protected _server?: EventConnection; // The aggregator's result produced after aggregation. - protected aggregationResult?: Promise + protected aggregationResult?: Promise; /** * When the server notifies clients to pause and wait until more * participants join, we rely on this promise to wait @@ -49,12 +49,12 @@ export abstract class Client extends EventEmitter<{ // Current number of participants including this client in the training session #nbOfParticipants: number = 1; - constructor ( + constructor( public readonly url: URL, // The network server's URL to connect to public readonly task: Task, // The client's corresponding task public readonly aggregator: Aggregator, ) { - super() + super(); } /** @@ -67,7 +67,9 @@ export abstract class Client extends EventEmitter<{ * @param weights The local weight update resulting for the current local training round * @returns aggregated weights or the local weights upon error */ - abstract onRoundEndCommunication(weights: WeightsContainer): Promise; + abstract onRoundEndCommunication( + weights: WeightsContainer, + ): Promise; /** * Handles the connection process from the client to any sort of network server. @@ -75,61 +77,65 @@ export abstract class Client extends EventEmitter<{ * By default, it fetches and returns the server's base model */ async connect(): Promise> { - return this.getLatestModel() + return this.getLatestModel(); } /** * Handles the disconnection process of the client from any sort of network server. */ - async disconnect(): Promise { } + async disconnect(): Promise {} /** - * Emits the round status specified. It also stores the status emitted such that + * Emits the round status specified. It also stores the status emitted such that * if the server tells the client to wait for more participants, it can display * the waiting status and once enough participants join, it can display the previous status again - */ + */ protected saveAndEmit(status: RoundStatus) { - this.#previousStatus = status - this.emit("status", status) + this.#previousStatus = status; + this.emit("status", status); } - + /** * For both federated and decentralized clients, we listen to the server to tell * us whether there are enough participants to train. If not, we pause until further notice. - * When a client connects to the server, the server answers with the session information (id, - * number of participants) and whether there are enough participants. - * When there are the server sends a new EnoughParticipant message to update the client. - * + * When a client connects to the server, the server answers with the session information (id, + * number of participants) and whether there are enough participants. + * When there are the server sends a new EnoughParticipant message to update the client. + * * `setMessageInversionFlag` is used to address the following scenario: * 1. Client 1 connect to the server * 2. Server answers with message A containing "not enough participants" * 3. Before A arrives a new client joins. There are enough participants now. * 4. Server updates client 1 with message B saying "there are enough participants" - * 5. Due to network and message sizes message B can arrive before A. + * 5. Due to network and message sizes message B can arrive before A. * i.e. "there are enough participants" arrives before "not enough participants" * ending up with client 1 thinking it needs to wait for more participants. - * - * To keep track of this message inversion, `setMessageInversionFlag` + * + * To keep track of this message inversion, `setMessageInversionFlag` * tells us whether a message inversion occurred (by setting a boolean to true) - * + * * @param setMessageInversionFlag function flagging whether a message inversion occurred * between a NewNodeInfo message and an EnoughParticipant message. */ protected setupServerCallbacks(setMessageInversionFlag: () => void) { - // Setup an event callback if the server signals that we should + // Setup an event callback if the server signals that we should // wait for more participants this.server.on(type.WaitingForMoreParticipants, (event) => { if (this.promiseForMoreParticipants !== undefined) - throw new Error("Server sent multiple WaitingForMoreParticipants messages") - debug(`[${shortenId(this.ownId)}] received WaitingForMoreParticipants message from server`) + throw new Error( + "Server sent multiple WaitingForMoreParticipants messages", + ); + debug( + `[${shortenId(this.ownId)}] received WaitingForMoreParticipants message from server`, + ); // Display the waiting status right away - this.emit("status", "not enough participants") - this.nbOfParticipants = event.nbOfParticipants // emits the `participants` event + this.emit("status", "not enough participants"); + this.nbOfParticipants = event.nbOfParticipants; // emits the `participants` event // Upon receiving a WaitingForMoreParticipants message, // the client will await for this promise to resolve before sending its // local weight update - this.promiseForMoreParticipants = this.createPromiseForMoreParticipants() - }) + this.promiseForMoreParticipants = this.createPromiseForMoreParticipants(); + }); // As an example assume we need at least 2 participants to train, // When two participants join almost at the same time, the server @@ -139,15 +145,15 @@ export abstract class Client extends EventEmitter<{ // so we check whether we received the EnoughParticipants before being assigned a node ID this.server.once(type.EnoughParticipants, (event) => { if (this._ownId === undefined) { - setMessageInversionFlag() - this.nbOfParticipants = event.nbOfParticipants + setMessageInversionFlag(); + this.nbOfParticipants = event.nbOfParticipants; } - }) + }); } /** - * Method called when the server notifies the client that there aren't enough + * Method called when the server notifies the client that there aren't enough * participants (anymore) to start/continue training - * The method creates a promise that will resolve once the server notifies + * The method creates a promise that will resolve once the server notifies * the client that the training can resume via a subsequent EnoughParticipants message * @returns a promise which resolves when enough participants joined the session */ @@ -155,52 +161,57 @@ export abstract class Client extends EventEmitter<{ return new Promise((resolve) => { // "once" is important because we can't resolve the same promise multiple times this.server.once(type.EnoughParticipants, (event) => { - debug(`[${shortenId(this.ownId)}] received EnoughParticipants message from server`) + debug( + `[${shortenId(this.ownId)}] received EnoughParticipants message from server`, + ); // Emit the last status emitted before waiting if defined - if (this.#previousStatus !== undefined) this.emit("status", this.#previousStatus) - this.nbOfParticipants = event.nbOfParticipants - resolve() - }) - }) + if (this.#previousStatus !== undefined) + this.emit("status", this.#previousStatus); + this.nbOfParticipants = event.nbOfParticipants; + resolve(); + }); + }); } - protected async waitForParticipantsIfNeeded(): Promise{ + protected async waitForParticipantsIfNeeded(): Promise { // we check if we are waiting for more participants before sending our weight update if (this.waitingForMoreParticipants) { // wait for the promise to resolve, which takes as long as it takes for new participants to join - debug(`[${shortenId(this.ownId)}] is awaiting the promise for more participants`) - this.emit("status", "not enough participants") - await this.promiseForMoreParticipants + debug( + `[${shortenId(this.ownId)}] is awaiting the promise for more participants`, + ); + this.emit("status", "not enough participants"); + await this.promiseForMoreParticipants; // Make sure to set the promise back to undefined once resolved - this.promiseForMoreParticipants = undefined + this.promiseForMoreParticipants = undefined; } } /** * Fetches the latest model available on the network's server, for the adequate task. * @returns The latest model */ - async getLatestModel (): Promise> { - const url = new URL('', this.url.href) - if (!url.pathname.endsWith('/')) { - url.pathname += '/' + async getLatestModel(): Promise> { + const url = new URL("", this.url.href); + if (!url.pathname.endsWith("/")) { + url.pathname += "/"; } - url.pathname += `tasks/${this.task.id}/model.json` + url.pathname += `tasks/${this.task.id}/model.json`; const response = await fetch(url); if (!response.ok) throw new Error(`fetch: HTTP status ${response.status}`); - const encoded = new Uint8Array(await response.arrayBuffer()) - return await serialization.model.decode(encoded) + const encoded = new Uint8Array(await response.arrayBuffer()); + return await serialization.model.decode(encoded); } /** - * Number of contributors to a collaborative session - * If decentralized, it should be the number of peers - * If federated, it should the number of participants excluding the server - * If local it should be 1 - */ + * Number of contributors to a collaborative session + * If decentralized, it should be the number of peers + * If federated, it should the number of participants excluding the server + * If local it should be 1 + */ public get nbOfParticipants(): number { - return this.#nbOfParticipants + return this.#nbOfParticipants; } /** @@ -208,33 +219,32 @@ export abstract class Client extends EventEmitter<{ * It emits the number of participants to the client */ public set nbOfParticipants(nbOfParticipants: number) { - this.#nbOfParticipants = nbOfParticipants - this.emit("participants", nbOfParticipants) + this.#nbOfParticipants = nbOfParticipants; + this.emit("participants", nbOfParticipants); } get ownId(): NodeID { if (this._ownId === undefined) { - throw new Error('the node is not connected') + throw new Error("the node is not connected"); } - return this._ownId + return this._ownId; } - get server (): EventConnection { + get server(): EventConnection { if (this._server === undefined) { - throw new Error('server undefined, not connected') + throw new Error("server undefined, not connected"); } - return this._server + return this._server; } /** * Whether the client should wait until more * participants join the session, i.e. a promise has been created */ get waitingForMoreParticipants(): boolean { - return this.promiseForMoreParticipants !== undefined + return this.promiseForMoreParticipants !== undefined; } - } export function shortenId(id: string): string { - return id.slice(0, 4) + return id.slice(0, 4); } diff --git a/discojs/src/client/decentralized/decentralized_client.ts b/discojs/src/client/decentralized/decentralized_client.ts index 6f9da6e77..474968102 100644 --- a/discojs/src/client/decentralized/decentralized_client.ts +++ b/discojs/src/client/decentralized/decentralized_client.ts @@ -1,15 +1,20 @@ import createDebug from "debug"; -import { Map, Set } from 'immutable' +import { Map, Set } from "immutable"; import type { DataType, Model, WeightsContainer } from "../../index.js"; import { serialization } from "../../index.js"; -import { Client, shortenId } from '../client.js' -import { type NodeID } from '../index.js' -import { type, type ClientConnected } from '../messages.js' -import { timeout } from '../utils.js' -import { WebSocketServer, waitMessage, type PeerConnection, waitMessageWithTimeout } from '../event_connection.js' -import { PeerPool } from './peer_pool.js' -import * as messages from './messages.js' +import { Client, shortenId } from "../client.js"; +import { type NodeID } from "../index.js"; +import { type, type ClientConnected } from "../messages.js"; +import { timeout } from "../utils.js"; +import { + WebSocketServer, + waitMessage, + type PeerConnection, + waitMessageWithTimeout, +} from "../event_connection.js"; +import { PeerPool } from "./peer_pool.js"; +import * as messages from "./messages.js"; const debug = createDebug("discojs:client:decentralized"); @@ -23,66 +28,73 @@ export class DecentralizedClient extends Client<"decentralized"> { /** * The pool of peers to communicate with during the current training round. */ - #pool?: PeerPool - #connections?: Map + #pool?: PeerPool; + #connections?: Map; // Used to handle timeouts and promise resolving after calling disconnect - private get isDisconnected() : boolean { - return this._server === undefined + private get isDisconnected(): boolean { + return this._server === undefined; } private setAggregatorNodes(nodes: Set) { - this.aggregator.setNodes(nodes) + this.aggregator.setNodes(nodes); // Emits the `participants` event - this.nbOfParticipants = this.aggregator.nodes.size === 0 ? 1 : this.aggregator.nodes.size + this.nbOfParticipants = + this.aggregator.nodes.size === 0 ? 1 : this.aggregator.nodes.size; } - + /** * Public method called by disco.ts when starting training. This method sends * a message to the server asking to join the task and be assigned a client ID. - * - * The peer also establishes a WebSocket connection with the server to then + * + * The peer also establishes a WebSocket connection with the server to then * create peer-to-peer WebRTC connections with peers. The server is used to exchange * peers network information. */ override async connect(): Promise> { - const model = await super.connect() // Get the server base model - const serverURL = new URL('', this.url.href) + const model = await super.connect(); // Get the server base model + const serverURL = new URL("", this.url.href); switch (this.url.protocol) { - case 'http:': - serverURL.protocol = 'ws:' - break - case 'https:': - serverURL.protocol = 'wss:' - break + case "http:": + serverURL.protocol = "ws:"; + break; + case "https:": + serverURL.protocol = "wss:"; + break; default: - throw new Error(`unknown protocol: ${this.url.protocol}`) + throw new Error(`unknown protocol: ${this.url.protocol}`); } - serverURL.pathname += `decentralized/${this.task.id}` + serverURL.pathname += `decentralized/${this.task.id}`; // Create a WebSocket connection with the server // The client then waits for the server to forward it other client's network information. // Upon receiving other peer's information, the clients establish a peer-to-peer WebRTC connection. - this._server = await WebSocketServer.connect(serverURL, messages.isMessageFromServer, messages.isMessageToServer) + this._server = await WebSocketServer.connect( + serverURL, + messages.isMessageFromServer, + messages.isMessageToServer, + ); this.server.on(type.SignalForPeer, (event) => { - if (this.#pool === undefined) throw new Error('received signal but peer pool is undefined') + if (this.#pool === undefined) + throw new Error("received signal but peer pool is undefined"); // Create a WebRTC connection with the peer - this.#pool.signal(event.peer, event.signal) - }) + this.#pool.signal(event.peer, event.signal); + }); // c.f. setupServerCallbacks doc for explanation - let receivedEnoughParticipants = false - this.setupServerCallbacks(() => receivedEnoughParticipants = true) - + let receivedEnoughParticipants = false; + this.setupServerCallbacks(() => (receivedEnoughParticipants = true)); + const msg: ClientConnected = { - type: type.ClientConnected - } - this.server.send(msg) - - const { id, waitForMoreParticipants, - nbOfParticipants } = await waitMessage(this.server, type.NewDecentralizedNodeInfo) - - this.nbOfParticipants = nbOfParticipants - + type: type.ClientConnected, + }; + this.server.send(msg); + + const { id, waitForMoreParticipants, nbOfParticipants } = await waitMessage( + this.server, + type.NewDecentralizedNodeInfo, + ); + + this.nbOfParticipants = nbOfParticipants; // This should come right after receiving the message to make sure // we don't miss a subsequent message from the server @@ -92,34 +104,34 @@ export class DecentralizedClient extends Client<"decentralized"> { if (waitForMoreParticipants && !receivedEnoughParticipants) { // Create a promise that resolves when enough participants join // The client will await this promise before sending its local weight update - this.promiseForMoreParticipants = this.createPromiseForMoreParticipants() + this.promiseForMoreParticipants = this.createPromiseForMoreParticipants(); } debug(`[${shortenId(id)}] assigned id generated by server`); if (this._ownId !== undefined) { - throw new Error('received id from server but was already received') + throw new Error("received id from server but was already received"); } - this._ownId = id - this.#pool = new PeerPool(id) - return model + this._ownId = id; + this.#pool = new PeerPool(id); + return model; } - override async disconnect (): Promise { + override async disconnect(): Promise { // Disconnect from peers - await this.#pool?.shutdown() - this.#pool = undefined + await this.#pool?.shutdown(); + this.#pool = undefined; if (this.#connections !== undefined) { - const peers = this.#connections.keySeq().toSet() - this.setAggregatorNodes(this.aggregator.nodes.subtract(peers)) + const peers = this.#connections.keySeq().toSet(); + this.setAggregatorNodes(this.aggregator.nodes.subtract(peers)); } // Disconnect from server - await this.server?.disconnect() - this._server = undefined - this._ownId = undefined - - return Promise.resolve() + await this.server?.disconnect(); + this._server = undefined; + this._ownId = undefined; + + return Promise.resolve(); } /** @@ -128,32 +140,34 @@ export class DecentralizedClient extends Client<"decentralized"> { * Given the list, the peers then create peer-to-peer connections with each other. * When connected, one peer creates a promise for every other peer's weight update * and waits for it to resolve. - * + * */ override async onRoundBeginCommunication(): Promise { // Notify the server we want to join the next round so that the server // waits for us to be ready before sending the list of peers for the round - this.server.send({ type: type.JoinRound }) + this.server.send({ type: type.JoinRound }); // Store the promise for the current round's aggregation result. // We will await for it to resolve at the end of the round when exchanging weight updates. - this.aggregationResult = this.aggregator.getPromiseForAggregation() - this.saveAndEmit("local training") - return Promise.resolve() + this.aggregationResult = this.aggregator.getPromiseForAggregation(); + this.saveAndEmit("local training"); + return Promise.resolve(); } - override async onRoundEndCommunication (weights: WeightsContainer): Promise { + override async onRoundEndCommunication( + weights: WeightsContainer, + ): Promise { if (this.aggregationResult === undefined) { - throw new TypeError('aggregation result promise is undefined') + throw new TypeError("aggregation result promise is undefined"); } // Save the status in case participants leave and we switch to waiting for more participants // Once enough new participants join we can display the previous status again - this.saveAndEmit("connecting to peers") + this.saveAndEmit("connecting to peers"); // First we check if we are waiting for more participants before sending our weight update - await this.waitForParticipantsIfNeeded() + await this.waitForParticipantsIfNeeded(); // Create peer-to-peer connections with all peers for the round - await this.establishPeerConnections() + await this.establishPeerConnections(); // Exchange weight updates with peers and return aggregated weights - return await this.exchangeWeightUpdates(weights) + return await this.exchangeWeightUpdates(weights); } /** @@ -163,30 +177,40 @@ export class DecentralizedClient extends Client<"decentralized"> { */ private async establishPeerConnections(): Promise { if (this.server === undefined) { - throw new Error("peer's server is undefined, make sure to call `client.connect()` first") - } if (this.#pool === undefined) { - throw new Error('peer pool is undefined, make sure to call `client.connect()` first') + throw new Error( + "peer's server is undefined, make sure to call `client.connect()` first", + ); + } + if (this.#pool === undefined) { + throw new Error( + "peer pool is undefined, make sure to call `client.connect()` first", + ); } // Reset peers list at each round of training to make sure client works with an updated peers // list, maintained by the server. Adds any received weights to the aggregator. // Tell the server we are ready for the next round - const readyMessage: messages.PeerIsReady = { type: type.PeerIsReady } - this.server.send(readyMessage) + const readyMessage: messages.PeerIsReady = { type: type.PeerIsReady }; + this.server.send(readyMessage); // Wait for the server to answer with the list of peers for the round try { - debug(`[${shortenId(this.ownId)}] is waiting for peer list for round ${this.aggregator.round}`); - const receivedMessage = await waitMessage(this.server, type.PeersForRound) - - const peers = Set(receivedMessage.peers) + debug( + `[${shortenId(this.ownId)}] is waiting for peer list for round ${this.aggregator.round}`, + ); + const receivedMessage = await waitMessage( + this.server, + type.PeersForRound, + ); + + const peers = Set(receivedMessage.peers); if (this.ownId !== undefined && peers.has(this.ownId)) { - throw new Error('received peer list contains our own id') + throw new Error("received peer list contains our own id"); } // Store the list of peers for the current round including ourselves - this.setAggregatorNodes(peers.add(this.ownId)) - this.aggregator.setRound(receivedMessage.aggregationRound) // the server gives us the round number + this.setAggregatorNodes(peers.add(this.ownId)); + this.aggregator.setRound(receivedMessage.aggregationRound); // the server gives us the round number // Initiate peer to peer connections with each peer // When connected, create a promise waiting for each peer's round contribution @@ -195,112 +219,166 @@ export class DecentralizedClient extends Client<"decentralized"> { this.server, // Init receipt of peers weights. this awaits the peer's // weight update and adds it to our aggregator upon reception - (conn) => this.receivePayloads(conn) - ) + (conn) => this.receivePayloads(conn), + ); - debug(`[${shortenId(this.ownId)}] received peers for round ${this.aggregator.round}: %o`, connections.keySeq().toJS()); - this.#connections = connections + debug( + `[${shortenId(this.ownId)}] received peers for round ${this.aggregator.round}: %o`, + connections.keySeq().toJS(), + ); + this.#connections = connections; } catch (e) { - debug(`Error for [${shortenId(this.ownId)}] while beginning round: %o`, e); - this.setAggregatorNodes(Set(this.ownId)) - this.#connections = Map() + debug( + `Error for [${shortenId(this.ownId)}] while beginning round: %o`, + e, + ); + this.setAggregatorNodes(Set(this.ownId)); + this.#connections = Map(); } } /** * At each communication rounds, awaits peers contributions and add them to the client's aggregator. * This method is used as callback by getPeers when connecting to the rounds' peers - * @param connections - * @param round + * @param connections + * @param round */ - private receivePayloads (connections: Map): void { + private receivePayloads(connections: Map): void { connections.forEach(async (connection, peerId) => { debug(`waiting for peer ${peerId}`); for (let r = 0; r < this.aggregator.communicationRounds; r++) { try { - const message = await waitMessageWithTimeout(connection, type.Payload, - 60_000, "Timeout waiting for a contribution from peer " + peerId) - const decoded = serialization.weights.decode(message.payload) + const message = await waitMessageWithTimeout( + connection, + type.Payload, + 60_000, + "Timeout waiting for a contribution from peer " + peerId, + ); + const decoded = serialization.weights.decode(message.payload); - if (!this.aggregator.isValidContribution(peerId, message.aggregationRound)) { - debug(`[${shortenId(this.ownId)}] failed to add contribution from peer ${shortenId(peerId)}`); - } - else { - debug(`[${shortenId(this.ownId)}] received payload from peer ${shortenId(peerId)}` + - ` for round (%d, %d)`, message.aggregationRound, message.communicationRound); + if ( + !this.aggregator.isValidContribution( + peerId, + message.aggregationRound, + ) + ) { + debug( + `[${shortenId(this.ownId)}] failed to add contribution from peer ${shortenId(peerId)}`, + ); + } else { + debug( + `[${shortenId(this.ownId)}] received payload from peer ${shortenId(peerId)}` + + ` for round (%d, %d)`, + message.aggregationRound, + message.communicationRound, + ); this.aggregator.once("aggregation", () => - debug(`[${shortenId(this.ownId)}] aggregated the model` + - ` for round (%d, %d)`, message.aggregationRound, message.communicationRound) - ) - this.aggregator.add(peerId, decoded, message.aggregationRound, message.communicationRound) + debug( + `[${shortenId(this.ownId)}] aggregated the model` + + ` for round (%d, %d)`, + message.aggregationRound, + message.communicationRound, + ), + ); + this.aggregator.add( + peerId, + decoded, + message.aggregationRound, + message.communicationRound, + ); } } catch (e) { - if (this.isDisconnected) return - debug(`Error for [${shortenId(this.ownId)}] while receiving payloads: %o`, e); + if (this.isDisconnected) return; + debug( + `Error for [${shortenId(this.ownId)}] while receiving payloads: %o`, + e, + ); } } - }) + }); } - private async exchangeWeightUpdates(weights: WeightsContainer): Promise { + private async exchangeWeightUpdates( + weights: WeightsContainer, + ): Promise { if (this.aggregationResult === undefined) { - throw new TypeError('aggregation result promise is undefined') + throw new TypeError("aggregation result promise is undefined"); } - this.saveAndEmit("updating model") + this.saveAndEmit("updating model"); // Perform the required communication rounds. Each communication round consists in sending our local payload, // followed by an aggregation step triggered by the receipt of other payloads, and handled by the aggregator. // A communication round's payload is the aggregation result of the previous communication round. The first // communication round simply sends our training result, i.e. model weights updates. This scheme allows for // the aggregator to define any complex multi-round aggregation mechanism. let result = weights; - for (let communicationRound = 0; communicationRound < this.aggregator.communicationRounds; communicationRound++) { - const connections = this.#connections - if (connections === undefined) throw new Error("peer's connections is undefined") + for ( + let communicationRound = 0; + communicationRound < this.aggregator.communicationRounds; + communicationRound++ + ) { + const connections = this.#connections; + if (connections === undefined) + throw new Error("peer's connections is undefined"); // Generate our payloads for this communication round and send them to all ready connected peers - const payloads = this.aggregator.makePayloads(result) + const payloads = this.aggregator.makePayloads(result); payloads.forEach(async (payload, id) => { // add our own contribution to the aggregator if (id === this.ownId) { - this.aggregator.add(this.ownId, payload, this.aggregator.round, communicationRound) - return + this.aggregator.add( + this.ownId, + payload, + this.aggregator.round, + communicationRound, + ); + return; } // Send our payload to each peer - const peer = connections.get(id) + const peer = connections.get(id); if (peer !== undefined) { - const encoded = await serialization.weights.encode(payload) + const encoded = await serialization.weights.encode(payload); const msg: messages.PeerMessage = { type: type.Payload, peer: id, aggregationRound: this.aggregator.round, communicationRound, - payload: encoded - } - peer.send(msg) - debug(`[${shortenId(this.ownId)}] send weight update to peer ${shortenId(msg.peer)}` + - ` for round (%d, %d)`, this.aggregator.round, communicationRound); + payload: encoded, + }; + peer.send(msg); + debug( + `[${shortenId(this.ownId)}] send weight update to peer ${shortenId(msg.peer)}` + + ` for round (%d, %d)`, + this.aggregator.round, + communicationRound, + ); } - }) + }); // Wait for aggregation before proceeding to the next communication round. // The current result will be used as payload for the eventual next communication round. - try { + try { result = await Promise.race([ this.aggregationResult, - timeout(undefined, "Timeout waiting on the aggregation result promise to resolve") - ]) + timeout( + undefined, + "Timeout waiting on the aggregation result promise to resolve", + ), + ]); } catch (e) { if (this.isDisconnected) { - return weights + return weights; } - debug(`[${shortenId(this.ownId)}] while waiting for aggregation: %o`, e); - break + debug( + `[${shortenId(this.ownId)}] while waiting for aggregation: %o`, + e, + ); + break; } // There is at least one communication round remaining if (communicationRound < this.aggregator.communicationRounds - 1) { // Reuse the aggregation result - this.aggregationResult = this.aggregator.getPromiseForAggregation() + this.aggregationResult = this.aggregator.getPromiseForAggregation(); } } - return await this.aggregationResult + return await this.aggregationResult; } } diff --git a/discojs/src/client/decentralized/index.ts b/discojs/src/client/decentralized/index.ts index 471e888bd..6e2c43bfa 100644 --- a/discojs/src/client/decentralized/index.ts +++ b/discojs/src/client/decentralized/index.ts @@ -1,2 +1,2 @@ -export { DecentralizedClient } from './decentralized_client.js' -export * as messages from './messages.js' +export { DecentralizedClient } from "./decentralized_client.js"; +export * as messages from "./messages.js"; diff --git a/discojs/src/client/decentralized/messages.ts b/discojs/src/client/decentralized/messages.ts index 626062ad4..30991c3eb 100644 --- a/discojs/src/client/decentralized/messages.ts +++ b/discojs/src/client/decentralized/messages.ts @@ -1,118 +1,124 @@ import { serialization } from "../../index.js"; -import { type SignalData } from './peer.js' -import { isNodeID, type NodeID } from '../types.js' -import { type, hasMessageType } from '../messages.js' -import type { ClientConnected, WaitingForMoreParticipants, EnoughParticipants } from '../messages.js' - +import { type SignalData } from "./peer.js"; +import { isNodeID, type NodeID } from "../types.js"; +import { type, hasMessageType } from "../messages.js"; +import type { + ClientConnected, + WaitingForMoreParticipants, + EnoughParticipants, +} from "../messages.js"; /// Phase 0 communication (between server and peers) export interface NewDecentralizedNodeInfo { - type: type.NewDecentralizedNodeInfo - id: NodeID - waitForMoreParticipants: boolean - nbOfParticipants: number + type: type.NewDecentralizedNodeInfo; + id: NodeID; + waitForMoreParticipants: boolean; + nbOfParticipants: number; } // WebRTC signal to forward to other node export interface SignalForPeer { - type: type.SignalForPeer - peer: NodeID - signal: SignalData + type: type.SignalForPeer; + peer: NodeID; + signal: SignalData; } // peer wants to join the next round export interface JoinRound { - type: type.JoinRound + type: type.JoinRound; } // peer who sent is ready export interface PeerIsReady { - type: type.PeerIsReady + type: type.PeerIsReady; } // server sends to each peer the list of peers to connect to export interface PeersForRound { - type: type.PeersForRound - peers: NodeID[] - aggregationRound: number + type: type.PeersForRound; + peers: NodeID[]; + aggregationRound: number; } /// Phase 1 communication (between peers) export interface Payload { - type: type.Payload - peer: NodeID - aggregationRound: number - communicationRound: number - payload: serialization.Encoded + type: type.Payload; + peer: NodeID; + aggregationRound: number; + communicationRound: number; + payload: serialization.Encoded; } /// Phase 2 communication (between peers) export type MessageFromServer = - NewDecentralizedNodeInfo | - SignalForPeer | - PeersForRound | - WaitingForMoreParticipants | - EnoughParticipants + | NewDecentralizedNodeInfo + | SignalForPeer + | PeersForRound + | WaitingForMoreParticipants + | EnoughParticipants; export type MessageToServer = - ClientConnected | - SignalForPeer | - PeerIsReady | - JoinRound + | ClientConnected + | SignalForPeer + | PeerIsReady + | JoinRound; -export type PeerMessage = Payload +export type PeerMessage = Payload; -export function isMessageFromServer (o: unknown): o is MessageFromServer { - if (!hasMessageType(o)) return false +export function isMessageFromServer(o: unknown): o is MessageFromServer { + if (!hasMessageType(o)) return false; switch (o.type) { case type.NewDecentralizedNodeInfo: - return 'id' in o && isNodeID(o.id) && - 'waitForMoreParticipants' in o && - typeof o.waitForMoreParticipants === 'boolean' + return ( + "id" in o && + isNodeID(o.id) && + "waitForMoreParticipants" in o && + typeof o.waitForMoreParticipants === "boolean" + ); case type.SignalForPeer: - return 'peer' in o && isNodeID(o.peer) && - 'signal' in o // TODO check signal content? + return "peer" in o && isNodeID(o.peer) && "signal" in o; // TODO check signal content? case type.PeersForRound: - return 'peers' in o && Array.isArray(o.peers) && o.peers.every(isNodeID) + return "peers" in o && Array.isArray(o.peers) && o.peers.every(isNodeID); case type.WaitingForMoreParticipants: case type.EnoughParticipants: - return true + return true; } - return false + return false; } -export function isMessageToServer (o: unknown): o is MessageToServer { - if (!hasMessageType(o)) return false +export function isMessageToServer(o: unknown): o is MessageToServer { + if (!hasMessageType(o)) return false; switch (o.type) { case type.ClientConnected: - return true + return true; case type.SignalForPeer: - return 'peer' in o && isNodeID(o.peer) && - 'signal' in o // TODO check signal content? + return "peer" in o && isNodeID(o.peer) && "signal" in o; // TODO check signal content? case type.JoinRound: case type.PeerIsReady: - return true + return true; } - return false + return false; } -export function isPeerMessage (o: unknown): o is PeerMessage { - if (!hasMessageType(o)) return false +export function isPeerMessage(o: unknown): o is PeerMessage { + if (!hasMessageType(o)) return false; switch (o.type) { case type.Payload: return ( - 'peer' in o && isNodeID(o.peer) && - 'payload' in o && serialization.isEncoded(o.payload) - ) + "peer" in o && + isNodeID(o.peer) && + "payload" in o && + serialization.isEncoded(o.payload) + ); } - return false + return false; } diff --git a/discojs/src/client/decentralized/peer.spec.ts b/discojs/src/client/decentralized/peer.spec.ts index ac6dbcd3e..b089b43fb 100644 --- a/discojs/src/client/decentralized/peer.spec.ts +++ b/discojs/src/client/decentralized/peer.spec.ts @@ -2,52 +2,65 @@ import { List, Range, Set } from "immutable"; import { assert, afterEach, beforeEach, describe, it } from "vitest"; import { Peer } from "./peer.js"; -describe('peer', () => { - let peer1: Peer - let peer2: Peer +describe("peer", () => { + let peer1: Peer; + let peer2: Peer; beforeEach(async () => { - peer1 = new Peer('1') - peer2 = new Peer('2', true) - const peers = Set.of(peer1, peer2) + peer1 = new Peer("1"); + peer2 = new Peer("2", true); + const peers = Set.of(peer1, peer2); - peer1.on('signal', (signal) => { peer2.signal(signal) }) - peer2.on('signal', (signal) => { peer1.signal(signal) }) + peer1.on("signal", (signal) => { + peer2.signal(signal); + }); + peer2.on("signal", (signal) => { + peer1.signal(signal); + }); - await Promise.all(peers.map(async (peer) => { await new Promise((resolve) => { peer.on('connect', resolve) }) } - ).toArray()) - }) + await Promise.all( + peers + .map(async (peer) => { + await new Promise((resolve) => { + peer.on("connect", resolve); + }); + }) + .toArray(), + ); + }); afterEach(async () => { - await Promise.all([peer1.destroy(), peer2.destroy()]) - }) + await Promise.all([peer1.destroy(), peer2.destroy()]); + }); - it('can send and receives a message', async () => { - const message = 'small message' + it("can send and receives a message", async () => { + const message = "small message"; - peer1.send(Buffer.from(message)) - const received = await new Promise((resolve) => { peer2.on('data', (msg) => { resolve(msg.toString()) }) }) + peer1.send(Buffer.from(message)); + const received = await new Promise((resolve) => { + peer2.on("data", (msg) => { + resolve(msg.toString()); + }); + }); - assert.strictEqual(received, message) - }) + assert.strictEqual(received, message); + }); - it('can send and receives multiple messages', async () => { - const messages = - Range(0, 5) - .map((i) => `message ${i}`) + it("can send and receives multiple messages", async () => { + const messages = Range(0, 5).map((i) => `message ${i}`); - for (const m of messages) peer1.send(Buffer.from(m)); + for (const m of messages) peer1.send(Buffer.from(m)); const receiveds: List = await new Promise((resolve) => { - let buffer = List() - peer2.on('data', (data) => { - buffer = buffer.push(data.toString()) + let buffer = List(); + peer2.on("data", (data) => { + buffer = buffer.push(data.toString()); if (buffer.size === messages.size) { - resolve(buffer) + resolve(buffer); } - }) - }) + }); + }); - assert.deepStrictEqual(receiveds.toArray(), messages.toArray()) - }) -}) + assert.deepStrictEqual(receiveds.toArray(), messages.toArray()); + }); +}); diff --git a/discojs/src/client/decentralized/peer.ts b/discojs/src/client/decentralized/peer.ts index ad746628e..22667d68a 100644 --- a/discojs/src/client/decentralized/peer.ts +++ b/discojs/src/client/decentralized/peer.ts @@ -1,33 +1,33 @@ -import { List, Map, Range, Seq } from 'immutable' +import { List, Map, Range, Seq } from "immutable"; import wrtc from "@epfml/isomorphic-wrtc"; -import SimplePeer from 'simple-peer' +import SimplePeer from "simple-peer"; -import type { NodeID } from '../types.js' +import type { NodeID } from "../types.js"; -type MessageID = number -type ChunkID = number +type MessageID = number; +type ChunkID = number; // message id + (chunk counter == 0) + chunk count -const FIRST_HEADER_SIZE = 2 + 1 + 1 +const FIRST_HEADER_SIZE = 2 + 1 + 1; // message id + chunk counter -const HEADER_SIZE = 2 + 1 +const HEADER_SIZE = 2 + 1; // at which interval to poll -const TICK = 10 +const TICK = 10; // we can't use the definition in DOM as we're platform independent export type SignalData = - | { type: 'answer' | 'offer' | 'pranswer' | 'rollback', sdp?: string } - | { type: 'transceiverRequest', transceiverRequest: { kind: string } } - | { type: 'renegotiate', renegotiate: true } - | { type: 'candidate', candidate: RTCIceCandidate } + | { type: "answer" | "offer" | "pranswer" | "rollback"; sdp?: string } + | { type: "transceiverRequest"; transceiverRequest: { kind: string } } + | { type: "renegotiate"; renegotiate: true } + | { type: "candidate"; candidate: RTCIceCandidate }; interface Events { - 'close': () => void - 'connect': () => void - 'signal': (signal: SignalData) => void - 'data': (data: Buffer) => void + close: () => void; + connect: () => void; + signal: (signal: SignalData) => void; + data: (data: Buffer) => void; } // Peer wraps a SimplePeer, adding message fragmentation @@ -44,209 +44,227 @@ interface Events { // // see feross/simple-peer#393 for more info export class Peer { - private readonly peer: SimplePeer.Instance + private readonly peer: SimplePeer.Instance; - private bufferSize?: number + private bufferSize?: number; - private sendCounter: MessageID = 0 - private sendQueue = List() + private sendCounter: MessageID = 0; + private sendQueue = List(); - private receiving = Map - }>() + private receiving = Map< + MessageID, + { + total?: number; + chunks: Map; + } + >(); - constructor ( + constructor( public readonly id: NodeID, - initiator: boolean = false + initiator: boolean = false, ) { - this.peer = new SimplePeer({ wrtc, initiator }) + this.peer = new SimplePeer({ wrtc, initiator }); } - send (msg: Buffer): void { - const chunks = this.chunk(msg) - this.sendQueue = this.sendQueue.concat(chunks) - this.flush() + send(msg: Buffer): void { + const chunks = this.chunk(msg); + this.sendQueue = this.sendQueue.concat(chunks); + this.flush(); } - private flush (): void { + private flush(): void { if (this.bufferSize === undefined) { - throw new Error('flush without known buffer size') + throw new Error("flush without known buffer size"); } - const chunk = this.sendQueue.first() + const chunk = this.sendQueue.first(); if (chunk === undefined) { - return // nothing to flush + return; // nothing to flush } - const remainingBufferSize = this.bufferSize - this.peer.bufferSize + const remainingBufferSize = this.bufferSize - this.peer.bufferSize; if (chunk.length > remainingBufferSize) { - setTimeout(() => { this.flush() }, TICK) - return + setTimeout(() => { + this.flush(); + }, TICK); + return; } - this.sendQueue = this.sendQueue.shift() - this.peer.send(chunk) + this.sendQueue = this.sendQueue.shift(); + this.peer.send(chunk); // and loop - this.flush() + this.flush(); } - get maxChunkSize (): number { + get maxChunkSize(): number { if (this.bufferSize === undefined) { - throw new Error('chunk without known buffer size') + throw new Error("chunk without known buffer size"); } - return this.bufferSize + return this.bufferSize; } - private chunk (b: Buffer): Seq.Indexed { - const messageID = this.sendCounter - this.sendCounter++ - if (this.sendCounter > 0xFFFF) { - throw new Error('too much messages sent to this peer') + private chunk(b: Buffer): Seq.Indexed { + const messageID = this.sendCounter; + this.sendCounter++; + if (this.sendCounter > 0xffff) { + throw new Error("too much messages sent to this peer"); } // special case as Range(1, 0) yields a value - let tail = Seq.Indexed([]) + let tail = Seq.Indexed([]); if (b.length > this.maxChunkSize) { tail = Range( this.maxChunkSize - FIRST_HEADER_SIZE, b.length, - this.maxChunkSize - HEADER_SIZE - ).map((offset) => b.subarray( - offset, - offset + this.maxChunkSize - HEADER_SIZE - )) + this.maxChunkSize - HEADER_SIZE, + ).map((offset) => + b.subarray(offset, offset + this.maxChunkSize - HEADER_SIZE), + ); } - const totalChunkCount = 1 + tail.count() - if (totalChunkCount > 0xFF) { - throw new Error(`The payload is too big: ${totalChunkCount * this.maxChunkSize} bytes > 255,` + - ' consider reducing the model size or increasing the chunk size') + const totalChunkCount = 1 + tail.count(); + if (totalChunkCount > 0xff) { + throw new Error( + `The payload is too big: ${totalChunkCount * this.maxChunkSize} bytes > 255,` + + " consider reducing the model size or increasing the chunk size", + ); } const firstChunk = Buffer.alloc( - (b.length > this.maxChunkSize - FIRST_HEADER_SIZE) + b.length > this.maxChunkSize - FIRST_HEADER_SIZE ? this.maxChunkSize - : FIRST_HEADER_SIZE + b.length - ) - firstChunk.writeUint16BE(messageID) - firstChunk.writeUint8(0 as ChunkID, 2) - firstChunk.writeUint8(totalChunkCount, 3) - b.copy(firstChunk, FIRST_HEADER_SIZE, 0, this.maxChunkSize - FIRST_HEADER_SIZE) - - return Seq.Indexed([firstChunk]) - .concat( - Range(1 as ChunkID, Number.POSITIVE_INFINITY) - .zip(tail) - .map(([id, raw]) => { - const chunk = Buffer.alloc(HEADER_SIZE + raw.length) - chunk.writeUint16BE(messageID) - chunk.writeUint8(id, 2) - raw.copy(chunk, HEADER_SIZE, 0) - return chunk - }) - ) + : FIRST_HEADER_SIZE + b.length, + ); + firstChunk.writeUint16BE(messageID); + firstChunk.writeUint8(0 as ChunkID, 2); + firstChunk.writeUint8(totalChunkCount, 3); + b.copy( + firstChunk, + FIRST_HEADER_SIZE, + 0, + this.maxChunkSize - FIRST_HEADER_SIZE, + ); + + return Seq.Indexed([firstChunk]).concat( + Range(1 as ChunkID, Number.POSITIVE_INFINITY) + .zip(tail) + .map(([id, raw]) => { + const chunk = Buffer.alloc(HEADER_SIZE + raw.length); + chunk.writeUint16BE(messageID); + chunk.writeUint8(id, 2); + raw.copy(chunk, HEADER_SIZE, 0); + return chunk; + }), + ); } - async destroy (): Promise { + async destroy(): Promise { return new Promise((resolve, reject) => { - this.peer.once('error', reject) - this.peer.once('close', resolve) - this.peer.destroy() - }) + this.peer.once("error", reject); + this.peer.once("close", resolve); + this.peer.destroy(); + }); } - signal (signal: SignalData): void { + signal(signal: SignalData): void { // extract max buffer size from the signal - if (signal.type === 'offer' || signal.type === 'answer') { + if (signal.type === "offer" || signal.type === "answer") { if (signal.sdp === undefined) { - throw new Error('signal answer|offer without session description') + throw new Error("signal answer|offer without session description"); } if (this.bufferSize !== undefined) { - throw new Error('buffer size set twice') + throw new Error("buffer size set twice"); } - const match = signal.sdp.match(/a=max-message-size:(\d+)/) + const match = signal.sdp.match(/a=max-message-size:(\d+)/); if (match === null) { // TODO default value instead? - throw new Error('no max-message-size found in signal') + throw new Error("no max-message-size found in signal"); } - const max = parseInt(match[1], 10) + const max = parseInt(match[1], 10); if (isNaN(max)) { - throw new Error(`unable to parse max-message-size as int: ${match[1]}`) + throw new Error(`unable to parse max-message-size as int: ${match[1]}`); } - this.bufferSize = max + this.bufferSize = max; } - this.peer.signal(signal) + this.peer.signal(signal); } on(event: K, listener: Events[K]): void { - if (event !== 'data') { - this.peer.on(event, listener) - return + if (event !== "data") { + this.peer.on(event, listener); + return; } // gotta help typescript here - const dataListener = listener as Events['data'] + const dataListener = listener as Events["data"]; - this.peer.on('data', (data: unknown) => { + this.peer.on("data", (data: unknown) => { if (!Buffer.isBuffer(data) || data.length < HEADER_SIZE) { - throw new Error('received invalid message type') + throw new Error("received invalid message type"); } - const messageID: MessageID = data.readUInt16BE() //readUint16BE (case sensitive) fails at runtime - const chunkID: ChunkID = data.readUInt8(2) // same for readUint8 + const messageID: MessageID = data.readUInt16BE(); //readUint16BE (case sensitive) fails at runtime + const chunkID: ChunkID = data.readUInt8(2); // same for readUint8 const received = this.receiving.get(messageID, { total: undefined, - chunks: Map() - }) - let total = received.total - const chunks = received.chunks + chunks: Map(), + }); + let total = received.total; + const chunks = received.chunks; if (chunks.has(chunkID)) { - throw new Error(`chunk ${messageID}:${chunkID} already received`) + throw new Error(`chunk ${messageID}:${chunkID} already received`); } - let chunk: Buffer + let chunk: Buffer; if (chunkID !== 0) { - chunk = Buffer.alloc(data.length - HEADER_SIZE) - data.copy(chunk, 0, HEADER_SIZE) + chunk = Buffer.alloc(data.length - HEADER_SIZE); + data.copy(chunk, 0, HEADER_SIZE); } else { if (data.length < FIRST_HEADER_SIZE) { - throw new Error('received invalid message type') + throw new Error("received invalid message type"); } if (total !== undefined) { - throw new Error('first header received twice') + throw new Error("first header received twice"); } - const readTotal = data.readUInt8(3) - total = readTotal - chunk = Buffer.alloc(data.length - FIRST_HEADER_SIZE) - data.copy(chunk, 0, FIRST_HEADER_SIZE) + const readTotal = data.readUInt8(3); + total = readTotal; + chunk = Buffer.alloc(data.length - FIRST_HEADER_SIZE); + data.copy(chunk, 0, FIRST_HEADER_SIZE); if (chunks.keySeq().some((id) => id > readTotal)) { - throw new Error('received total of chunk but got now-out-of-bound chunks') + throw new Error( + "received total of chunk but got now-out-of-bound chunks", + ); } } this.receiving = this.receiving.set(messageID, { total, - chunks: chunks.set(chunkID, chunk) - }) + chunks: chunks.set(chunkID, chunk), + }); const readyMessages = this.receiving - .filter(({ total, chunks }) => total !== undefined && chunks.size === total) + .filter( + ({ total, chunks }) => total !== undefined && chunks.size === total, + ) .sort() - .map(({ chunks }) => chunks.entrySeq().toList().sortBy(([id, _]) => id)) - .map((chunks) => Buffer.concat(chunks.map(([_, b]) => b).toArray())) - this.receiving = this.receiving.deleteAll(readyMessages.keys()) - - readyMessages - .forEach((message) => { - // TODO debug - dataListener(message) - }) - }) + .map(({ chunks }) => + chunks + .entrySeq() + .toList() + .sortBy(([id, _]) => id), + ) + .map((chunks) => Buffer.concat(chunks.map(([_, b]) => b).toArray())); + this.receiving = this.receiving.deleteAll(readyMessages.keys()); + + readyMessages.forEach((message) => { + // TODO debug + dataListener(message); + }); + }); } } diff --git a/discojs/src/client/decentralized/peer_pool.spec.ts b/discojs/src/client/decentralized/peer_pool.spec.ts index e748df496..435f7ef0e 100644 --- a/discojs/src/client/decentralized/peer_pool.spec.ts +++ b/discojs/src/client/decentralized/peer_pool.spec.ts @@ -1,125 +1,142 @@ -import { Map, Range } from 'immutable' +import { Map, Range } from "immutable"; import { assert, afterEach, beforeEach, describe, it } from "vitest"; -import type { EventConnection, PeerConnection } from '../event_connection.js' -import { type } from '../messages.js' -import type { NodeID } from '../types.js' +import type { EventConnection, PeerConnection } from "../event_connection.js"; +import { type } from "../messages.js"; +import type { NodeID } from "../types.js"; -import type { messages } from './index.js' -import { PeerPool } from './peer_pool.js' +import type { messages } from "./index.js"; +import { PeerPool } from "./peer_pool.js"; describe("peer pool", { timeout: 10_000 }, () => { - let pools: Map + let pools: Map; beforeEach(() => { - const count = 3 + const count = 3; - pools = Map(Range(1, count + 1).map(String).map((id) => - [id, new PeerPool(id)] - )) - }) + pools = Map( + Range(1, count + 1) + .map(String) + .map((id) => [id, new PeerPool(id)]), + ); + }); afterEach(async () => { - await Promise.all(pools.valueSeq().map((p) => p.shutdown())) - }) + await Promise.all(pools.valueSeq().map((p) => p.shutdown())); + }); - function mockServer (poolId: string): EventConnection { + function mockServer(poolId: string): EventConnection { return { send: (msg): void => { - const signal = msg as messages.SignalForPeer - const otherPool = pools.get(signal.peer) + const signal = msg as messages.SignalForPeer; + const otherPool = pools.get(signal.peer); if (otherPool === undefined) { - throw new Error(`signal for unknown pool: ${signal.peer}`) + throw new Error(`signal for unknown pool: ${signal.peer}`); } - otherPool.signal(poolId, signal.signal) + otherPool.signal(poolId, signal.signal); }, - on: (): void => { /* nothing */ }, - once: (): void => { /* nothing */ }, - disconnect: (): Promise => Promise.resolve() - } + on: (): void => { + /* nothing */ + }, + once: (): void => { + /* nothing */ + }, + disconnect: (): Promise => Promise.resolve(), + }; } - function mockWeights (id: NodeID): messages.Payload { + function mockWeights(id: NodeID): messages.Payload { return { type: type.Payload, peer: id, payload: Uint8Array.of(1, 2, 3), aggregationRound: 0, - communicationRound: 0 - } + communicationRound: 0, + }; + } + + async function getAllPeers( + pools: Map, + ): Promise>> { + const ids = pools.keySeq().toSet(); + + return Map( + await Promise.all( + pools + .map( + async (pool, poolID) => + await pool.getPeers( + ids.remove(poolID), + mockServer(poolID), + () => { + // empty + }, + ), + ) + .entrySeq() + .map( + async ([id, p]) => + [id, await p] as [NodeID, Map], + ) + .toArray(), + ), + ); } - async function getAllPeers (pools: Map): - Promise>> { - const ids = pools.keySeq().toSet() - - return Map( - await Promise.all( - pools - .map( - async (pool, poolID) => - await pool.getPeers( - ids.remove(poolID), - mockServer(poolID), - () => { - // empty - }, - ), - ) - .entrySeq() - .map( - async ([id, p]) => - [id, await p] as [NodeID, Map], - ) - .toArray(), - ), - ); - } - - async function assertCanSendMessagesToEach ( - peersSets: Map> + async function assertCanSendMessagesToEach( + peersSets: Map>, ): Promise { - const messages = - peersSets - .entrySeq() - .map(([_, peers]) => - peers - .keySeq().map((id) => mockWeights(id)) - .toArray()) - .toArray() - .flat() - - for (const [poolID, peers] of peersSets) - for (const peer of peers.values()) peer.send(mockWeights(poolID)); - - const exchanged = (await Promise.all( - peersSets - .valueSeq() - .map(async (peers) => await Promise.all( - peers - .valueSeq() - .map(async (peer) => - await new Promise((resolve) => { peer.on(type.Payload, (data) => { resolve(data) }) } - ) - ) - .toArray() - )) - .toArray() - )).flat() - - assert.sameDeepMembers(exchanged, messages) + const messages = peersSets + .entrySeq() + .map(([_, peers]) => + peers + .keySeq() + .map((id) => mockWeights(id)) + .toArray(), + ) + .toArray() + .flat(); + + for (const [poolID, peers] of peersSets) + for (const peer of peers.values()) peer.send(mockWeights(poolID)); + + const exchanged = ( + await Promise.all( + peersSets + .valueSeq() + .map( + async (peers) => + await Promise.all( + peers + .valueSeq() + .map( + async (peer) => + await new Promise((resolve) => { + peer.on(type.Payload, (data) => { + resolve(data); + }); + }), + ) + .toArray(), + ), + ) + .toArray(), + ) + ).flat(); + + assert.sameDeepMembers(exchanged, messages); } - it('gets peers to connect to', async () => { - const poolsPeers = await getAllPeers(pools) - await assertCanSendMessagesToEach(poolsPeers) - }) + it("gets peers to connect to", async () => { + const poolsPeers = await getAllPeers(pools); + await assertCanSendMessagesToEach(poolsPeers); + }); it("doesn't reconnect known peers", async () => { - const poolsPeersFirst = await getAllPeers(pools) - await assertCanSendMessagesToEach(poolsPeersFirst) + const poolsPeersFirst = await getAllPeers(pools); + await assertCanSendMessagesToEach(poolsPeersFirst); - const poolsPeersSecond = await getAllPeers(pools) - await assertCanSendMessagesToEach(poolsPeersSecond) - }) -}) + const poolsPeersSecond = await getAllPeers(pools); + await assertCanSendMessagesToEach(poolsPeersSecond); + }); +}); diff --git a/discojs/src/client/decentralized/peer_pool.ts b/discojs/src/client/decentralized/peer_pool.ts index ba33690bd..dd004b3a5 100644 --- a/discojs/src/client/decentralized/peer_pool.ts +++ b/discojs/src/client/decentralized/peer_pool.ts @@ -1,51 +1,49 @@ import createDebug from "debug"; -import { Map, type Set } from 'immutable' +import { Map, type Set } from "immutable"; -import { Peer, type SignalData } from './peer.js' -import type { NodeID } from '../types.js' -import { PeerConnection, type EventConnection } from '../event_connection.js' +import { Peer, type SignalData } from "./peer.js"; +import type { NodeID } from "../types.js"; +import { PeerConnection, type EventConnection } from "../event_connection.js"; const debug = createDebug("discojs:client:decentralized:pool"); // TODO cleanup old peers export class PeerPool { - private peers = Map() + private peers = Map(); - constructor ( - private readonly id: NodeID - ) {} + constructor(private readonly id: NodeID) {} - async shutdown (): Promise { + async shutdown(): Promise { debug(`[${this.id}] is shutting down all its connections`); // Add a timeout o.w. the promise hangs forever if the other peer is already disconnected await Promise.race([ Promise.all(this.peers.valueSeq().map((peer) => peer.disconnect())), - new Promise((res, _) => setTimeout(res, 1000)) // Wait for other peers to finish - ]) - this.peers = Map() + new Promise((res, _) => setTimeout(res, 1000)), // Wait for other peers to finish + ]); + this.peers = Map(); } - signal (peerId: NodeID, signal: SignalData): void { + signal(peerId: NodeID, signal: SignalData): void { debug(`[${this.id}] signals for %s`, peerId); - const peer = this.peers.get(peerId) + const peer = this.peers.get(peerId); if (peer === undefined) { - throw new Error(`received signal for unknown peer: ${peerId}`) + throw new Error(`received signal for unknown peer: ${peerId}`); } - peer.signal(signal) + peer.signal(signal); } - async getPeers ( + async getPeers( peersToConnect: Set, signallingServer: EventConnection, // TODO as event? - clientHandle: (connections: Map) => void + clientHandle: (connections: Map) => void, ): Promise> { if (peersToConnect.contains(this.id)) { - throw new Error('peers to connect contains our id') + throw new Error("peers to connect contains our id"); } debug(`[${this.id}] is connecting peers: %o`, peersToConnect.toArray()); @@ -53,21 +51,30 @@ export class PeerPool { const newPeers = Map( peersToConnect .filter((id) => !this.peers.has(id)) - .map((id) => [id, new Peer(id, id < this.id)] as [string, Peer]) - ) + .map((id) => [id, new Peer(id, id < this.id)] as [string, Peer]), + ); - debug(`[${this.id}] asked to connect new peers: %o`, newPeers.keySeq().toArray()); - const newPeersConnections = newPeers.map((peer) => new PeerConnection(this.id, peer, signallingServer)) + debug( + `[${this.id}] asked to connect new peers: %o`, + newPeers.keySeq().toArray(), + ); + const newPeersConnections = newPeers.map( + (peer) => new PeerConnection(this.id, peer, signallingServer), + ); // adding peers to pool before connecting them because they must be set to call signal on them - this.peers = this.peers.merge(newPeersConnections) + this.peers = this.peers.merge(newPeersConnections); - clientHandle(this.peers) + clientHandle(this.peers); - await Promise.all(newPeersConnections.valueSeq().map((conn) => conn.connect())) - debug(`[${this.id}] knowns connected peers: %o`, this.peers.keySeq().toArray()) + await Promise.all( + newPeersConnections.valueSeq().map((conn) => conn.connect()), + ); + debug( + `[${this.id}] knowns connected peers: %o`, + this.peers.keySeq().toArray(), + ); - return this.peers - .filter((_, id) => peersToConnect.has(id)) + return this.peers.filter((_, id) => peersToConnect.has(id)); } } diff --git a/discojs/src/client/event_connection.ts b/discojs/src/client/event_connection.ts index 3e3aec409..ca253ba64 100644 --- a/discojs/src/client/event_connection.ts +++ b/discojs/src/client/event_connection.ts @@ -1,144 +1,173 @@ import createDebug from "debug"; import WebSocket from "isomorphic-ws"; import * as msgpack from "@msgpack/msgpack"; -import type { Peer, SignalData } from './decentralized/peer.js' -import type { NodeID } from './types.js' -import * as decentralizedMessages from './decentralized/messages.js' -import { type, type NarrowMessage, type Message } from './messages.js' -import { timeout } from './utils.js' +import type { Peer, SignalData } from "./decentralized/peer.js"; +import type { NodeID } from "./types.js"; +import * as decentralizedMessages from "./decentralized/messages.js"; +import { type, type NarrowMessage, type Message } from "./messages.js"; +import { timeout } from "./utils.js"; -import { EventEmitter } from '../utils/event_emitter.js' +import { EventEmitter } from "../utils/event_emitter.js"; const debug = createDebug("discojs:client:connections"); export interface EventConnection { - on: (type: K, handler: (event: NarrowMessage) => void) => void - once: (type: K, handler: (event: NarrowMessage) => void) => void - send: (msg: T) => void - disconnect: () => Promise + on: ( + type: K, + handler: (event: NarrowMessage) => void, + ) => void; + once: ( + type: K, + handler: (event: NarrowMessage) => void, + ) => void; + send: (msg: T) => void; + disconnect: () => Promise; } -export async function waitMessage (connection: EventConnection, type: T): Promise> { +export async function waitMessage( + connection: EventConnection, + type: T, +): Promise> { return await new Promise((resolve) => { // "once" is important because we can't resolve the same promise multiple times connection.once(type, (event) => { - resolve(event) - }) - }) + resolve(event); + }); + }); } export async function waitMessageWithTimeout( connection: EventConnection, - type: T, timeoutMs?: number, - errorMsg: string = 'timeout'): Promise> { - - return await Promise.race([waitMessage(connection, type), timeout(timeoutMs, errorMsg)]) + type: T, + timeoutMs?: number, + errorMsg: string = "timeout", +): Promise> { + return await Promise.race([ + waitMessage(connection, type), + timeout(timeoutMs, errorMsg), + ]); } -export class PeerConnection extends EventEmitter<{ [K in type]: NarrowMessage }> implements EventConnection { - constructor ( +export class PeerConnection + extends EventEmitter<{ [K in type]: NarrowMessage }> + implements EventConnection +{ + constructor( private readonly _ownId: NodeID, private readonly peer: Peer, - private readonly signallingServer: EventConnection + private readonly signallingServer: EventConnection, ) { - super() + super(); } - async connect (): Promise { - this.peer.on('signal', (signal) => { + async connect(): Promise { + this.peer.on("signal", (signal) => { const msg: decentralizedMessages.SignalForPeer = { type: type.SignalForPeer, peer: this.peer.id, - signal - } - this.signallingServer.send(msg) - }) + signal, + }; + this.signallingServer.send(msg); + }); - this.peer.on('data', (data) => { - const msg: unknown = msgpack.decode(data) + this.peer.on("data", (data) => { + const msg: unknown = msgpack.decode(data); if (!decentralizedMessages.isPeerMessage(msg)) { - throw new Error(`invalid message received: ${JSON.stringify(msg)}`) + throw new Error(`invalid message received: ${JSON.stringify(msg)}`); } - this.emit(msg.type, msg) - }) + this.emit(msg.type, msg); + }); this.peer.on("close", () => { debug(`[${this._ownId}] peer ${this.peer.id} closed connection`); }); await new Promise((resolve) => { - this.peer.on('connect', resolve) - }) + this.peer.on("connect", resolve); + }); } - signal (signal: SignalData): void { - this.peer.signal(signal) + signal(signal: SignalData): void { + this.peer.signal(signal); } - send (msg: T): void { + send(msg: T): void { if (!decentralizedMessages.isPeerMessage(msg)) { - throw new Error(`can't send this type of message: ${JSON.stringify(msg)}`) + throw new Error( + `can't send this type of message: ${JSON.stringify(msg)}`, + ); } - this.peer.send(Buffer.from(msgpack.encode(msg))) + this.peer.send(Buffer.from(msgpack.encode(msg))); } async disconnect(): Promise { - await this.peer.destroy() + await this.peer.destroy(); } } -export class WebSocketServer extends EventEmitter<{ [K in type]: NarrowMessage }> implements EventConnection { - private constructor ( +export class WebSocketServer + extends EventEmitter<{ [K in type]: NarrowMessage }> + implements EventConnection +{ + private constructor( private readonly socket: WebSocket.WebSocket, - private readonly validateSent?: (msg: Message) => boolean - ) { super() } + private readonly validateSent?: (msg: Message) => boolean, + ) { + super(); + } - static async connect (url: URL, + static async connect( + url: URL, validateReceived: (msg: unknown) => msg is Message, - validateSent: (msg: Message) => boolean): Promise { - const ws = new WebSocket(url) - ws.binaryType = 'arraybuffer' + validateSent: (msg: Message) => boolean, + ): Promise { + const ws = new WebSocket(url); + ws.binaryType = "arraybuffer"; - const server: WebSocketServer = new WebSocketServer(ws, validateSent) + const server: WebSocketServer = new WebSocketServer(ws, validateSent); ws.onmessage = (event: WebSocket.MessageEvent) => { if (!(event.data instanceof ArrayBuffer)) { - throw new Error('server did not send an ArrayBuffer') + throw new Error("server did not send an ArrayBuffer"); } - const msg: unknown = msgpack.decode(new Uint8Array(event.data)) + const msg: unknown = msgpack.decode(new Uint8Array(event.data)); // Validate message format if (!validateReceived(msg)) { - throw new Error(`invalid message received: ${JSON.stringify(msg)}`) + throw new Error(`invalid message received: ${JSON.stringify(msg)}`); } - server.emit(msg.type, msg) - } + server.emit(msg.type, msg); + }; return await new Promise((resolve, reject) => { ws.onerror = (err: WebSocket.ErrorEvent) => { - reject(new Error(`Server unreachable: ${err.message}`)) - } - ws.onopen = () => { resolve(server) } - }) + reject(new Error(`Server unreachable: ${err.message}`)); + }; + ws.onopen = () => { + resolve(server); + }; + }); } disconnect(): Promise { return new Promise((resolve, reject) => { - this.socket.onclose = () => resolve() - this.socket.onerror = (e) => reject(new Error(e.message)) - this.socket.close() - }) + this.socket.onclose = () => resolve(); + this.socket.onerror = (e) => reject(new Error(e.message)); + this.socket.close(); + }); } - send (msg: Message): void { + send(msg: Message): void { if (this.validateSent !== undefined && !this.validateSent(msg)) { - throw new Error(`can't send this type of message: ${JSON.stringify(msg)}`) + throw new Error( + `can't send this type of message: ${JSON.stringify(msg)}`, + ); } - this.socket.send(msgpack.encode(msg)) + this.socket.send(msgpack.encode(msg)); } } diff --git a/discojs/src/client/federated/federated_client.ts b/discojs/src/client/federated/federated_client.ts index a89c65ad6..9e21bd46a 100644 --- a/discojs/src/client/federated/federated_client.ts +++ b/discojs/src/client/federated/federated_client.ts @@ -4,10 +4,7 @@ import { serialization } from "../../index.js"; import type { DataType, Model, WeightsContainer } from "../../index.js"; import { Client, shortenId } from "../client.js"; import { type, type ClientConnected } from "../messages.js"; -import { - waitMessage, - WebSocketServer, -} from "../event_connection.js"; +import { waitMessage, WebSocketServer } from "../event_connection.js"; import * as messages from "./messages.js"; const debug = createDebug("discojs:client:federated"); @@ -16,22 +13,21 @@ const debug = createDebug("discojs:client:federated"); * Arbitrary node id assigned to the federated server which we are communicating with. * Indeed, the server acts as a node within the network. In the federated setting described * by this client class, the server is the only node which we are communicating with. -*/ + */ const SERVER_NODE_ID = "federated-server-node-id"; /** * Client class that communicates with a centralized, federated server, when training * a specific task in the federated setting. */ -export class FederatedClient extends Client<"federated"> { - +export class FederatedClient extends Client<"federated"> { /** * Initializes the connection to the server, gets our node ID * as well as the latest training information: latest global model, current round and * whether we are waiting for more participants. */ override async connect(): Promise> { - const model = await super.connect() // Get the server base model + const model = await super.connect(); // Get the server base model const serverURL = new URL("", this.url.href); switch (this.url.protocol) { @@ -53,8 +49,8 @@ export class FederatedClient extends Client<"federated"> { ); // c.f. setupServerCallbacks doc for explanation - let receivedEnoughParticipants = false - this.setupServerCallbacks(() => receivedEnoughParticipants = true) + let receivedEnoughParticipants = false; + this.setupServerCallbacks(() => (receivedEnoughParticipants = true)); this.aggregator.registerNode(SERVER_NODE_ID); @@ -63,11 +59,9 @@ export class FederatedClient extends Client<"federated"> { }; this.server.send(msg); - const { - id, waitForMoreParticipants, payload, - round, nbOfParticipants - } = await waitMessage(this.server, type.NewFederatedNodeInfo); - + const { id, waitForMoreParticipants, payload, round, nbOfParticipants } = + await waitMessage(this.server, type.NewFederatedNodeInfo); + // This should come right after receiving the message to make sure // we don't miss a subsequent message from the server // We check if the server is telling us to wait for more participants @@ -76,20 +70,23 @@ export class FederatedClient extends Client<"federated"> { if (waitForMoreParticipants && !receivedEnoughParticipants) { // Create a promise that resolves when enough participants join // The client will await this promise before sending its local weight update - this.promiseForMoreParticipants = this.createPromiseForMoreParticipants() + this.promiseForMoreParticipants = this.createPromiseForMoreParticipants(); } if (this._ownId !== undefined) { - throw new Error('received id from server but was already received') + throw new Error("received id from server but was already received"); } this._ownId = id; debug(`[${shortenId(id)}] joined session at round ${round} `); - this.aggregator.setRound(round) - this.nbOfParticipants = nbOfParticipants + this.aggregator.setRound(round); + this.nbOfParticipants = nbOfParticipants; // Upon connecting, the server answers with a boolean // which indicates whether there are enough participants or not - debug(`[${shortenId(this.ownId)}] upon connecting, wait for participant flag %o`, this.waitingForMoreParticipants) - model.weights = serialization.weights.decode(payload) - return model + debug( + `[${shortenId(this.ownId)}] upon connecting, wait for participant flag %o`, + this.waitingForMoreParticipants, + ); + model.weights = serialization.weights.decode(payload); + return model; } /** @@ -105,8 +102,10 @@ export class FederatedClient extends Client<"federated"> { override onRoundBeginCommunication(): Promise { // Prepare the result promise for the incoming round - this.aggregationResult = new Promise((resolve) => this.aggregator.once('aggregation', resolve)) - this.saveAndEmit("local training") + this.aggregationResult = new Promise((resolve) => + this.aggregator.once("aggregation", resolve), + ); + this.saveAndEmit("local training"); return Promise.resolve(); } @@ -120,23 +119,25 @@ export class FederatedClient extends Client<"federated"> { * @param weights Local weights sent to the server at the end of the local training round * @returns the new global weights sent by the server */ - override async onRoundEndCommunication(weights: WeightsContainer): Promise { - if (this._ownId === undefined) - throw new Error("no received ID from server"); + override async onRoundEndCommunication( + weights: WeightsContainer, + ): Promise { + if (this._ownId === undefined) + throw new Error("no received ID from server"); if (this.aggregationResult === undefined) { throw new Error("local aggregation result was not set"); } // First we check if we are waiting for more participants before sending our weight update - await this.waitForParticipantsIfNeeded() - this.saveAndEmit("updating model") + await this.waitForParticipantsIfNeeded(); + this.saveAndEmit("updating model"); // Send our local contribution to the server // and receive the server global update for this round as an answer to our contribution - const payloadToServer = this.aggregator - .makePayloads(weights) - .get(SERVER_NODE_ID); - if (payloadToServer === undefined) - throw new Error("aggregator didn't make a payload for the server"); + const payloadToServer = this.aggregator + .makePayloads(weights) + .get(SERVER_NODE_ID); + if (payloadToServer === undefined) + throw new Error("aggregator didn't make a payload for the server"); const msg: messages.SendPayload = { type: type.SendPayload, payload: await serialization.weights.encode(payloadToServer), @@ -145,18 +146,22 @@ export class FederatedClient extends Client<"federated"> { // Need to await the resulting global model right after sending our local contribution // to make sure we don't miss it - debug(`[${shortenId(this.ownId)}] sent its local update to the server for round ${this.aggregator.round}`); + debug( + `[${shortenId(this.ownId)}] sent its local update to the server for round ${this.aggregator.round}`, + ); this.server.send(msg); - debug(`[${shortenId(this.ownId)}] is waiting for server update for round ${this.aggregator.round + 1}`); + debug( + `[${shortenId(this.ownId)}] is waiting for server update for round ${this.aggregator.round + 1}`, + ); const { payload: payloadFromServer, round: serverRound, - nbOfParticipants - } = await waitMessage( this.server, type.ReceiveServerPayload); // Wait indefinitely for the server update - this.nbOfParticipants = nbOfParticipants // Save the current participants + nbOfParticipants, + } = await waitMessage(this.server, type.ReceiveServerPayload); // Wait indefinitely for the server update + this.nbOfParticipants = nbOfParticipants; // Save the current participants const serverResult = serialization.weights.decode(payloadFromServer); this.aggregator.setRound(serverRound); - return serverResult + return serverResult; } } diff --git a/discojs/src/client/federated/index.ts b/discojs/src/client/federated/index.ts index 63db2e0c7..056fd8f9f 100644 --- a/discojs/src/client/federated/index.ts +++ b/discojs/src/client/federated/index.ts @@ -1,2 +1,2 @@ -export { FederatedClient } from './federated_client.js' -export * as messages from './messages.js' +export { FederatedClient } from "./federated_client.js"; +export * as messages from "./messages.js"; diff --git a/discojs/src/client/federated/messages.ts b/discojs/src/client/federated/messages.ts index 3733d2c1c..4d4ee0e2a 100644 --- a/discojs/src/client/federated/messages.ts +++ b/discojs/src/client/federated/messages.ts @@ -1,43 +1,47 @@ import type { serialization } from "../../index.js"; -import { type NodeID } from '..//types.js' +import { type NodeID } from "..//types.js"; -import { type, hasMessageType } from '../messages.js' - import type { ClientConnected, WaitingForMoreParticipants, EnoughParticipants } from '../messages.js' +import { type, hasMessageType } from "../messages.js"; +import type { + ClientConnected, + WaitingForMoreParticipants, + EnoughParticipants, +} from "../messages.js"; - // See ../messages.ts for doc +// See ../messages.ts for doc export type MessageFederated = - ClientConnected | - NewFederatedNodeInfo | - SendPayload | - ReceiveServerPayload | - WaitingForMoreParticipants | - EnoughParticipants + | ClientConnected + | NewFederatedNodeInfo + | SendPayload + | ReceiveServerPayload + | WaitingForMoreParticipants + | EnoughParticipants; export interface NewFederatedNodeInfo { - type: type.NewFederatedNodeInfo - id: NodeID - waitForMoreParticipants: boolean + type: type.NewFederatedNodeInfo; + id: NodeID; + waitForMoreParticipants: boolean; payload: serialization.Encoded; - round: number - nbOfParticipants: number + round: number; + nbOfParticipants: number; } export interface SendPayload { - type: type.SendPayload + type: type.SendPayload; payload: serialization.Encoded; - round: number + round: number; } export interface ReceiveServerPayload { - type: type.ReceiveServerPayload + type: type.ReceiveServerPayload; payload: serialization.Encoded; - round: number, - nbOfParticipants: number // number of peers contributing to a federated training + round: number; + nbOfParticipants: number; // number of peers contributing to a federated training } -export function isMessageFederated (raw: unknown): raw is MessageFederated { +export function isMessageFederated(raw: unknown): raw is MessageFederated { if (!hasMessageType(raw)) { - return false + return false; } switch (raw.type) { @@ -47,8 +51,8 @@ export function isMessageFederated (raw: unknown): raw is MessageFederated { case type.ReceiveServerPayload: case type.WaitingForMoreParticipants: case type.EnoughParticipants: - return true + return true; } - return false + return false; } diff --git a/discojs/src/client/index.ts b/discojs/src/client/index.ts index ca2994190..f084d5a2f 100644 --- a/discojs/src/client/index.ts +++ b/discojs/src/client/index.ts @@ -1,11 +1,11 @@ -export { Client } from './client.js' +export { Client } from "./client.js"; -export * from './types.js' +export * from "./types.js"; -export * as aggregator from '../aggregator/index.js' -export * as decentralized from './decentralized/index.js' -export * as federated from './federated/index.js' -export * as messages from './messages.js' -export { getClient, timeout } from './utils.js' +export * as aggregator from "../aggregator/index.js"; +export * as decentralized from "./decentralized/index.js"; +export * as federated from "./federated/index.js"; +export * as messages from "./messages.js"; +export { getClient, timeout } from "./utils.js"; -export { LocalClient } from './local_client.js' +export { LocalClient } from "./local_client.js"; diff --git a/discojs/src/client/local_client.ts b/discojs/src/client/local_client.ts index 4ef477f2a..1b40cc6bb 100644 --- a/discojs/src/client/local_client.ts +++ b/discojs/src/client/local_client.ts @@ -6,12 +6,13 @@ import { Client } from "./client.js"; * with anyone. Thus LocalClient doesn't do anything during communication */ export class LocalClient extends Client<"local"> { - override onRoundBeginCommunication(): Promise { return Promise.resolve(); } - // Simply return the local weights - override onRoundEndCommunication(weights: WeightsContainer): Promise { + // Simply return the local weights + override onRoundEndCommunication( + weights: WeightsContainer, + ): Promise { return Promise.resolve(weights); } } diff --git a/discojs/src/client/messages.ts b/discojs/src/client/messages.ts index f5b5f9bb4..afc2f6556 100644 --- a/discojs/src/client/messages.ts +++ b/discojs/src/client/messages.ts @@ -1,21 +1,21 @@ -import type * as decentralized from './decentralized/messages.js' -import type * as federated from './federated/messages.js' +import type * as decentralized from "./decentralized/messages.js"; +import type * as federated from "./federated/messages.js"; export enum type { - // Sent from client to server as first point of contact to join a task. + // Sent from client to server as first point of contact to join a task. // The server answers with an node id in a NewFederatedNodeInfo // or NewDecentralizedNodeInfo message - ClientConnected, - + ClientConnected, + /* Decentralized */ // When a user joins a task with a ClientConnected message, the server // answers with its peer id and also tells the client whether we are waiting // for more participants before starting training NewDecentralizedNodeInfo, - // Message sent by peers to the server to signal they want to + // Message sent by peers to the server to signal they want to // join the next round JoinRound, - // Message sent by nodes to server signaling they are ready to + // Message sent by nodes to server signaling they are ready to // start the next round PeerIsReady, // Sent by the server to participating peers containing the list @@ -26,7 +26,7 @@ export enum type { SignalForPeer, // The weight update Payload, - + /* Federated */ // The server answers the ClientConnected message with the necessary information // to start training: node id, latest model global weights, current round etc @@ -42,37 +42,37 @@ export enum type { } export interface ClientConnected { - type: type.ClientConnected + type: type.ClientConnected; } export interface EnoughParticipants { - type: type.EnoughParticipants - nbOfParticipants: number + type: type.EnoughParticipants; + nbOfParticipants: number; } export interface WaitingForMoreParticipants { - type: type.WaitingForMoreParticipants - nbOfParticipants: number + type: type.WaitingForMoreParticipants; + nbOfParticipants: number; } export type Message = - decentralized.MessageFromServer | - decentralized.MessageToServer | - decentralized.PeerMessage | - federated.MessageFederated + | decentralized.MessageFromServer + | decentralized.MessageToServer + | decentralized.PeerMessage + | federated.MessageFederated; // Retrieve a specific message interface from the type D. i.e. NarrowMessage => messages.PeerId type -export type NarrowMessage = Extract +export type NarrowMessage = Extract; -export function hasMessageType (raw: unknown): raw is { type: type } & Record { - if (typeof raw !== 'object' || raw === null) return false +export function hasMessageType( + raw: unknown, +): raw is { type: type } & Record { + if (typeof raw !== "object" || raw === null) return false; - const o = raw as Record - if ( - !('type' in o && typeof o.type === 'number' && o.type in type) - ) { - return false + const o = raw as Record; + if (!("type" in o && typeof o.type === "number" && o.type in type)) { + return false; } - return true + return true; } diff --git a/discojs/src/client/types.ts b/discojs/src/client/types.ts index a73dc4a75..a42d585a0 100644 --- a/discojs/src/client/types.ts +++ b/discojs/src/client/types.ts @@ -1,6 +1,6 @@ -export type NodeID = string +export type NodeID = string; // TODO @s314cy: regexp test just like server-side -export function isNodeID (raw: unknown): raw is NodeID { - return typeof raw === 'string' +export function isNodeID(raw: unknown): raw is NodeID { + return typeof raw === "string"; } diff --git a/discojs/src/client/utils.ts b/discojs/src/client/utils.ts index 0987a7cd1..cfc3e6922 100644 --- a/discojs/src/client/utils.ts +++ b/discojs/src/client/utils.ts @@ -1,47 +1,52 @@ import type { DataType, Network, Task } from "../index.js"; -import { client as clients, type aggregator } from '../index.js' +import { client as clients, type aggregator } from "../index.js"; // Time to wait for the others in milliseconds. -const MAX_WAIT_PER_ROUND = 15_000 +const MAX_WAIT_PER_ROUND = 15_000; -export async function timeout (ms = MAX_WAIT_PER_ROUND, errorMsg: string = 'timeout'): Promise { +export async function timeout( + ms = MAX_WAIT_PER_ROUND, + errorMsg: string = "timeout", +): Promise { return await new Promise((_, reject) => { - setTimeout(() => { reject(new Error(errorMsg)) }, ms) - }) + setTimeout(() => { + reject(new Error(errorMsg)); + }, ms); + }); } export function getClient( - scheme: N | "local", - serverURL: URL, - task: Task, - aggregator: aggregator.Aggregator, + scheme: N | "local", + serverURL: URL, + task: Task, + aggregator: aggregator.Aggregator, ): clients.Client { - switch (scheme) { - case "decentralized": { - const t = task as Task; - t.trainingInformation.scheme = scheme; + switch (scheme) { + case "decentralized": { + const t = task as Task; + t.trainingInformation.scheme = scheme; - return new clients.decentralized.DecentralizedClient( - serverURL, - t, - aggregator, - ); - } - case "federated": { - const t = task as Task; - t.trainingInformation.scheme = scheme; + return new clients.decentralized.DecentralizedClient( + serverURL, + t, + aggregator, + ); + } + case "federated": { + const t = task as Task; + t.trainingInformation.scheme = scheme; - return new clients.federated.FederatedClient(serverURL, t, aggregator); - } - case "local": { - const t = task as Task; - t.trainingInformation.scheme = scheme; + return new clients.federated.FederatedClient(serverURL, t, aggregator); + } + case "local": { + const t = task as Task; + t.trainingInformation.scheme = scheme; - return new clients.LocalClient(serverURL, t, aggregator); - } - default: { - const _: never = scheme; - throw new Error("should never happen"); - } - } + return new clients.LocalClient(serverURL, t, aggregator); + } + default: { + const _: never = scheme; + throw new Error("should never happen"); + } + } } diff --git a/discojs/src/dataset/dataset.spec.ts b/discojs/src/dataset/dataset.spec.ts index f448121f1..65c91e937 100644 --- a/discojs/src/dataset/dataset.spec.ts +++ b/discojs/src/dataset/dataset.spec.ts @@ -148,50 +148,53 @@ describe("dataset", () => { expect( (await arrayFromAsync(batched)).map((l) => l.toArray()), - ).to.have.deep.ordered.members([[1, 2], [2, 3]]); + ).to.have.deep.ordered.members([ + [1, 2], + [2, 3], + ]); }); it("batch with overlap yields correct batches", async () => { - const expectedTokens = Range(0, 53).toList() - const contextLength = 4 + const expectedTokens = Range(0, 53).toList(); + const contextLength = 4; const parsed = new Dataset([expectedTokens]) .flatten() - .batch(contextLength + 1, 1) - + .batch(contextLength + 1, 1); + // -1 because the last sequence is dropped as there is no next token label - const expectedLength = Math.ceil(expectedTokens.size / contextLength) - 1 + const expectedLength = Math.ceil(expectedTokens.size / contextLength) - 1; expect(await parsed.size()).to.equal(expectedLength); - + // exclude the last sequence because it has been padded - let sequences = List(await arrayFromAsync(parsed)) + let sequences = List(await arrayFromAsync(parsed)); // we expect the last sequence to have contextLength + 1 tokens via padding - expect(sequences.last()?.size).to.equal(contextLength + 1) - sequences = sequences.pop() - let i = 0 + expect(sequences.last()?.size).to.equal(contextLength + 1); + sequences = sequences.pop(); + let i = 0; for (const tokens of sequences) { // each sequence has length contextLength + 1 (for the label) expect(tokens.toArray()).to.deep.equal( - expectedTokens.slice(i, i + contextLength + 1).toArray() + expectedTokens.slice(i, i + contextLength + 1).toArray(), ); // but the window should move by contextLength only - i += contextLength + i += contextLength; } - }) + }); it("repeats content infinitely", async () => { const dataset = new Dataset([0, 1, 2]).repeat(); - const iter = dataset[Symbol.asyncIterator]() + const iter = dataset[Symbol.asyncIterator](); for (const i of Range(0, 10)) { - const e = await iter.next() - expect(e.done).to.be.false - expect(e.value).to.equal(i % 3) + const e = await iter.next(); + expect(e.done).to.be.false; + expect(e.value).to.equal(i % 3); } }); it("repeats content a fixed number of times", async () => { const dataset = new Dataset([0, 1]).repeat(3); - expect([0,1,0,1,0,1]).to.deep.equal(await arrayFromAsync(dataset)) + expect([0, 1, 0, 1, 0, 1]).to.deep.equal(await arrayFromAsync(dataset)); }); }); diff --git a/discojs/src/dataset/dataset.ts b/discojs/src/dataset/dataset.ts index 12e2094e3..be7085ef8 100644 --- a/discojs/src/dataset/dataset.ts +++ b/discojs/src/dataset/dataset.ts @@ -13,8 +13,9 @@ type DatasetLike = | (() => Iterator); /** Convert a DatasetLike object to an async generator */ -async function* datasetLikeToGenerator(content: DatasetLike): - AsyncGenerator { +async function* datasetLikeToGenerator( + content: DatasetLike, +): AsyncGenerator { let iter: AsyncIterator | Iterator; if (typeof content === "function") iter = content(); else if (Symbol.asyncIterator in content) @@ -40,7 +41,7 @@ export class Dataset implements AsyncIterable { constructor(content: DatasetLike) { this.#content = async function* () { yield* datasetLikeToGenerator(content); - } + }; } [Symbol.asyncIterator](): AsyncIterator { @@ -120,11 +121,11 @@ export class Dataset implements AsyncIterable { /** Create batches of `size` elements with potential overlap. * Last batch is smaller if dataset isn't perfectly divisible - * - * If overlap is set to a positive integer, the last `overlap` elements of a batch + * + * If overlap is set to a positive integer, the last `overlap` elements of a batch * are the first `overlap` elements of the next batch. - * - * This method is tailored to create text sequences where each token's label is the following token. + * + * This method is tailored to create text sequences where each token's label is the following token. * In order to have a label for the last token of the input sequence, we include the first token * of the next sequence (i.e. with an overlap of 1). * @@ -132,8 +133,7 @@ export class Dataset implements AsyncIterable { * @param overlap number of elements overlapping between two consecutive batches */ batch(size: number, overlap = 0): Dataset> { - if (size <= 0 || !Number.isInteger(size)) - throw new Error("invalid size"); + if (size <= 0 || !Number.isInteger(size)) throw new Error("invalid size"); if (overlap >= size || !Number.isInteger(overlap)) throw new Error("invalid overlap"); @@ -146,8 +146,8 @@ export class Dataset implements AsyncIterable { const batch = List( // get the first elements of the next batch await Promise.all( - Range(overlapped.size, size).map(() => iter.next()) - ) + Range(overlapped.size, size).map(() => iter.next()), + ), ).flatMap((res) => { if (res.done) return []; else return [res.value]; @@ -214,8 +214,8 @@ export class Dataset implements AsyncIterable { let loop = 0; do { yield* this; - loop++ - } while (times === undefined || loop < times) + loop++; + } while (times === undefined || loop < times); }.bind(this), ); } @@ -239,64 +239,65 @@ export class Dataset implements AsyncIterable { } /** Shuffle the dataset - * + * * Shuffle within the sliding window */ - shuffle(windowSize: number){ - if (!Number.isInteger(windowSize) || windowSize < 1){ + shuffle(windowSize: number) { + if (!Number.isInteger(windowSize) || windowSize < 1) { throw new Error("Shuffle window size should be a positive integer"); } return new Dataset( - async function*(this: Dataset){ + async function* (this: Dataset) { const iter = this[Symbol.asyncIterator](); const buffer: T[] = []; // 1. Construct the initial buffer - while (buffer.length < windowSize){ + while (buffer.length < windowSize) { const n = await iter.next(); if (n.done) break; buffer.push(n.value); } // 2. Shuffle - while (buffer.length > 0){ + while (buffer.length > 0) { const pick = Math.floor(Math.random() * buffer.length); const chosen = buffer[pick]; const n = await iter.next(); - if (n.done){ + if (n.done) { // move the last element to the pick position buffer[pick] = buffer.pop() as T; - }else{ + } else { buffer[pick] = n.value; } yield chosen; } - }.bind(this) + }.bind(this), ); } /** Filter the dataset using the condition - * + * * Used for splitting dataset for each client (filter by client's id) */ filter( - condition: (value: T, index: number) => boolean | Promise - ): Dataset{ - return new Dataset(async function* (this: Dataset): AsyncGenerator{ - let i = 0; - for await(const v of this){ - if (await condition(v, i)){ - yield v; + condition: (value: T, index: number) => boolean | Promise, + ): Dataset { + return new Dataset( + async function* (this: Dataset): AsyncGenerator { + let i = 0; + for await (const v of this) { + if (await condition(v, i)) { + yield v; + } + i += 1; } - i += 1 - } - }.bind(this)); + }.bind(this), + ); } - } /** diff --git a/discojs/src/dataset/types.ts b/discojs/src/dataset/types.ts index dfd190aba..915cedc71 100644 --- a/discojs/src/dataset/types.ts +++ b/discojs/src/dataset/types.ts @@ -1,6 +1,6 @@ import { List } from "immutable"; -import { Image } from "./image.js" +import { Image } from "./image.js"; export type Batched = List; diff --git a/discojs/src/default_tasks/cifar10.ts b/discojs/src/default_tasks/cifar10.ts index 3a3537440..79a9d7520 100644 --- a/discojs/src/default_tasks/cifar10.ts +++ b/discojs/src/default_tasks/cifar10.ts @@ -1,23 +1,27 @@ -import * as tf from '@tensorflow/tfjs' +import * as tf from "@tensorflow/tfjs"; import type { Model, TaskProvider } from "../index.js"; -import { models } from '../index.js' +import { models } from "../index.js"; -import baseModel from '../models/mobileNet_v1_025_224.js' +import baseModel from "../models/mobileNet_v1_025_224.js"; export const cifar10: TaskProvider<"image", "decentralized"> = { getTask() { return Promise.resolve({ - id: 'cifar10', + id: "cifar10", dataType: "image", displayInformation: { - title: 'CIFAR10', + title: "CIFAR10", summary: { - preview: 'CIFAR-10 is a classic image classification task, and one of the most widely used datasets for machine learning research.', - overview: "The dataset contains 60,000 32x32 color images in 10 different classes: airplanes, cars, birds, cats, deer, dogs, frogs, horses, ships, and trucks. The official CIFAR-10 website can be found at https://www.cs.toronto.edu/~kriz/cifar.html . You can find a link to a sample dataset at the next step." + preview: + "CIFAR-10 is a classic image classification task, and one of the most widely used datasets for machine learning research.", + overview: + "The dataset contains 60,000 32x32 color images in 10 different classes: airplanes, cars, birds, cats, deer, dogs, frogs, horses, ships, and trucks. The official CIFAR-10 website can be found at https://www.cs.toronto.edu/~kriz/cifar.html . You can find a link to a sample dataset at the next step.", }, - model: 'The model is a pretrained MobileNetV1 model trained in Tensorflow.js. The last output layer is replaced with a fully connected layer with softmax activation and one output neuron per CIFAR10 category. The data preprocessing reshapes images into 224x224 pixels and normalizes values between 0 and 1. The neural network is optimized via Stochastic Gradient Descent and a categorical Cross Entropy loss.', - dataFormatInformation: 'Images should be of .png format and of size 32x32.
The CSV file should start with the exact header "filename,label", and each row should contain an image filename (without extension) and its label.

For example if you have images: 0.png (of a frog) and 1.png (of a car)
The CSV file should be:
filename, label

0, frog
1, car', + model: + "The model is a pretrained MobileNetV1 model trained in Tensorflow.js. The last output layer is replaced with a fully connected layer with softmax activation and one output neuron per CIFAR10 category. The data preprocessing reshapes images into 224x224 pixels and normalizes values between 0 and 1. The neural network is optimized via Stochastic Gradient Descent and a categorical Cross Entropy loss.", + dataFormatInformation: + 'Images should be of .png format and of size 32x32.
The CSV file should start with the exact header "filename,label", and each row should contain an image filename (without extension) and its label.

For example if you have images: 0.png (of a frog) and 1.png (of a car)
The CSV file should be:
filename, label

0, frog
1, car', dataExample: "https://storage.googleapis.com/deai-313515.appspot.com/example_training_data/cifar10-example.png", sampleDataset: { @@ -33,11 +37,22 @@ export const cifar10: TaskProvider<"image", "decentralized"> = { batchSize: 10, IMAGE_H: 224, IMAGE_W: 224, - LABEL_LIST: ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'], - scheme: 'decentralized', - maxIterations: 1, - beta: 0.9, - aggregationStrategy: 'mean', + LABEL_LIST: [ + "airplane", + "automobile", + "bird", + "cat", + "deer", + "dog", + "frog", + "horse", + "ship", + "truck", + ], + scheme: "decentralized", + maxIterations: 1, + beta: 0.9, + aggregationStrategy: "mean", privacy: { differentialPrivacy: { clippingRadius: 1, @@ -47,33 +62,33 @@ export const cifar10: TaskProvider<"image", "decentralized"> = { }, minNbOfParticipants: 3, maxShareValue: 100, - tensorBackend: 'tfjs' - } + tensorBackend: "tfjs", + }, }); }, - async getModel (): Promise> { + async getModel(): Promise> { const mobilenet = await tf.loadLayersModel({ load: async () => Promise.resolve(baseModel), - }) + }); - const x = mobilenet.getLayer('global_average_pooling2d_1') + const x = mobilenet.getLayer("global_average_pooling2d_1"); const predictions = tf.layers - .dense({ units: 10, activation: 'softmax', name: 'denseModified' }) - .apply(x.output) as tf.SymbolicTensor + .dense({ units: 10, activation: "softmax", name: "denseModified" }) + .apply(x.output) as tf.SymbolicTensor; const model = tf.model({ inputs: mobilenet.input, outputs: predictions, - name: 'modelModified' - }) + name: "modelModified", + }); model.compile({ - optimizer: 'sgd', - loss: 'categoricalCrossentropy', - metrics: ['accuracy'], - }) + optimizer: "sgd", + loss: "categoricalCrossentropy", + metrics: ["accuracy"], + }); - return new models.TFJS('image', model) - } -} + return new models.TFJS("image", model); + }, +}; diff --git a/discojs/src/default_tasks/index.ts b/discojs/src/default_tasks/index.ts index 43adf0d3c..6cbe47c8a 100644 --- a/discojs/src/default_tasks/index.ts +++ b/discojs/src/default_tasks/index.ts @@ -1,7 +1,7 @@ -export { cifar10 } from './cifar10.js' -export { lusCovid } from './lus_covid.js' -export { mnist } from './mnist.js' -export { simpleFace } from './simple_face.js' -export { titanic } from './titanic.js' -export { wikitext } from './wikitext.js' -export { tinderDog } from './tinder_dog.js' \ No newline at end of file +export { cifar10 } from "./cifar10.js"; +export { lusCovid } from "./lus_covid.js"; +export { mnist } from "./mnist.js"; +export { simpleFace } from "./simple_face.js"; +export { titanic } from "./titanic.js"; +export { wikitext } from "./wikitext.js"; +export { tinderDog } from "./tinder_dog.js"; diff --git a/discojs/src/default_tasks/lus_covid.ts b/discojs/src/default_tasks/lus_covid.ts index 234370735..eaee56885 100644 --- a/discojs/src/default_tasks/lus_covid.ts +++ b/discojs/src/default_tasks/lus_covid.ts @@ -1,21 +1,25 @@ -import * as tf from '@tensorflow/tfjs' +import * as tf from "@tensorflow/tfjs"; import type { Model, TaskProvider } from "../index.js"; -import { models } from '../index.js' +import { models } from "../index.js"; export const lusCovid: TaskProvider<"image", "federated"> = { getTask() { return Promise.resolve({ - id: 'lus_covid', + id: "lus_covid", dataType: "image", displayInformation: { - title: 'Lung Ultrasound Image Classification', + title: "Lung Ultrasound Image Classification", summary: { - preview: "Medical images are a typical example of data that exists in huge quantity yet that can't be shared due to confidentiality reasons. Medical applications would immensely benefit from training on data currently locked. More data diversity leads to better generalization and bias mitigation.", - overview: "Disco allows data owners to collaboratively train machine learning models using their respective data without any privacy breach. This example problem is about diagnosing whether patients are positive or negative to COVID-19 from lung ultrasounds images. You can find a link to a sample dataset at the next step." + preview: + "Medical images are a typical example of data that exists in huge quantity yet that can't be shared due to confidentiality reasons. Medical applications would immensely benefit from training on data currently locked. More data diversity leads to better generalization and bias mitigation.", + overview: + "Disco allows data owners to collaboratively train machine learning models using their respective data without any privacy breach. This example problem is about diagnosing whether patients are positive or negative to COVID-19 from lung ultrasounds images. You can find a link to a sample dataset at the next step.", }, - model: "The model is a simple Convolutional Neural Network composed of two convolutional layers with ReLU activations and max pooling layers, followed by a fully connected output layer. The data preprocessing reshapes images into 100x100 pixels and normalizes values between 0 and 1", - dataFormatInformation: 'This model takes as input an image dataset of lung ultrasounds. The images are resized automatically.', + model: + "The model is a simple Convolutional Neural Network composed of two convolutional layers with ReLU activations and max pooling layers, followed by a fully connected output layer. The data preprocessing reshapes images into 100x100 pixels and normalizes values between 0 and 1", + dataFormatInformation: + "This model takes as input an image dataset of lung ultrasounds. The images are resized automatically.", dataExample: "https://storage.googleapis.com/deai-313515.appspot.com/example_training_data/2_QAID_1.masked.reshaped.squared.224.png", sampleDataset: { @@ -31,70 +35,76 @@ export const lusCovid: TaskProvider<"image", "federated"> = { batchSize: 5, IMAGE_H: 100, IMAGE_W: 100, - LABEL_LIST: ['COVID-Positive', 'COVID-Negative'], - scheme: 'federated', - aggregationStrategy: 'mean', + LABEL_LIST: ["COVID-Positive", "COVID-Negative"], + scheme: "federated", + aggregationStrategy: "mean", minNbOfParticipants: 2, - tensorBackend: 'tfjs' - } + tensorBackend: "tfjs", + }, }); }, - // Model architecture from tensorflow.js docs: + // Model architecture from tensorflow.js docs: // https://codelabs.developers.google.com/codelabs/tfjs-training-classfication/index.html#4 - async getModel (): Promise> { - const imageHeight = 100 - const imageWidth = 100 - const imageChannels = 3 - const numOutputClasses = 2 - const model = tf.sequential() + async getModel(): Promise> { + const imageHeight = 100; + const imageWidth = 100; + const imageChannels = 3; + const numOutputClasses = 2; + const model = tf.sequential(); // In the first layer of our convolutional neural network we have // to specify the input shape. Then we specify some parameters for // the convolution operation that takes place in this layer. - model.add(tf.layers.conv2d({ - inputShape: [imageHeight, imageWidth, imageChannels], - kernelSize: 5, - filters: 8, - strides: 1, - activation: 'relu', - kernelInitializer: 'varianceScaling' - })) + model.add( + tf.layers.conv2d({ + inputShape: [imageHeight, imageWidth, imageChannels], + kernelSize: 5, + filters: 8, + strides: 1, + activation: "relu", + kernelInitializer: "varianceScaling", + }), + ); // The MaxPooling layer acts as a sort of downsampling using max values // in a region instead of averaging. - model.add(tf.layers.maxPooling2d({ poolSize: [2, 2], strides: [2, 2] })) + model.add(tf.layers.maxPooling2d({ poolSize: [2, 2], strides: [2, 2] })); // Repeat the conv2d + maxPooling block. // Note that we have more filters in the convolution. - model.add(tf.layers.conv2d({ - kernelSize: 5, - filters: 16, - strides: 1, - activation: 'relu', - kernelInitializer: 'varianceScaling' - })) - model.add(tf.layers.maxPooling2d({ poolSize: [2, 2], strides: [2, 2] })) + model.add( + tf.layers.conv2d({ + kernelSize: 5, + filters: 16, + strides: 1, + activation: "relu", + kernelInitializer: "varianceScaling", + }), + ); + model.add(tf.layers.maxPooling2d({ poolSize: [2, 2], strides: [2, 2] })); // Now we flatten the output from the 2D filters into a 1D vector to prepare // it for input into our last layer. This is common practice when feeding // higher dimensional data to a final classification output layer. - model.add(tf.layers.flatten()) + model.add(tf.layers.flatten()); // Our last layer is a dense layer which has 2 output units, one for each // output class. - model.add(tf.layers.dense({ - units: numOutputClasses, - kernelInitializer: 'varianceScaling', - activation: 'softmax' - })) + model.add( + tf.layers.dense({ + units: numOutputClasses, + kernelInitializer: "varianceScaling", + activation: "softmax", + }), + ); model.compile({ - optimizer: 'sgd', - loss: 'binaryCrossentropy', - metrics: ['accuracy'] - }) + optimizer: "sgd", + loss: "binaryCrossentropy", + metrics: ["accuracy"], + }); - return Promise.resolve(new models.TFJS('image', model)) - } -} + return Promise.resolve(new models.TFJS("image", model)); + }, +}; diff --git a/discojs/src/default_tasks/mnist.ts b/discojs/src/default_tasks/mnist.ts index 71b148507..073359378 100644 --- a/discojs/src/default_tasks/mnist.ts +++ b/discojs/src/default_tasks/mnist.ts @@ -1,21 +1,25 @@ -import * as tf from '@tensorflow/tfjs' +import * as tf from "@tensorflow/tfjs"; import type { Model, TaskProvider } from "../index.js"; -import { models } from '../index.js' +import { models } from "../index.js"; export const mnist: TaskProvider<"image", "decentralized"> = { getTask() { return Promise.resolve({ - id: 'mnist', + id: "mnist", dataType: "image", displayInformation: { - title: 'Handwritten Digit Recognition', + title: "Handwritten Digit Recognition", summary: { - preview: "The MNIST handwritten digit classification problem is a classic dataset used in computer vision and deep learning. The objective is to classify handwritten digits from 28x28 pixel images.", - overview: "Download the classic MNIST dataset of hand-written numbers at https://www.kaggle.com/scolianni/mnistasjpg . You can also find a sample dataset at the next step." + preview: + "The MNIST handwritten digit classification problem is a classic dataset used in computer vision and deep learning. The objective is to classify handwritten digits from 28x28 pixel images.", + overview: + "Download the classic MNIST dataset of hand-written numbers at https://www.kaggle.com/scolianni/mnistasjpg . You can also find a sample dataset at the next step.", }, - model: "The model is a simple Convolutional Neural Network composed of three convolutional layers with ReLU activations and max pooling layers, followed by two fully connected layers. The data preprocessing simply normalizes values between 0 and 1. The neural network is optimized via RMSProp and a categorical cross-entropy loss.", - dataFormatInformation: 'This model is trained on images corresponding to digits 0 to 9. You can connect your own images of each digit in the box corresponding to its label. The model takes images of size 28x28 as input.', + model: + "The model is a simple Convolutional Neural Network composed of three convolutional layers with ReLU activations and max pooling layers, followed by two fully connected layers. The data preprocessing simply normalizes values between 0 and 1. The neural network is optimized via RMSProp and a categorical cross-entropy loss.", + dataFormatInformation: + "This model is trained on images corresponding to digits 0 to 9. You can connect your own images of each digit in the box corresponding to its label. The model takes images of size 28x28 as input.", dataExample: "http://storage.googleapis.com/deai-313515.appspot.com/example_training_data/9-mnist-example.png", sampleDataset: { @@ -31,51 +35,53 @@ export const mnist: TaskProvider<"image", "decentralized"> = { batchSize: 64, IMAGE_H: 28, IMAGE_W: 28, - LABEL_LIST: ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'], - scheme: 'decentralized', - aggregationStrategy: "byzantine", - privacy: { - byzantineFaultTolerance: { - clippingRadius: 10, - maxIterations: 1, - beta: 0.9, - }, - }, + LABEL_LIST: ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"], + scheme: "decentralized", + aggregationStrategy: "byzantine", + privacy: { + byzantineFaultTolerance: { + clippingRadius: 10, + maxIterations: 1, + beta: 0.9, + }, + }, minNbOfParticipants: 3, maxShareValue: 100, - tensorBackend: 'tfjs' - } + tensorBackend: "tfjs", + }, }); }, - getModel(): Promise> { + getModel(): Promise> { // Architecture from the PyTorch MNIST example (I made it slightly smaller, 650kB instead of 5MB) // https://github.com/pytorch/examples/blob/main/mnist/main.py - const model = tf.sequential() + const model = tf.sequential(); model.add( tf.layers.conv2d({ inputShape: [28, 28, 3], kernelSize: 5, filters: 8, - activation: 'relu', - }) - ) - model.add(tf.layers.conv2d({ kernelSize: 5, filters: 16, activation: 'relu' })) - model.add(tf.layers.maxPooling2d({ poolSize: 2, strides: 2 })) - model.add(tf.layers.dropout({ rate: 0.25 })) + activation: "relu", + }), + ); + model.add( + tf.layers.conv2d({ kernelSize: 5, filters: 16, activation: "relu" }), + ); + model.add(tf.layers.maxPooling2d({ poolSize: 2, strides: 2 })); + model.add(tf.layers.dropout({ rate: 0.25 })); - model.add(tf.layers.flatten()) - model.add(tf.layers.dense({ units: 32, activation: 'relu' })) - model.add(tf.layers.dropout({rate:0.25})) - model.add(tf.layers.dense({ units: 10, activation: 'softmax' })) + model.add(tf.layers.flatten()); + model.add(tf.layers.dense({ units: 32, activation: "relu" })); + model.add(tf.layers.dropout({ rate: 0.25 })); + model.add(tf.layers.dense({ units: 10, activation: "softmax" })); model.compile({ - optimizer: 'adam', - loss: 'categoricalCrossentropy', - metrics: ['accuracy'] - }) + optimizer: "adam", + loss: "categoricalCrossentropy", + metrics: ["accuracy"], + }); - return Promise.resolve(new models.TFJS('image', model)) - } -} + return Promise.resolve(new models.TFJS("image", model)); + }, +}; diff --git a/discojs/src/default_tasks/simple_face.ts b/discojs/src/default_tasks/simple_face.ts index 96e94e6b3..ad45a8d8a 100644 --- a/discojs/src/default_tasks/simple_face.ts +++ b/discojs/src/default_tasks/simple_face.ts @@ -1,21 +1,23 @@ -import * as tf from '@tensorflow/tfjs' +import * as tf from "@tensorflow/tfjs"; import type { Model, TaskProvider } from "../index.js"; -import { models } from '../index.js' -import baseModel from '../models/mobileNetV2_35_alpha_2_classes.js' +import { models } from "../index.js"; +import baseModel from "../models/mobileNetV2_35_alpha_2_classes.js"; export const simpleFace: TaskProvider<"image", "federated"> = { getTask() { return Promise.resolve({ - id: 'simple_face', + id: "simple_face", dataType: "image", displayInformation: { - title: 'Simple Face', + title: "Simple Face", summary: { - preview: 'Can you detect if the person in a picture is a child or an adult?', - overview: 'Simple face is a small subset of the public face_task dataset from Kaggle' + preview: + "Can you detect if the person in a picture is a child or an adult?", + overview: + "Simple face is a small subset of the public face_task dataset from Kaggle", }, - dataFormatInformation: '', + dataFormatInformation: "", dataExample: "https://storage.googleapis.com/deai-313515.appspot.com/example_training_data/simple_face-example.png", sampleDataset: { @@ -31,26 +33,26 @@ export const simpleFace: TaskProvider<"image", "federated"> = { batchSize: 10, IMAGE_H: 200, IMAGE_W: 200, - LABEL_LIST: ['child', 'adult'], - scheme: 'federated', - aggregationStrategy: 'mean', + LABEL_LIST: ["child", "adult"], + scheme: "federated", + aggregationStrategy: "mean", minNbOfParticipants: 2, - tensorBackend: 'tfjs' - } + tensorBackend: "tfjs", + }, }); }, - async getModel (): Promise> { + async getModel(): Promise> { const model = await tf.loadLayersModel({ load: async () => Promise.resolve(baseModel), }); model.compile({ optimizer: tf.train.sgd(0.001), - loss: 'categoricalCrossentropy', - metrics: ['accuracy'] - }) + loss: "categoricalCrossentropy", + metrics: ["accuracy"], + }); - return new models.TFJS('image', model) - } -} + return new models.TFJS("image", model); + }, +}; diff --git a/discojs/src/default_tasks/tinder_dog.ts b/discojs/src/default_tasks/tinder_dog.ts index 17babbfc5..e59c347fe 100644 --- a/discojs/src/default_tasks/tinder_dog.ts +++ b/discojs/src/default_tasks/tinder_dog.ts @@ -1,21 +1,23 @@ -import * as tf from '@tensorflow/tfjs' +import * as tf from "@tensorflow/tfjs"; -import type { Model, TaskProvider } from '../index.js' -import { models } from '../index.js' +import type { Model, TaskProvider } from "../index.js"; +import { models } from "../index.js"; export const tinderDog: TaskProvider<"image", "federated"> = { getTask() { return Promise.resolve({ - id: 'tinder_dog', + id: "tinder_dog", dataType: "image", displayInformation: { title: "GDHF 2024 | TinderDog", summary: { - preview: 'Which dog is the cutest....or not?', - overview: "Binary classification model for dog cuteness." + preview: "Which dog is the cutest....or not?", + overview: "Binary classification model for dog cuteness.", }, - model: 'The model is a simple Convolutional Neural Network composed of two convolutional layers with ReLU activations and max pooling layers, followed by a fully connected output layer. The data preprocessing reshapes images into 64x64 pixels and normalizes values between 0 and 1', - dataFormatInformation: 'Accepted image formats are .png .jpg and .jpeg.', + model: + "The model is a simple Convolutional Neural Network composed of two convolutional layers with ReLU activations and max pooling layers, followed by a fully connected output layer. The data preprocessing reshapes images into 64x64 pixels and normalizes values between 0 and 1", + dataFormatInformation: + "Accepted image formats are .png .jpg and .jpeg.", dataExample: "https://storage.googleapis.com/deai-313515.appspot.com/tinder_dog_preview.png", sampleDataset: { @@ -31,59 +33,68 @@ export const tinderDog: TaskProvider<"image", "federated"> = { batchSize: 10, IMAGE_H: 64, IMAGE_W: 64, - LABEL_LIST: ['Cute dogs', 'Less cute dogs'], - scheme: 'federated', - aggregationStrategy: 'mean', + LABEL_LIST: ["Cute dogs", "Less cute dogs"], + scheme: "federated", + aggregationStrategy: "mean", minNbOfParticipants: 3, - tensorBackend: 'tfjs' - } + tensorBackend: "tfjs", + }, }); }, - - async getModel(): Promise> { + async getModel(): Promise> { const task = await this.getTask(); - const seed = 42 // set a seed to ensure reproducibility during GDHF demo - const imageHeight = task.trainingInformation.IMAGE_H - const imageWidth = task.trainingInformation.IMAGE_W - const imageChannels = 3 + const seed = 42; // set a seed to ensure reproducibility during GDHF demo + const imageHeight = task.trainingInformation.IMAGE_H; + const imageWidth = task.trainingInformation.IMAGE_W; + const imageChannels = 3; - const model = tf.sequential() + const model = tf.sequential(); model.add( tf.layers.conv2d({ inputShape: [imageHeight, imageWidth, imageChannels], kernelSize: 5, filters: 8, - activation: 'relu', - kernelInitializer: tf.initializers.heNormal({ seed }) - }) - ) - model.add(tf.layers.conv2d({ - kernelSize: 5, filters: 16, activation: 'relu', - kernelInitializer: tf.initializers.heNormal({ seed }) - })) - model.add(tf.layers.maxPooling2d({ poolSize: 2, strides: 2 })) - model.add(tf.layers.dropout({ rate: 0.25, seed })) + activation: "relu", + kernelInitializer: tf.initializers.heNormal({ seed }), + }), + ); + model.add( + tf.layers.conv2d({ + kernelSize: 5, + filters: 16, + activation: "relu", + kernelInitializer: tf.initializers.heNormal({ seed }), + }), + ); + model.add(tf.layers.maxPooling2d({ poolSize: 2, strides: 2 })); + model.add(tf.layers.dropout({ rate: 0.25, seed })); - model.add(tf.layers.flatten()) - model.add(tf.layers.dense({ - units: 32, activation: 'relu', - kernelInitializer: tf.initializers.heNormal({ seed }) - })) - model.add(tf.layers.dropout({rate:0.25, seed})) - model.add(tf.layers.dense({ - units: 2, activation: 'softmax', - kernelInitializer: tf.initializers.heNormal({ seed }) - })) + model.add(tf.layers.flatten()); + model.add( + tf.layers.dense({ + units: 32, + activation: "relu", + kernelInitializer: tf.initializers.heNormal({ seed }), + }), + ); + model.add(tf.layers.dropout({ rate: 0.25, seed })); + model.add( + tf.layers.dense({ + units: 2, + activation: "softmax", + kernelInitializer: tf.initializers.heNormal({ seed }), + }), + ); model.compile({ optimizer: tf.train.adam(0.0005), - loss: 'categoricalCrossentropy', - metrics: ['accuracy'] - }) + loss: "categoricalCrossentropy", + metrics: ["accuracy"], + }); - return Promise.resolve(new models.TFJS('image', model)) - } -} + return Promise.resolve(new models.TFJS("image", model)); + }, +}; diff --git a/discojs/src/default_tasks/titanic.ts b/discojs/src/default_tasks/titanic.ts index 5ec8829d7..4908e21e1 100644 --- a/discojs/src/default_tasks/titanic.ts +++ b/discojs/src/default_tasks/titanic.ts @@ -1,21 +1,25 @@ -import * as tf from '@tensorflow/tfjs' +import * as tf from "@tensorflow/tfjs"; import type { Model, TaskProvider } from "../index.js"; -import { models } from '../index.js' +import { models } from "../index.js"; export const titanic: TaskProvider<"tabular", "federated"> = { getTask() { return Promise.resolve({ - id: 'titanic', + id: "titanic", dataType: "tabular", displayInformation: { - title: 'Titanic Prediction', + title: "Titanic Prediction", summary: { - preview: "The Titanic classification task is one of the main entrypoints into machine learning. Using passenger data (name, age, gender, socio-economic class, etc), the goal is to identify who was more likely to survive the infamous shipwreck.", - overview: "The original competition can be found on Kaggle (https://www.kaggle.com/c/titanic) and a link to the training set can be found here: https://storage.googleapis.com/deai-313515.appspot.com/example_training_data/titanic_train.csv" + preview: + "The Titanic classification task is one of the main entrypoints into machine learning. Using passenger data (name, age, gender, socio-economic class, etc), the goal is to identify who was more likely to survive the infamous shipwreck.", + overview: + "The original competition can be found on Kaggle (https://www.kaggle.com/c/titanic) and a link to the training set can be found here: https://storage.googleapis.com/deai-313515.appspot.com/example_training_data/titanic_train.csv", }, - model: 'The model is a simple 5-layer feedforward network with ReLU activations. The model is optimized with Adam and binary cross-entropy loss. The preprocessing only fills missing value with a placeholder value (0).', - dataFormatInformation: 'The expected format for the tabular dataset is exactly the same as the sample data provided above or in the Kaggle competition. It is a CSV file with 12 columns. The features are general information about the passenger (sex, age, name, etc.) and specific related Titanic data such as the ticket class bought by the passenger, its cabin number, etc. The first line of the CSV contains the header: "PassengerId, Survived, Pclass, Name, Sex, Age, SibSp, Parch, Ticket, Fare, Cabin, Embarked". Each subsequent row contains passenger data.', + model: + "The model is a simple 5-layer feedforward network with ReLU activations. The model is optimized with Adam and binary cross-entropy loss. The preprocessing only fills missing value with a placeholder value (0).", + dataFormatInformation: + 'The expected format for the tabular dataset is exactly the same as the sample data provided above or in the Kaggle competition. It is a CSV file with 12 columns. The features are general information about the passenger (sex, age, name, etc.) and specific related Titanic data such as the ticket class bought by the passenger, its cabin number, etc. The first line of the CSV contains the header: "PassengerId, Survived, Pclass, Name, Sex, Age, SibSp, Parch, Ticket, Fare, Cabin, Embarked". Each subsequent row contains passenger data.', dataExample: [ { name: "PassengerId", data: "1" }, { name: "Survived", data: "0" }, @@ -41,43 +45,37 @@ export const titanic: TaskProvider<"tabular", "federated"> = { roundDuration: 2, validationSplit: 0.2, batchSize: 30, - inputColumns: [ - 'Age', - 'SibSp', - 'Parch', - 'Fare', - 'Pclass' - ], - outputColumn: 'Survived', - scheme: 'federated', - aggregationStrategy: 'mean', + inputColumns: ["Age", "SibSp", "Parch", "Fare", "Pclass"], + outputColumn: "Survived", + scheme: "federated", + aggregationStrategy: "mean", minNbOfParticipants: 2, - tensorBackend: 'tfjs' - } + tensorBackend: "tfjs", + }, }); }, - getModel (): Promise> { - const model = tf.sequential() + getModel(): Promise> { + const model = tf.sequential(); model.add( tf.layers.dense({ inputShape: [5], units: 124, - activation: 'relu', - kernelInitializer: 'leCunNormal' - }) - ) - model.add(tf.layers.dense({ units: 64, activation: 'relu' })) - model.add(tf.layers.dense({ units: 32, activation: 'relu' })) - model.add(tf.layers.dense({ units: 1, activation: 'sigmoid' })) + activation: "relu", + kernelInitializer: "leCunNormal", + }), + ); + model.add(tf.layers.dense({ units: 64, activation: "relu" })); + model.add(tf.layers.dense({ units: 32, activation: "relu" })); + model.add(tf.layers.dense({ units: 1, activation: "sigmoid" })); model.compile({ - optimizer: 'adam', - loss: 'binaryCrossentropy', - metrics: ['accuracy'] - }) + optimizer: "adam", + loss: "binaryCrossentropy", + metrics: ["accuracy"], + }); - return Promise.resolve(new models.TFJS('tabular', model)) - } -} + return Promise.resolve(new models.TFJS("tabular", model)); + }, +}; diff --git a/discojs/src/default_tasks/wikitext.ts b/discojs/src/default_tasks/wikitext.ts index d5cf128e7..354cb24d9 100644 --- a/discojs/src/default_tasks/wikitext.ts +++ b/discojs/src/default_tasks/wikitext.ts @@ -4,13 +4,15 @@ import { Tokenizer, models } from "../index.js"; export const wikitext: TaskProvider<"text", "federated"> = { async getTask() { return { - id: 'llm_task', + id: "llm_task", dataType: "text", displayInformation: { - title: "GPT Language Modeling", + title: "GPT Language Modeling", summary: { - preview: 'Train a language model (L)LM in your browser, collaboratively and from scratch.', - overview: "You can train a GPT-2 model in your browser and in a collaborative manner on any textual dataset. As an example, you can try the Wikitext-103 dataset, composed of Wikipedia articles, widely used in natural language modeling. More information on how to connect the dataset at the next step." + preview: + "Train a language model (L)LM in your browser, collaboratively and from scratch.", + overview: + "You can train a GPT-2 model in your browser and in a collaborative manner on any textual dataset. As an example, you can try the Wikitext-103 dataset, composed of Wikipedia articles, widely used in natural language modeling. More information on how to connect the dataset at the next step.", }, model: [ "The model follows the exact GPT-2 architecture and is implemented in TensorFlow.js.", @@ -19,30 +21,31 @@ export const wikitext: TaskProvider<"text", "federated"> = { "It has around 5M parameters.", "To accommodate all devices, the context length is currently kept at 128 and the batch size at 1.", ].join(" "), - dataFormatInformation: 'You can use any natural language (text) dataset you like. For example the Wikitext-103 dataset is organized as a large text file, with each line representing a segment of raw text from Wikipedia articles.', - dataExample: - "For the first twenty years of its existence , the only staged performances of Parsifal took place in the Bayreuth Festspielhaus , the venue for which Wagner conceived the work ( except eight private performances for Ludwig II at Munich in 1884 and 1885 ) .", - sampleDataset: { - link: "https://storage.googleapis.com/deai-313515.appspot.com/wikitext.zip", - instructions: - 'Opening the link should start downloading a zip file. Unzip it and drag and drop the training set named "wiki.train.tokens" in the field below (or use the "Select File" button). Even though the file extension is ".tokens" it is indeed a text file. You can use "wiki.test.tokens" at the evaluation step after training a language model.', - }, + dataFormatInformation: + "You can use any natural language (text) dataset you like. For example the Wikitext-103 dataset is organized as a large text file, with each line representing a segment of raw text from Wikipedia articles.", + dataExample: + "For the first twenty years of its existence , the only staged performances of Parsifal took place in the Bayreuth Festspielhaus , the venue for which Wagner conceived the work ( except eight private performances for Ludwig II at Munich in 1884 and 1885 ) .", + sampleDataset: { + link: "https://storage.googleapis.com/deai-313515.appspot.com/wikitext.zip", + instructions: + 'Opening the link should start downloading a zip file. Unzip it and drag and drop the training set named "wiki.train.tokens" in the field below (or use the "Select File" button). Even though the file extension is ".tokens" it is indeed a text file. You can use "wiki.test.tokens" at the evaluation step after training a language model.', + }, }, trainingInformation: { - scheme: 'federated', - aggregationStrategy: 'mean', + scheme: "federated", + aggregationStrategy: "mean", minNbOfParticipants: 2, epochs: 6, // Unused by wikitext because data already comes split // But if set to 0 then the webapp doesn't display the validation metrics - validationSplit: 0.1, + validationSplit: 0.1, roundDuration: 2, batchSize: 8, // If set too high firefox raises a WebGL error tokenizer: await Tokenizer.from_pretrained("Xenova/gpt2"), contextLength: 64, - tensorBackend: 'gpt' - } - } + tensorBackend: "gpt", + }, + }; }, async getModel() { @@ -52,4 +55,4 @@ export const wikitext: TaskProvider<"text", "federated"> = { contextLength: task.trainingInformation.contextLength, }); }, -} +}; diff --git a/discojs/src/index.ts b/discojs/src/index.ts index 188090b0e..1930aa14c 100644 --- a/discojs/src/index.ts +++ b/discojs/src/index.ts @@ -1,15 +1,20 @@ -export * as data from './dataset/index.js' -export * as serialization from './serialization/index.js' -export * as training from './training/index.js' -export * as privacy from './privacy.js' +export * as data from "./dataset/index.js"; +export * as serialization from "./serialization/index.js"; +export * as training from "./training/index.js"; +export * as privacy from "./privacy.js"; -export * as client from './client/index.js' -export * as aggregator from './aggregator/index.js' +export * as client from "./client/index.js"; +export * as aggregator from "./aggregator/index.js"; -export { WeightsContainer, aggregation } from './weights/index.js' -export { Logger, ConsoleLogger } from './logging/index.js' -export { Disco, RoundLogs, RoundStatus, SummaryLogs } from './training/index.js' -export { Validator } from './validator.js' +export { WeightsContainer, aggregation } from "./weights/index.js"; +export { Logger, ConsoleLogger } from "./logging/index.js"; +export { + Disco, + RoundLogs, + RoundStatus, + SummaryLogs, +} from "./training/index.js"; +export { Validator } from "./validator.js"; export { Model, @@ -18,13 +23,13 @@ export { Tokenizer, ValidationMetrics, } from "./models/index.js"; -export * as models from './models/index.js' +export * as models from "./models/index.js"; -export * from './task/index.js' -export * as defaultTasks from './default_tasks/index.js' +export * from "./task/index.js"; +export * as defaultTasks from "./default_tasks/index.js"; -export * as async_iterator from "./utils/async_iterator.js" -export { EventEmitter } from "./utils/event_emitter.js" +export * as async_iterator from "./utils/async_iterator.js"; +export { EventEmitter } from "./utils/event_emitter.js"; export * from "./dataset/index.js"; export * from "./types/index.js"; diff --git a/discojs/src/logging/index.ts b/discojs/src/logging/index.ts index 37b91c39d..5b43b5751 100644 --- a/discojs/src/logging/index.ts +++ b/discojs/src/logging/index.ts @@ -1,2 +1,2 @@ -export { Logger } from './logger.js' -export { ConsoleLogger } from './console_logger.js' +export { Logger } from "./logger.js"; +export { ConsoleLogger } from "./console_logger.js"; diff --git a/discojs/src/privacy.spec.ts b/discojs/src/privacy.spec.ts index 6fec49a03..6629095b6 100644 --- a/discojs/src/privacy.spec.ts +++ b/discojs/src/privacy.spec.ts @@ -1,7 +1,12 @@ import { describe, expect, it } from "vitest"; import { WeightsContainer } from "./index.js"; -import { frobeniusNorm, clipNorm, addOptimalNoise, getClippingRadius } from "./privacy.js"; +import { + frobeniusNorm, + clipNorm, + addOptimalNoise, + getClippingRadius, +} from "./privacy.js"; import { WeightNormHistory } from "./training/trainer.js"; import * as tf from "@tensorflow/tfjs"; import { List } from "immutable"; @@ -12,38 +17,37 @@ async function WSIntoArrays(ws: WeightsContainer): Promise { ); } - /** Test the frobenius norm computation */ describe("frobeniusNorm", () => { it("computes Frobenius norm", async () => { const t = tf.tensor([3, 4]); const n = await frobeniusNorm(t); expect(n).toBeCloseTo(5, 1e-12); - }) + }); }); describe("clipNorm", () => { it("clips a single-layer vector using single radius value", async () => { const result = await clipNorm(WeightsContainer.of([2]), [1]); expect(await WSIntoArrays(result)).toEqual([[1]]); - }) + }); it("check if it does not change vector when it is already within radius", async () => { // norm is smaller than the clipping radius 10 - const result = await clipNorm(WeightsContainer.of([3, 4]), [10]); - expect(await WSIntoArrays(result)).toEqual([[3, 4]]) - }) + const result = await clipNorm(WeightsContainer.of([3, 4]), [10]); + expect(await WSIntoArrays(result)).toEqual([[3, 4]]); + }); it("applying different clipping radii per layer", async () => { const wc = WeightsContainer.of([3, 4], [0, 6]); const result = await clipNorm(wc, [5, 3]); // apply different clipping radii for each layer - + expect(await WSIntoArrays(result)).toEqual([ - [3, 4], - [0, 3], + [3, 4], + [0, 3], ]); }); -}) +}); describe("addOptimalNoise", () => { it("check if the structure is maintained", async () => { @@ -66,22 +70,26 @@ describe("addOptimalNoise", () => { expect(Number.isFinite(resultArrays[1][0])).toBe(true); expect(Number.isFinite(resultArrays[1][1])).toBe(true); }); -}) +}); describe("getClippingRadius", () => { it("correct average clipping radius and default radius", () => { const weightNormHistory = List([ List([2, 4, 6]), // expected average norm is 4 - List([10]) + List([10]), ]); - expect(getClippingRadius(weightNormHistory as WeightNormHistory, 5)).toEqual([4, 5]); + expect( + getClippingRadius(weightNormHistory as WeightNormHistory, 5), + ).toEqual([4, 5]); }); it("uses smaller window size automatically if needed", () => { const weightNormHistory = List([List([2, 4])]); // Automatically use window size of 2 instead of 10 - expect(getClippingRadius(weightNormHistory as WeightNormHistory, 10)).toEqual([3]); + expect( + getClippingRadius(weightNormHistory as WeightNormHistory, 10), + ).toEqual([3]); }); }); diff --git a/discojs/src/privacy.ts b/discojs/src/privacy.ts index 8ed51ed1c..c91cb0553 100644 --- a/discojs/src/privacy.ts +++ b/discojs/src/privacy.ts @@ -6,22 +6,25 @@ import type { WeightNormHistory } from "./training/trainer.js"; /** Computes the Frobenius norm of the given weights. */ export async function frobeniusNorm(weights: tf.Tensor): Promise { - const squared = await weights.square().sum().data(); - if (squared.length !== 1) throw new Error("unexpected weights shape"); - return Math.sqrt(squared[0]); + const squared = await weights.square().sum().data(); + if (squared.length !== 1) throw new Error("unexpected weights shape"); + return Math.sqrt(squared[0]); } /** ALDP-FL implementation */ // Conditions need to be added for the first three epochs -> get the avg update from all of the available previous updates -export function getClippingRadius(weightNormHistory: WeightNormHistory, defaultClippingRadius:number): number[]{ +export function getClippingRadius( + weightNormHistory: WeightNormHistory, + defaultClippingRadius: number, +): number[] { const WINDOW_SIZE = 3; - const MIN_RADIUS = 1e-12; + const MIN_RADIUS = 1e-12; const radii = weightNormHistory.map((norms) => { const recent = norms.slice(-WINDOW_SIZE); - const avg = recent.reduce((sum, n) => sum+n, 0) / recent.size; + const avg = recent.reduce((sum, n) => sum + n, 0) / recent.size; - return Math.max(MIN_RADIUS, Math.min(avg, defaultClippingRadius)) + return Math.max(MIN_RADIUS, Math.min(avg, defaultClippingRadius)); }); // Convert List to number[] @@ -35,7 +38,7 @@ export async function addOptimalNoise( weightUpdates: WeightsContainer, epsilon: number, delta: number, - clippingRadius: number[], + clippingRadius: number[], ): Promise { /** * In the original paper, the sensitivity is given as 2 * clippingRadius / d, though the meaning of d is unclear. @@ -43,13 +46,15 @@ export async function addOptimalNoise( */ // apply different sensitivity and noise to each of the layer // clippingRadius is now number[] - const sens = clippingRadius.map((r)=>(2*r)); - const sigmas = sens.map((s)=>(s * Math.sqrt(2*Math.log(1.25/delta))/epsilon)); + const sens = clippingRadius.map((r) => 2 * r); + const sigmas = sens.map( + (s) => (s * Math.sqrt(2 * Math.log(1.25 / delta))) / epsilon, + ); const clippedWeights = await clipNorm(weightUpdates, clippingRadius); return clippedWeights.map((w, i) => - w.add(tf.randomNormal(w.shape, 0, sigmas[i])) - ) + w.add(tf.randomNormal(w.shape, 0, sigmas[i])), + ); } /** @@ -65,7 +70,9 @@ export async function clipNorm( */ const layers = weights.weights; if (radius.length !== layers.length) - throw new Error(`radius length mismatch: got ${radius.length}, expected ${layers.length}`); + throw new Error( + `radius length mismatch: got ${radius.length}, expected ${layers.length}`, + ); /** Apply different clipping radius to each layer in the WeightsContainer */ const clipped = await Promise.all( @@ -75,10 +82,10 @@ export async function clipNorm( // Check the invalid radius value if (!Number.isFinite(r) || r <= 0) - throw new Error("Invalid radius value") + throw new Error("Invalid radius value"); const scaling = Math.max(1, norm / r); return l.div(scaling); - }) + }), ); return new WeightsContainer(clipped); diff --git a/discojs/src/processing/index.ts b/discojs/src/processing/index.ts index 4011f40d1..8d6c2110f 100644 --- a/discojs/src/processing/index.ts +++ b/discojs/src/processing/index.ts @@ -3,12 +3,12 @@ import { List } from "immutable"; import type { - Dataset, - DataFormat, - DataType, - Tabular, - Task, - Network, + Dataset, + DataFormat, + DataType, + Tabular, + Task, + Network, } from "../index.js"; import * as processing from "./index.js"; @@ -103,34 +103,34 @@ export function preprocessWithoutLabel( } export function postprocess( - task: Task, - encoded: DataFormat.ModelEncoded[D][1], + task: Task, + encoded: DataFormat.ModelEncoded[D][1], ): DataFormat.Inferred[D] { - switch (task.dataType) { - case "image": { - // cast as typescript doesn't reduce generic type - const index = encoded as DataFormat.ModelEncoded["image"][1]; + switch (task.dataType) { + case "image": { + // cast as typescript doesn't reduce generic type + const index = encoded as DataFormat.ModelEncoded["image"][1]; const labels = List(task.trainingInformation.LABEL_LIST); - const v = labels.get(index); - if (v === undefined) throw new Error("index not found in labels"); - return v as DataFormat.Inferred[D]; - } - case "tabular": { - // cast as typescript doesn't reduce generic type - const v = encoded as DataFormat.ModelEncoded["tabular"][1]; - - return v as DataFormat.Inferred[D]; - } - case "text": { - // cast as typescript doesn't reduce generic type - const token = encoded as DataFormat.ModelEncoded["text"][1]; - - return task.trainingInformation.tokenizer.decode([ - token, - ]) as DataFormat.Inferred[D]; - } - } + const v = labels.get(index); + if (v === undefined) throw new Error("index not found in labels"); + return v as DataFormat.Inferred[D]; + } + case "tabular": { + // cast as typescript doesn't reduce generic type + const v = encoded as DataFormat.ModelEncoded["tabular"][1]; + + return v as DataFormat.Inferred[D]; + } + case "text": { + // cast as typescript doesn't reduce generic type + const token = encoded as DataFormat.ModelEncoded["text"][1]; + + return task.trainingInformation.tokenizer.decode([ + token, + ]) as DataFormat.Inferred[D]; + } + } } function extractToNumbers(columns: Iterable, row: Tabular) { diff --git a/discojs/src/serialization/coder.ts b/discojs/src/serialization/coder.ts index 1aa32bec6..4d78cdb11 100644 --- a/discojs/src/serialization/coder.ts +++ b/discojs/src/serialization/coder.ts @@ -29,9 +29,9 @@ CODEC.register({ type: 0x12, encode(obj: unknown): Uint8Array | null { if (!(obj instanceof Uint8Array)) return null; - return obj + return obj; }, - decode: (raw: Uint8Array): Uint8Array => raw + decode: (raw: Uint8Array): Uint8Array => raw, }); // used by TFJS's weights CODEC.register({ @@ -54,7 +54,7 @@ CODEC.register({ }, decode: (raw: Uint8Array): ArrayBuffer => // need to copy as backing ArrayBuffer might be larger - copy(raw).buffer + copy(raw).buffer, }); type Encodable = diff --git a/discojs/src/serialization/index.ts b/discojs/src/serialization/index.ts index 9411935ee..1dccea6e6 100644 --- a/discojs/src/serialization/index.ts +++ b/discojs/src/serialization/index.ts @@ -1,15 +1,15 @@ -export * as model from './model.js' +export * as model from "./model.js"; export * as task from "./task.js"; -export * as weights from './weights.js' +export * as weights from "./weights.js"; export type { Encoded } from "./coder.js"; export { isEncoded } from "./coder.js"; export type JSON = - | null - | undefined - | boolean - | number - | string - | JSON[] - | { [_: string]: JSON }; + | null + | undefined + | boolean + | number + | string + | JSON[] + | { [_: string]: JSON }; diff --git a/discojs/src/serialization/model.spec.ts b/discojs/src/serialization/model.spec.ts index b08933021..6cf0e3bf8 100644 --- a/discojs/src/serialization/model.spec.ts +++ b/discojs/src/serialization/model.spec.ts @@ -1,38 +1,40 @@ -import * as tf from '@tensorflow/tfjs' +import * as tf from "@tensorflow/tfjs"; import { assert, describe, expect, it } from "vitest"; import type { DataType, Model } from "../index.js"; -import { models, serialization } from '../index.js' -import type { GPTConfig } from '../models/index.js' +import { models, serialization } from "../index.js"; +import type { GPTConfig } from "../models/index.js"; async function getRawWeights( model: Model, ): Promise<[number, Float32Array][]> { return Array.from( - (await Promise.all( - model.weights.weights.map(async (w) => await w.data<'float32'>())) - ).entries() - ) + ( + await Promise.all( + model.weights.weights.map(async (w) => await w.data<"float32">()), + ) + ).entries(), + ); } -describe('serialization', () => { - it('can encode & decode a TFJS model', async () => { +describe("serialization", () => { + it("can encode & decode a TFJS model", async () => { const rawModel = tf.sequential({ layers: [ tf.layers.conv2d({ inputShape: [32, 32, 3], kernelSize: 3, filters: 16, - activation: 'relu' - }) - ] - }) - rawModel.compile({ optimizer: 'sgd', loss: 'hinge' }) - const model = new models.TFJS("image", rawModel) + activation: "relu", + }), + ], + }); + rawModel.compile({ optimizer: "sgd", loss: "hinge" }); + const model = new models.TFJS("image", rawModel); - const encoded = await serialization.model.encode(model) - assert.isTrue(serialization.isEncoded(encoded)) - const decoded = await serialization.model.decode(encoded) + const encoded = await serialization.model.encode(model); + assert.isTrue(serialization.isEncoded(encoded)); + const decoded = await serialization.model.decode(encoded); expect(decoded).to.be.an.instanceof(models.TFJS); expect((decoded as models.TFJS<"image" | "tabular">).datatype).to.equal( @@ -40,31 +42,31 @@ describe('serialization', () => { ); assert.sameDeepOrderedMembers( await getRawWeights(model), - await getRawWeights(decoded) - ) - }) + await getRawWeights(decoded), + ); + }); it("can encode & decode a gpt-tfjs model", { timeout: 20_000 }, async () => { const config: GPTConfig = { - modelType: 'gpt-nano', + modelType: "gpt-nano", lr: 0.01, maxIter: 10, - evaluateEvery:10, + evaluateEvery: 10, maxEvalBatches: 10, contextLength: 8, - } - const model = new models.GPT(config) + }; + const model = new models.GPT(config); + + const encoded = await serialization.model.encode(model); + assert.isTrue(serialization.isEncoded(encoded)); + const decoded = await serialization.model.decode(encoded); - const encoded = await serialization.model.encode(model) - assert.isTrue(serialization.isEncoded(encoded)) - const decoded = await serialization.model.decode(encoded) - - assert.instanceOf(decoded, models.GPT) + assert.instanceOf(decoded, models.GPT); assert.sameDeepOrderedMembers( await getRawWeights(model), - await getRawWeights(decoded) - ) + await getRawWeights(decoded), + ); assert.deepEqual(model.config, decoded.config); - }) -}) + }); +}); diff --git a/discojs/src/serialization/model.ts b/discojs/src/serialization/model.ts index 020d147af..5552bfca1 100644 --- a/discojs/src/serialization/model.ts +++ b/discojs/src/serialization/model.ts @@ -1,16 +1,16 @@ -import type tf from '@tensorflow/tfjs' +import type tf from "@tensorflow/tfjs"; -import type { DataType, Model } from '../index.js' -import { models, serialization } from '../index.js' -import { GPTConfig } from '../models/index.js' +import type { DataType, Model } from "../index.js"; +import { models, serialization } from "../index.js"; +import { GPTConfig } from "../models/index.js"; import * as coder from "./coder.js"; import { Encoded, isEncoded } from "./coder.js"; const Type = { TFJS: 0, - GPT: 1 -} as const + GPT: 1, +} as const; export async function encode(model: Model): Promise { switch (true) { @@ -29,16 +29,20 @@ export async function encode(model: Model): Promise { } export async function decode(encoded: Encoded): Promise> { - const raw = coder.decode(encoded) + const raw = coder.decode(encoded); if (!Array.isArray(raw) || raw.length < 2) { - throw new Error("invalid encoding, encoding isn't an array or doesn't contain enough values") + throw new Error( + "invalid encoding, encoding isn't an array or doesn't contain enough values", + ); } - const type = raw[0] as unknown - if (typeof type !== 'number') { - throw new Error('invalid encoding, first encoding field should be the model type') + const type = raw[0] as unknown; + if (typeof type !== "number") { + throw new Error( + "invalid encoding, first encoding field should be the model type", + ); } - const rawModel = raw[1] as unknown + const rawModel = raw[1] as unknown; switch (type) { case Type.TFJS: { if (raw.length !== 3) @@ -54,9 +58,7 @@ export async function decode(encoded: Encoded): Promise> { datatype = rawDatatype; break; default: - throw new Error( - "invalid TFJS model encoding: invalid DataType", - ); + throw new Error("invalid TFJS model encoding: invalid DataType"); } return await models.TFJS.deserialize([ @@ -65,24 +67,26 @@ export async function decode(encoded: Encoded): Promise> { rawModel as tf.io.ModelArtifacts, ]); } - case Type.GPT: { - let config + case Type.GPT: { + let config; if (raw.length == 2) { - config = undefined + config = undefined; } else if (raw.length == 3) { - config = raw[2] as GPTConfig + config = raw[2] as GPTConfig; } else { - throw new Error('invalid encoding, gpt-tfjs model encoding should be an array of length 2 or 3') + throw new Error( + "invalid encoding, gpt-tfjs model encoding should be an array of length 2 or 3", + ); } if (!isEncoded(rawModel)) throw new Error( "invalid encoding, gpt-tfjs model weights should be an encoding of its weights", ); - const weights = serialization.weights.decode(rawModel) - return models.GPT.deserialize({weights, config}) + const weights = serialization.weights.decode(rawModel); + return models.GPT.deserialize({ weights, config }); } default: - throw new Error('invalid encoding, model type unrecognized') + throw new Error("invalid encoding, model type unrecognized"); } } diff --git a/discojs/src/serialization/task.spec.ts b/discojs/src/serialization/task.spec.ts index 2c42452c5..2c3061f0a 100644 --- a/discojs/src/serialization/task.spec.ts +++ b/discojs/src/serialization/task.spec.ts @@ -3,10 +3,10 @@ import { expect, it } from "vitest"; import { serialization, defaultTasks } from "../index.js"; it("can encode what it decodes", async () => { - const task = await defaultTasks.wikitext.getTask(); + const task = await defaultTasks.wikitext.getTask(); - const serialized = serialization.task.serializeToJSON(task); - const deserialized = await serialization.task.deserializeFromJSON(serialized); + const serialized = serialization.task.serializeToJSON(task); + const deserialized = await serialization.task.deserializeFromJSON(serialized); - expect(deserialized).to.be.deep.equal(task); + expect(deserialized).to.be.deep.equal(task); }); diff --git a/discojs/src/serialization/task.ts b/discojs/src/serialization/task.ts index 28c414246..86c5c6ba1 100644 --- a/discojs/src/serialization/task.ts +++ b/discojs/src/serialization/task.ts @@ -5,37 +5,37 @@ import { Task, Tokenizer } from "../index.js"; import type { JSON } from "./index.js"; export function serializeToJSON(task: Task): JSON { - switch (task.dataType) { - case "image": - case "tabular": - return task; - case "text": { - return { - ...task, - trainingInformation: { - ...task.trainingInformation, - tokenizer: task.trainingInformation.tokenizer.name, - }, - }; - } - } + switch (task.dataType) { + case "image": + case "tabular": + return task; + case "text": { + return { + ...task, + trainingInformation: { + ...task.trainingInformation, + tokenizer: task.trainingInformation.tokenizer.name, + }, + }; + } + } } export async function deserializeFromJSON( - serialized: JSON, + serialized: JSON, ): Promise> { - return await z - .object({ - trainingInformation: z - .object({ - tokenizer: z - .string() - .transform((name) => Tokenizer.from_pretrained(name)) - .optional(), - }) - .passthrough(), - }) - .passthrough() - .pipe(Task.schema) - .parseAsync(serialized); + return await z + .object({ + trainingInformation: z + .object({ + tokenizer: z + .string() + .transform((name) => Tokenizer.from_pretrained(name)) + .optional(), + }) + .passthrough(), + }) + .passthrough() + .pipe(Task.schema) + .parseAsync(serialized); } diff --git a/discojs/src/serialization/weights.spec.ts b/discojs/src/serialization/weights.spec.ts index 4fac46562..b13aea6ad 100644 --- a/discojs/src/serialization/weights.spec.ts +++ b/discojs/src/serialization/weights.spec.ts @@ -1,26 +1,30 @@ import { assert, describe, it } from "vitest"; -import { WeightsContainer, serialization } from '../index.js' +import { WeightsContainer, serialization } from "../index.js"; -describe('weights', () => { - it('can encode what it decodes', async () => { - const weights = WeightsContainer.of([1], [2], [3]) +describe("weights", () => { + it("can encode what it decodes", async () => { + const weights = WeightsContainer.of([1], [2], [3]); - const encoded = await serialization.weights.encode(weights) - assert.isTrue(serialization.isEncoded(encoded)) - const decoded = serialization.weights.decode(encoded) + const encoded = await serialization.weights.encode(weights); + assert.isTrue(serialization.isEncoded(encoded)); + const decoded = serialization.weights.decode(encoded); assert.sameDeepOrderedMembers( Array.from( - (await Promise.all( - decoded.weights.map(async (w) => await w.data<'float32'>())) - ).entries() + ( + await Promise.all( + decoded.weights.map(async (w) => await w.data<"float32">()), + ) + ).entries(), ), Array.from( - (await Promise.all( - weights.weights.map(async (w) => await w.data<'float32'>())) - ).entries() - ) - ) - }) -}) + ( + await Promise.all( + weights.weights.map(async (w) => await w.data<"float32">()), + ) + ).entries(), + ), + ); + }); +}); diff --git a/discojs/src/task/display_information.ts b/discojs/src/task/display_information.ts index c82f9c595..b03b91da2 100644 --- a/discojs/src/task/display_information.ts +++ b/discojs/src/task/display_information.ts @@ -3,39 +3,39 @@ import { z } from "zod"; import type { DataType } from "../types/index.js"; export namespace DisplayInformation { - export const baseSchema = z.object({ - title: z.string(), - summary: z.object({ - preview: z.string(), - overview: z.string(), - }), - dataFormatInformation: z.string().optional(), - model: z.string().optional(), - sampleDataset: z - .object({ - // URL to download a dataset for the task, is displayed in the UI when asking to connect data - link: z.string(), - // Instructions to download, unzip, and connect the right file of the sample dataset - instructions: z.string(), - }) - .optional(), - }); + export const baseSchema = z.object({ + title: z.string(), + summary: z.object({ + preview: z.string(), + overview: z.string(), + }), + dataFormatInformation: z.string().optional(), + model: z.string().optional(), + sampleDataset: z + .object({ + // URL to download a dataset for the task, is displayed in the UI when asking to connect data + link: z.string(), + // Instructions to download, unzip, and connect the right file of the sample dataset + instructions: z.string(), + }) + .optional(), + }); - export const dataTypeToSchema = { - image: z.object({ - // url to an image - dataExample: z.string().optional(), - }), - tabular: z.object({ - dataExample: z - .array(z.object({ name: z.string(), data: z.string() })) - .optional(), - }), - text: z.object({ - dataExample: z.string().optional(), - }), - } satisfies Record; + export const dataTypeToSchema = { + image: z.object({ + // url to an image + dataExample: z.string().optional(), + }), + tabular: z.object({ + dataExample: z + .array(z.object({ name: z.string(), data: z.string() })) + .optional(), + }), + text: z.object({ + dataExample: z.string().optional(), + }), + } satisfies Record; } export type DisplayInformation = - (typeof DisplayInformation.dataTypeToSchema)[D]; + (typeof DisplayInformation.dataTypeToSchema)[D]; diff --git a/discojs/src/task/task.ts b/discojs/src/task/task.ts index d79d26a28..2903239cc 100644 --- a/discojs/src/task/task.ts +++ b/discojs/src/task/task.ts @@ -6,63 +6,63 @@ import { DisplayInformation } from "./display_information.js"; import { TrainingInformation } from "./training_information.js"; export namespace Task { - export type ID = string; + export type ID = string; - export const baseSchema = z.object({ - id: z.string(), - displayInformation: DisplayInformation.baseSchema, - trainingInformation: TrainingInformation.baseSchema, - }); + export const baseSchema = z.object({ + id: z.string(), + displayInformation: DisplayInformation.baseSchema, + trainingInformation: TrainingInformation.baseSchema, + }); - export const dataTypeToSchema = { - image: z.object({ - dataType: z.literal("image"), - displayInformation: DisplayInformation.dataTypeToSchema.image, - trainingInformation: TrainingInformation.dataTypeToSchema.image, - }), - tabular: z.object({ - dataType: z.literal("tabular"), - displayInformation: DisplayInformation.dataTypeToSchema.tabular, - trainingInformation: TrainingInformation.dataTypeToSchema.tabular, - }), - text: z.object({ - dataType: z.literal("text"), - displayInformation: DisplayInformation.dataTypeToSchema.text, - trainingInformation: TrainingInformation.dataTypeToSchema.text, - }), - } satisfies Record; + export const dataTypeToSchema = { + image: z.object({ + dataType: z.literal("image"), + displayInformation: DisplayInformation.dataTypeToSchema.image, + trainingInformation: TrainingInformation.dataTypeToSchema.image, + }), + tabular: z.object({ + dataType: z.literal("tabular"), + displayInformation: DisplayInformation.dataTypeToSchema.tabular, + trainingInformation: TrainingInformation.dataTypeToSchema.tabular, + }), + text: z.object({ + dataType: z.literal("text"), + displayInformation: DisplayInformation.dataTypeToSchema.text, + trainingInformation: TrainingInformation.dataTypeToSchema.text, + }), + } satisfies Record; - export const networkToSchema = { - decentralized: z.object({ - trainingInformation: TrainingInformation.networkToSchema.decentralized, - }), - federated: z.object({ - trainingInformation: TrainingInformation.networkToSchema.federated, - }), - local: z.object({ - trainingInformation: TrainingInformation.networkToSchema.local, - }), - } satisfies Record; + export const networkToSchema = { + decentralized: z.object({ + trainingInformation: TrainingInformation.networkToSchema.decentralized, + }), + federated: z.object({ + trainingInformation: TrainingInformation.networkToSchema.federated, + }), + local: z.object({ + trainingInformation: TrainingInformation.networkToSchema.local, + }), + } satisfies Record; - export const schema = baseSchema - .and( - z.union([ - dataTypeToSchema.image, - dataTypeToSchema.tabular, - dataTypeToSchema.text, - ]), - ) - .and( - z.union([ - networkToSchema.decentralized, - networkToSchema.federated, - networkToSchema.local, - ]), - ); + export const schema = baseSchema + .and( + z.union([ + dataTypeToSchema.image, + dataTypeToSchema.tabular, + dataTypeToSchema.text, + ]), + ) + .and( + z.union([ + networkToSchema.decentralized, + networkToSchema.federated, + networkToSchema.local, + ]), + ); } export type Task = z.infer< - typeof Task.baseSchema + typeof Task.baseSchema > & - z.infer<(typeof Task.dataTypeToSchema)[D]> & - z.infer<(typeof Task.networkToSchema)[N]>; + z.infer<(typeof Task.dataTypeToSchema)[D]> & + z.infer<(typeof Task.networkToSchema)[N]>; diff --git a/discojs/src/task/task_handler.ts b/discojs/src/task/task_handler.ts index 83f35d051..ac199d126 100644 --- a/discojs/src/task/task_handler.ts +++ b/discojs/src/task/task_handler.ts @@ -6,49 +6,49 @@ import { serialization } from "../index.js"; import type { Task } from "./task.js"; function urlToTasks(base: URL): URL { - const ret = new URL(base); - ret.pathname += "tasks"; - return ret; + const ret = new URL(base); + ret.pathname += "tasks"; + return ret; } export async function pushTask( - base: URL, - task: Task, - model: Model, + base: URL, + task: Task, + model: Model, ): Promise { - const response = await fetch(urlToTasks(base), { - method: "POST", - headers: { "Content-Type": "application/json" }, - body: JSON.stringify({ - task: serialization.task.serializeToJSON(task), - model: [...(await serialization.model.encode(model))], - }), - }); - if (!response.ok) throw new Error(`fetch: HTTP status ${response.status}`); + const response = await fetch(urlToTasks(base), { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ + task: serialization.task.serializeToJSON(task), + model: [...(await serialization.model.encode(model))], + }), + }); + if (!response.ok) throw new Error(`fetch: HTTP status ${response.status}`); } export async function fetchTasks( - base: URL, + base: URL, ): Promise>> { - const response = await fetch(urlToTasks(base)); - if (!response.ok) throw new Error(`fetch: HTTP status ${response.status}`); - const json = (await response.json()) as serialization.JSON; + const response = await fetch(urlToTasks(base)); + if (!response.ok) throw new Error(`fetch: HTTP status ${response.status}`); + const json = (await response.json()) as serialization.JSON; - if (!Array.isArray(json)) - throw new Error("invalid tasks response: expected a JSON array"); - const arr = json; + if (!Array.isArray(json)) + throw new Error("invalid tasks response: expected a JSON array"); + const arr = json; - try { - return Map( - Seq( - await Promise.all( - arr.map((t) => serialization.task.deserializeFromJSON(t)), - ), - ).map((t) => [t.id, t]), - ); - } catch (cause) { - throw new Error("invalid tasks response: unable to parse all tasks", { - cause, - }); - } + try { + return Map( + Seq( + await Promise.all( + arr.map((t) => serialization.task.deserializeFromJSON(t)), + ), + ).map((t) => [t.id, t]), + ); + } catch (cause) { + throw new Error("invalid tasks response: unable to parse all tasks", { + cause, + }); + } } diff --git a/discojs/src/task/task_provider.ts b/discojs/src/task/task_provider.ts index 7bf97c962..939e68120 100644 --- a/discojs/src/task/task_provider.ts +++ b/discojs/src/task/task_provider.ts @@ -1,7 +1,7 @@ import type { DataType, Model, Network, Task } from "../index.js"; export interface TaskProvider { - getTask(): Promise>; - // Create the corresponding model ready for training (compiled) - getModel(): Promise>; + getTask(): Promise>; + // Create the corresponding model ready for training (compiled) + getModel(): Promise>; } diff --git a/discojs/src/task/training_information.ts b/discojs/src/task/training_information.ts index fa01cdf2e..a8f041ebb 100644 --- a/discojs/src/task/training_information.ts +++ b/discojs/src/task/training_information.ts @@ -4,120 +4,124 @@ import type { DataType, Network } from "../index.js"; import { Tokenizer } from "../index.js"; const privacySchema = z.object({ - // reduce training accuracy and improve privacy. - differentialPrivacy: z - .object({ - // maximum weights difference between each epoch, used for differential privacy - clippingRadius: z.number().positive().default(1), - // privacy budget, used to compute the variance of Gaussian noise - epsilon: z.number().positive(), - // small probability that the privacy guarantee may not hold - delta: z.number().gt(0).lt(1), - }) - .optional(), + // reduce training accuracy and improve privacy. + differentialPrivacy: z + .object({ + // maximum weights difference between each epoch, used for differential privacy + clippingRadius: z.number().positive().default(1), + // privacy budget, used to compute the variance of Gaussian noise + epsilon: z.number().positive(), + // small probability that the privacy guarantee may not hold + delta: z.number().gt(0).lt(1), + }) + .optional(), }); const nonLocalNetworkSchema = z - .object({ - // minimum number of participants required to train collaboratively - // In decentralized Learning the default is 3, in federated learning it is 2 - minNbOfParticipants: z.number().positive().int(), - }) - .and( - z.union([ - z.object({ - aggregationStrategy: z.literal("mean"), - privacy: privacySchema - .transform((o) => (o.differentialPrivacy === undefined ? undefined : o)) - .optional(), - }), - z.object({ - aggregationStrategy: z.literal("byzantine"), - privacy: z.object({ - ...privacySchema.shape, - byzantineFaultTolerance: z.object({ - // maximum weights difference between each round - clippingRadius: z.number().positive(), - maxIterations: z.number().int().positive().default(1), - beta: z.number().min(0).max(1).default(0.9), - }), - }), - }), - z.object({ - aggregationStrategy: z.literal("secure"), - privacy: privacySchema - .transform((o) => (o.differentialPrivacy === undefined ? undefined : o)) - .optional(), - // Secure Aggregation: maximum absolute value of a number in a randomly generated share - // default is 100, must be a positive number, check the docs/PRIVACY.md file for more information on significance of maxShareValue selection - maxShareValue: z.number().positive().int().optional().default(100), - }), - ]), - ); + .object({ + // minimum number of participants required to train collaboratively + // In decentralized Learning the default is 3, in federated learning it is 2 + minNbOfParticipants: z.number().positive().int(), + }) + .and( + z.union([ + z.object({ + aggregationStrategy: z.literal("mean"), + privacy: privacySchema + .transform((o) => + o.differentialPrivacy === undefined ? undefined : o, + ) + .optional(), + }), + z.object({ + aggregationStrategy: z.literal("byzantine"), + privacy: z.object({ + ...privacySchema.shape, + byzantineFaultTolerance: z.object({ + // maximum weights difference between each round + clippingRadius: z.number().positive(), + maxIterations: z.number().int().positive().default(1), + beta: z.number().min(0).max(1).default(0.9), + }), + }), + }), + z.object({ + aggregationStrategy: z.literal("secure"), + privacy: privacySchema + .transform((o) => + o.differentialPrivacy === undefined ? undefined : o, + ) + .optional(), + // Secure Aggregation: maximum absolute value of a number in a randomly generated share + // default is 100, must be a positive number, check the docs/PRIVACY.md file for more information on significance of maxShareValue selection + maxShareValue: z.number().positive().int().optional().default(100), + }), + ]), + ); export namespace TrainingInformation { - export const baseSchema = z.object({ - // number of epochs to run training for - epochs: z.number().positive().int(), - // number of epochs between each weight sharing round. - // e.g.if 3 then weights are shared every 3 epochs (in the distributed setting). - roundDuration: z.number().positive().int(), - // fraction of data to keep for validation, note this only works for image data - validationSplit: z.number().min(0).max(1), - // batch size of training data - batchSize: z.number().positive().int(), - // Tensor framework used by the model - tensorBackend: z.enum(["gpt", "tfjs"]), - }); + export const baseSchema = z.object({ + // number of epochs to run training for + epochs: z.number().positive().int(), + // number of epochs between each weight sharing round. + // e.g.if 3 then weights are shared every 3 epochs (in the distributed setting). + roundDuration: z.number().positive().int(), + // fraction of data to keep for validation, note this only works for image data + validationSplit: z.number().min(0).max(1), + // batch size of training data + batchSize: z.number().positive().int(), + // Tensor framework used by the model + tensorBackend: z.enum(["gpt", "tfjs"]), + }); - export const dataTypeToSchema = { - image: z.object({ - // classes, e.g. if two class of images, one with dogs and one with cats, then we would - // define ['dogs', 'cats']. - LABEL_LIST: z.array(z.string()).min(1), - // height of image to resize to - IMAGE_W: z.number().positive().int(), - // width of image to resize to - IMAGE_H: z.number().positive().int(), - }), - tabular: z.object({ - // the columns to be chosen as input data for the model - inputColumns: z.array(z.string()), - // the columns to be predicted by the model - outputColumn: z.string(), - }), - text: z.object({ - // should be set with the name of a Transformers.js pre-trained tokenizer, e.g., 'Xenova/gpt2'. - tokenizer: z.instanceof(Tokenizer), - // the maximum length of a input string used as input to a GPT model. It is used during preprocessing to - // truncate strings to a maximum length. The default value is tokenizer.model_max_length - contextLength: z.number().positive().int(), - }), - } satisfies Record; + export const dataTypeToSchema = { + image: z.object({ + // classes, e.g. if two class of images, one with dogs and one with cats, then we would + // define ['dogs', 'cats']. + LABEL_LIST: z.array(z.string()).min(1), + // height of image to resize to + IMAGE_W: z.number().positive().int(), + // width of image to resize to + IMAGE_H: z.number().positive().int(), + }), + tabular: z.object({ + // the columns to be chosen as input data for the model + inputColumns: z.array(z.string()), + // the columns to be predicted by the model + outputColumn: z.string(), + }), + text: z.object({ + // should be set with the name of a Transformers.js pre-trained tokenizer, e.g., 'Xenova/gpt2'. + tokenizer: z.instanceof(Tokenizer), + // the maximum length of a input string used as input to a GPT model. It is used during preprocessing to + // truncate strings to a maximum length. The default value is tokenizer.model_max_length + contextLength: z.number().positive().int(), + }), + } satisfies Record; - export const networkToSchema = { - decentralized: z - .object({ - scheme: z.literal("decentralized"), - aggregationStrategy: z.literal(["byzantine", "mean", "secure"]), - }) - .and(nonLocalNetworkSchema), - federated: z - .object({ - scheme: z.literal("federated"), - aggregationStrategy: z.literal(["byzantine", "mean"]), - }) - .and(nonLocalNetworkSchema), - local: z.object({ - scheme: z.literal("local"), - aggregationStrategy: z.literal("mean"), - }), - } satisfies Record; + export const networkToSchema = { + decentralized: z + .object({ + scheme: z.literal("decentralized"), + aggregationStrategy: z.literal(["byzantine", "mean", "secure"]), + }) + .and(nonLocalNetworkSchema), + federated: z + .object({ + scheme: z.literal("federated"), + aggregationStrategy: z.literal(["byzantine", "mean"]), + }) + .and(nonLocalNetworkSchema), + local: z.object({ + scheme: z.literal("local"), + aggregationStrategy: z.literal("mean"), + }), + } satisfies Record; } export type TrainingInformation< - D extends DataType, - N extends Network, + D extends DataType, + N extends Network, > = z.infer & - z.infer<(typeof TrainingInformation.dataTypeToSchema)[D]> & - z.infer<(typeof TrainingInformation.networkToSchema)[N]>; + z.infer<(typeof TrainingInformation.dataTypeToSchema)[D]> & + z.infer<(typeof TrainingInformation.networkToSchema)[N]>; diff --git a/discojs/src/training/disco.ts b/discojs/src/training/disco.ts index 7dd019b2d..ae6a1dc64 100644 --- a/discojs/src/training/disco.ts +++ b/discojs/src/training/disco.ts @@ -38,36 +38,42 @@ interface DiscoConfig { } export type SummaryLogs = { - round: number, - epoch: number, - trainingLoss: number, - trainingAccuracy: number, - peakMemory: number, - epochTime: number, - roundValidationLoss?: number, - roundValidationAccuracy?: number, - validationLoss?: number, - validationAccuracy?: number -} - -export type RoundStatus = 'not enough participants' | // Server notification to wait for more participants - 'updating model' | // fetching/aggregating local updates into a global model - 'local training' | // Training the model locally - 'connecting to peers' // for decentralized only, fetch the server's list of participating peers - -function buildSummaryLog(roundNum: number, epochNum: number, roundLogs: RoundLogs, epochLogs: EpochLogs): SummaryLogs { + round: number; + epoch: number; + trainingLoss: number; + trainingAccuracy: number; + peakMemory: number; + epochTime: number; + roundValidationLoss?: number; + roundValidationAccuracy?: number; + validationLoss?: number; + validationAccuracy?: number; +}; + +export type RoundStatus = + | "not enough participants" // Server notification to wait for more participants + | "updating model" // fetching/aggregating local updates into a global model + | "local training" // Training the model locally + | "connecting to peers"; // for decentralized only, fetch the server's list of participating peers + +function buildSummaryLog( + roundNum: number, + epochNum: number, + roundLogs: RoundLogs, + epochLogs: EpochLogs, +): SummaryLogs { return { - round: roundNum, - epoch: epochNum, - trainingLoss: epochLogs.training.loss, - trainingAccuracy: epochLogs.training.accuracy, - peakMemory: epochLogs.peakMemory, - epochTime: epochLogs.epochTime, - roundValidationLoss: roundLogs.preRoundValidation?.loss, - roundValidationAccuracy: roundLogs.preRoundValidation?.accuracy, - validationLoss: epochLogs.validation?.loss, - validationAccuracy: epochLogs.validation?.accuracy, - } + round: roundNum, + epoch: epochNum, + trainingLoss: epochLogs.training.loss, + trainingAccuracy: epochLogs.training.accuracy, + peakMemory: epochLogs.peakMemory, + epochTime: epochLogs.epochTime, + roundValidationLoss: roundLogs.preRoundValidation?.loss, + roundValidationAccuracy: roundLogs.preRoundValidation?.accuracy, + validationLoss: epochLogs.validation?.loss, + validationAccuracy: epochLogs.validation?.accuracy, + }; } /** @@ -77,7 +83,7 @@ function buildSummaryLog(roundNum: number, epochNum: number, roundLogs: RoundLog */ export class Disco extends EventEmitter<{ status: RoundStatus; - participants: number + participants: number; }> { public readonly trainer: Trainer; readonly #client: clients.Client; @@ -94,7 +100,10 @@ export class Disco extends EventEmitter<{ */ constructor( task: Task, - clientConfig: clients.Client | URL | { aggregator: Aggregator; url: URL }, + clientConfig: + | clients.Client + | URL + | { aggregator: Aggregator; url: URL }, config: Partial>, ) { super(); @@ -129,7 +138,9 @@ export class Disco extends EventEmitter<{ this.trainer = new Trainer(task, client); // Simply propagate the training status events emitted by the client this.#client.on("status", (status) => this.emit("status", status)); - this.#client.on("participants", (nbParticipants) => this.emit("participants", nbParticipants)); + this.#client.on("participants", (nbParticipants) => + this.emit("participants", nbParticipants), + ); } /** Train on dataset, yielding logs of every round. */ @@ -164,14 +175,15 @@ export class Disco extends EventEmitter<{ for await (const epoch of round) yield* epoch; } - /** Train on dataset, yielding summary logs */ + /** Train on dataset, yielding summary logs */ async *trainSummary( dataset: Dataset, ): AsyncGenerator { for await (const [roundNum, round] of enumerate(this.train(dataset))) { const [roundGen, roundLogsPromise] = async_iterator.split(round); - const epochResults: Array<{epochNum: number; epochLogs: EpochLogs}> = []; + const epochResults: Array<{ epochNum: number; epochLogs: EpochLogs }> = + []; for await (const [epochNum, epoch] of enumerate(roundGen)) { const [epochGen, epochLogsPromise] = async_iterator.split(epoch); @@ -183,7 +195,7 @@ export class Disco extends EventEmitter<{ const roundLogs = await roundLogsPromise; - for (const {epochNum, epochLogs} of epochResults) { + for (const { epochNum, epochLogs } of epochResults) { yield buildSummaryLog(roundNum, epochNum, roundLogs, epochLogs); } } @@ -220,7 +232,8 @@ export class Disco extends EventEmitter<{ )) { yield async function* (this: Disco) { const [roundGen, roundLogsPromise] = split(round); - const epochResults: Array<{epochNum: number; epochLogs: EpochLogs}> = []; + const epochResults: Array<{ epochNum: number; epochLogs: EpochLogs }> = + []; for await (const [epochNum, epoch] of enumerate(roundGen)) { const [epochGen, epochLogsPromise] = split(epoch); @@ -240,7 +253,7 @@ export class Disco extends EventEmitter<{ ].join("\n"), ); - for (const {epochNum, epochLogs} of epochResults){ + for (const { epochNum, epochLogs } of epochResults) { this.#logger.success( [ `Round: ${roundNum}`, @@ -283,13 +296,12 @@ export class Disco extends EventEmitter<{ let preprocessed = processing.preprocess(this.#task, dataset); - preprocessed = ( - this.#preprocessOnce - ? new Dataset(await arrayFromAsync(preprocessed)) - : preprocessed - ) - if (validationSplit === 0) return [preprocessed.batch(batchSize).cached(), undefined]; - + preprocessed = this.#preprocessOnce + ? new Dataset(await arrayFromAsync(preprocessed)) + : preprocessed; + if (validationSplit === 0) + return [preprocessed.batch(batchSize).cached(), undefined]; + const [training, validation] = preprocessed.split(validationSplit); return [ diff --git a/discojs/src/training/index.ts b/discojs/src/training/index.ts index adce5af68..1c10a3fef 100644 --- a/discojs/src/training/index.ts +++ b/discojs/src/training/index.ts @@ -1,2 +1,2 @@ -export { Disco, RoundStatus, SummaryLogs } from './disco.js' -export { RoundLogs, Trainer } from './trainer.js' +export { Disco, RoundStatus, SummaryLogs } from "./disco.js"; +export { RoundLogs, Trainer } from "./trainer.js"; diff --git a/discojs/src/training/trainer.ts b/discojs/src/training/trainer.ts index 68e716bcc..44c8f8cd3 100644 --- a/discojs/src/training/trainer.ts +++ b/discojs/src/training/trainer.ts @@ -27,7 +27,10 @@ export interface RoundLogs { /** List of weight update norms */ export type WeightNormHistory = List>; -function appendWeightHistory(weightNormHistory: WeightNormHistory, wc: number[]){ +function appendWeightHistory( + weightNormHistory: WeightNormHistory, + wc: number[], +) { return wc.reduce((hist, t, i) => { const arr = hist.get(i, List()); return hist.set(i, arr.push(t)); @@ -40,18 +43,18 @@ export class Trainer { readonly #roundDuration: number; readonly #epochs: number; readonly #privacy: - | Task< - DataType, - "decentralized" | "federated" - >["trainingInformation"]["privacy"] - | undefined; + | Task< + DataType, + "decentralized" | "federated" + >["trainingInformation"]["privacy"] + | undefined; #model: Model | undefined; #training?: AsyncGenerator< AsyncGenerator, RoundLogs>, void >; // Map of weight Index and weight update - #weightNormHistory : WeightNormHistory = List(); + #weightNormHistory: WeightNormHistory = List(); #previousRoundWeights?: WeightsContainer; public get model(): Model { @@ -68,8 +71,8 @@ export class Trainer { this.#client = client; this.#roundDuration = task.trainingInformation.roundDuration; this.#epochs = task.trainingInformation.epochs; - if ("privacy" in task.trainingInformation) - this.#privacy = task.trainingInformation.privacy; + if ("privacy" in task.trainingInformation) + this.#privacy = task.trainingInformation.privacy; if (!Number.isInteger(this.#epochs / this.#roundDuration)) throw new Error( @@ -110,35 +113,40 @@ export class Trainer { > { const totalRound = Math.trunc(this.#epochs / this.#roundDuration); for (let round = 0; round < totalRound; round++) { - await this.#client.onRoundBeginCommunication(); // Store the clean weight before starting the communication - this.#previousRoundWeights = new WeightsContainer(this.model.weights.weights.map(t => t.clone())); + this.#previousRoundWeights = new WeightsContainer( + this.model.weights.weights.map((t) => t.clone()), + ); yield this.#runRound(dataset, validationDataset); let roundWeights = this.model.weights; // Apply differential privacy before sharing the weight updates with other nodes - if (this.#privacy !== undefined){ + if (this.#privacy !== undefined) { const roundUpdate = roundWeights.sub(this.#previousRoundWeights); const updateNorm = await Promise.all( - roundUpdate.weights.map(privacy.frobeniusNorm) + roundUpdate.weights.map(privacy.frobeniusNorm), ); - this.#weightNormHistory = appendWeightHistory(this.#weightNormHistory, updateNorm); - + this.#weightNormHistory = appendWeightHistory( + this.#weightNormHistory, + updateNorm, + ); + roundWeights = await applyOptimalPrivacy( this.#previousRoundWeights, roundWeights, this.#privacy, this.#weightNormHistory, totalRound, - ) + ); } // Get the updated weights - const networkWeights = await this.#client.onRoundEndCommunication(roundWeights); - + const networkWeights = + await this.#client.onRoundEndCommunication(roundWeights); + // Update the local weights this.model.weights = networkWeights; } @@ -151,7 +159,10 @@ export class Trainer { let epochsLogs = List(); // Before starting the training, get the validation of global model - const validation = validationDataset !== undefined ? await this.model.evaluate(validationDataset) : undefined; + const validation = + validationDataset !== undefined + ? await this.model.evaluate(validationDataset) + : undefined; for (let epoch = 0; epoch < this.#roundDuration; epoch++) { const [gen, epochLogs] = async_iterator.split( @@ -161,7 +172,7 @@ export class Trainer { yield gen; epochsLogs = epochsLogs.push(await epochLogs); } - + return { epochs: epochsLogs, participants: this.#client.nbOfParticipants, @@ -172,70 +183,70 @@ export class Trainer { /** ALDP-FL implementation */ async function applyOptimalPrivacy( - previous: WeightsContainer | undefined, - current: WeightsContainer, - options: Exclude< - Task< - DataType, - "decentralized" | "federated" - >["trainingInformation"]["privacy"], - undefined - >, - weightNormHistory: WeightNormHistory, - totalRound: number, + previous: WeightsContainer | undefined, + current: WeightsContainer, + options: Exclude< + Task< + DataType, + "decentralized" | "federated" + >["trainingInformation"]["privacy"], + undefined + >, + weightNormHistory: WeightNormHistory, + totalRound: number, ): Promise { - let ret = current; - - // Clipping radius for BFT - if ("byzantineFaultTolerance" in options) { - // might need to change the variable name - const previousRoundWeights = - previous ?? current.map((w) => tf.zerosLike(w)); - const weightsProgress = current.sub(previousRoundWeights); - ret = previousRoundWeights.add( - await privacy.clipNorm( - weightsProgress, - Repeat(options.byzantineFaultTolerance.clippingRadius) - .take(weightsProgress.weights.length) - .toArray(), - ), - ); - } - - // Adding Gaussian noise for DP - const dpOptions = options.differentialPrivacy; - if (dpOptions !== undefined) { - const dpDefaultRadius = dpOptions.clippingRadius; // options.dpDefaultClippingRadius should be a number - - // Divide privacy budget across all rounds (conservative composition) - const delta = dpOptions.delta / totalRound; - const epsilon = dpOptions.epsilon / totalRound; - - const dpClippingRadius = privacy.getClippingRadius( - weightNormHistory, - dpDefaultRadius, - ); - - const previousEpochWeights = - previous ?? current.map((w) => tf.zerosLike(w)); - const weightsProgress = current.sub(previousEpochWeights); - - /** Need to use tighter clipping radius for noise calibration */ - const effectiveRadius = - "byzantineFaultTolerance" in options - ? dpClippingRadius.map((r) => - Math.min(r, options.byzantineFaultTolerance.clippingRadius), - ) - : dpClippingRadius; - - ret = previousEpochWeights.add( - await privacy.addOptimalNoise( - weightsProgress, - epsilon, - delta, - effectiveRadius, - ), - ); - } - return ret; + let ret = current; + + // Clipping radius for BFT + if ("byzantineFaultTolerance" in options) { + // might need to change the variable name + const previousRoundWeights = + previous ?? current.map((w) => tf.zerosLike(w)); + const weightsProgress = current.sub(previousRoundWeights); + ret = previousRoundWeights.add( + await privacy.clipNorm( + weightsProgress, + Repeat(options.byzantineFaultTolerance.clippingRadius) + .take(weightsProgress.weights.length) + .toArray(), + ), + ); + } + + // Adding Gaussian noise for DP + const dpOptions = options.differentialPrivacy; + if (dpOptions !== undefined) { + const dpDefaultRadius = dpOptions.clippingRadius; // options.dpDefaultClippingRadius should be a number + + // Divide privacy budget across all rounds (conservative composition) + const delta = dpOptions.delta / totalRound; + const epsilon = dpOptions.epsilon / totalRound; + + const dpClippingRadius = privacy.getClippingRadius( + weightNormHistory, + dpDefaultRadius, + ); + + const previousEpochWeights = + previous ?? current.map((w) => tf.zerosLike(w)); + const weightsProgress = current.sub(previousEpochWeights); + + /** Need to use tighter clipping radius for noise calibration */ + const effectiveRadius = + "byzantineFaultTolerance" in options + ? dpClippingRadius.map((r) => + Math.min(r, options.byzantineFaultTolerance.clippingRadius), + ) + : dpClippingRadius; + + ret = previousEpochWeights.add( + await privacy.addOptimalNoise( + weightsProgress, + epsilon, + delta, + effectiveRadius, + ), + ); + } + return ret; } diff --git a/discojs/src/types/data_format.ts b/discojs/src/types/data_format.ts index 681784edf..7047d4be8 100644 --- a/discojs/src/types/data_format.ts +++ b/discojs/src/types/data_format.ts @@ -1,6 +1,12 @@ import { List } from "immutable"; -import type { Image, processing, Tabular, Text, TokenizedText } from "../index.js"; +import type { + Image, + processing, + Tabular, + Text, + TokenizedText, +} from "../index.js"; /** * The data & label format goes through various stages. diff --git a/discojs/src/utils/async_iterator.spec.ts b/discojs/src/utils/async_iterator.spec.ts index 22ce89787..5b7dae9a0 100644 --- a/discojs/src/utils/async_iterator.spec.ts +++ b/discojs/src/utils/async_iterator.spec.ts @@ -39,9 +39,8 @@ describe("split", () => { it("throws returned when iterator throws", async () => { const [gen, ret] = split( - ( // eslint-disable-next-line @typescript-eslint/require-await - async function* () { + (async function* () { throw new Error(); })(), ); diff --git a/discojs/src/utils/event_emitter.ts b/discojs/src/utils/event_emitter.ts index f6ea3c26d..7053c0e8e 100644 --- a/discojs/src/utils/event_emitter.ts +++ b/discojs/src/utils/event_emitter.ts @@ -1,8 +1,8 @@ // inspired by https://danilafe.com/blog/typescript_typesafe_events/ -import { List } from 'immutable' +import { List } from "immutable"; -type Listener = (_: T) => void | Promise +type Listener = (_: T) => void | Promise; /** * Call handlers on given events @@ -13,33 +13,33 @@ export class EventEmitter> { // List of callbacks to run per event #listeners: { [E in keyof I]?: List<[once: boolean, _: Listener]>; - } = {} + } = {}; /** * @param initialListeners object/mapping of event name to listener, as if using `on` on created instance */ - constructor ( + constructor( initialListeners: { [E in keyof I]?: Listener; - } = {} + } = {}, ) { for (const event in initialListeners) { - const listener = initialListeners[event] + const listener = initialListeners[event]; if (listener !== undefined) { - this.on(event, listener) + this.on(event, listener); } } } /** - * Register listener to call on event. + * Register listener to call on event. * * @param event event name to listen to * @param listener handler to call */ on(event: E, listener: Listener): void { - const eventListeners = this.#listeners[event] ?? List() - this.#listeners[event] = eventListeners.push([false, listener]) + const eventListeners = this.#listeners[event] ?? List(); + this.#listeners[event] = eventListeners.push([false, listener]); } /** @@ -49,8 +49,8 @@ export class EventEmitter> { * @param listener handler to call next time */ once(event: E, listener: Listener): void { - const eventListeners = this.#listeners[event] ?? List() - this.#listeners[event] = eventListeners.push([true, listener]) + const eventListeners = this.#listeners[event] ?? List(); + this.#listeners[event] = eventListeners.push([true, listener]); } /** @@ -60,10 +60,12 @@ export class EventEmitter> { * @param value what to call listeners with */ emit(event: E, value: I[E]): void { - const eventListeners = this.#listeners[event] ?? List() - this.#listeners[event] = eventListeners.filterNot(([once]) => once) + const eventListeners = this.#listeners[event] ?? List(); + this.#listeners[event] = eventListeners.filterNot(([once]) => once); - eventListeners.forEach(async ([_, listener]) => { await listener(value) }) + eventListeners.forEach(async ([_, listener]) => { + await listener(value); + }); } } diff --git a/discojs/src/validator.ts b/discojs/src/validator.ts index e5e59fccc..a1bd01c39 100644 --- a/discojs/src/validator.ts +++ b/discojs/src/validator.ts @@ -1,10 +1,10 @@ import type { - Dataset, - DataFormat, - DataType, - Model, - Task, - Network, + Dataset, + DataFormat, + DataType, + Model, + Task, + Network, } from "./index.js"; import { processing } from "./index.js"; @@ -19,25 +19,25 @@ export class Validator { } /** infer every line of the dataset and check that it is as labelled */ - test( - dataset: Dataset, - ): Dataset> { - const preprocessed = processing.preprocess(this.task, dataset); - const batched = preprocessed.batch(this.task.trainingInformation.batchSize); - - const predictionWithTruth = batched - .map(async (batch) => - (await this.#model.predict(batch.map(([inputs, _]) => inputs))).zip( - batch.map(([_, outputs]) => outputs), - ), - ) - .flatten(); - - return predictionWithTruth.map(([predicted, truth]) => ({ - predicted: processing.postprocess(this.task, predicted), - truth: processing.postprocess(this.task, truth), - })); - } + test( + dataset: Dataset, + ): Dataset> { + const preprocessed = processing.preprocess(this.task, dataset); + const batched = preprocessed.batch(this.task.trainingInformation.batchSize); + + const predictionWithTruth = batched + .map(async (batch) => + (await this.#model.predict(batch.map(([inputs, _]) => inputs))).zip( + batch.map(([_, outputs]) => outputs), + ), + ) + .flatten(); + + return predictionWithTruth.map(([predicted, truth]) => ({ + predicted: processing.postprocess(this.task, predicted), + truth: processing.postprocess(this.task, truth), + })); + } /** use the model to predict every line of the dataset */ async *infer( @@ -49,9 +49,9 @@ export class Validator { .map((batch) => this.#model.predict(batch)) .flatten(); - const predictions = modelPredictions.map((prediction) => - processing.postprocess(this.task, prediction), - ); + const predictions = modelPredictions.map((prediction) => + processing.postprocess(this.task, prediction), + ); for await (const e of predictions) yield e; } diff --git a/discojs/src/weights/aggregation.spec.ts b/discojs/src/weights/aggregation.spec.ts index 799703025..b9757be71 100644 --- a/discojs/src/weights/aggregation.spec.ts +++ b/discojs/src/weights/aggregation.spec.ts @@ -1,35 +1,41 @@ import { assert, describe, it } from "vitest"; import { WeightsContainer, aggregation } from "./index.js"; -describe('weights aggregation', () => { - it('avg of weights with two operands', () => { +describe("weights aggregation", () => { + it("avg of weights with two operands", () => { const actual = aggregation.avg([ WeightsContainer.of([1, 2, 3, -1], [-5, 6]), WeightsContainer.of([2, 3, 7, 1], [-10, 5]), - WeightsContainer.of([3, 1, 5, 3], [-15, 19]) - ]) - const expected = WeightsContainer.of([2, 2, 5, 1], [-10, 10]) + WeightsContainer.of([3, 1, 5, 3], [-15, 19]), + ]); + const expected = WeightsContainer.of([2, 2, 5, 1], [-10, 10]); - assert.isTrue(actual.equals(expected)) - }) + assert.isTrue(actual.equals(expected)); + }); - it('sum of weights with two operands', () => { + it("sum of weights with two operands", () => { const actual = aggregation.sum([ [[3, -4], [9]], - [[2, 13], [0]] - ]) - const expected = WeightsContainer.of([5, 9], [9]) + [[2, 13], [0]], + ]); + const expected = WeightsContainer.of([5, 9], [9]); - assert.isTrue(actual.equals(expected)) - }) + assert.isTrue(actual.equals(expected)); + }); - it('diff of weights with two operands', () => { + it("diff of weights with two operands", () => { const actual = aggregation.diff([ - [[3, -4, 5], [9, 1]], - [[2, 13, 4], [0, 1]] - ]) - const expected = WeightsContainer.of([1, -17, 1], [9, 0]) + [ + [3, -4, 5], + [9, 1], + ], + [ + [2, 13, 4], + [0, 1], + ], + ]); + const expected = WeightsContainer.of([1, -17, 1], [9, 0]); - assert.isTrue(actual.equals(expected)) - }) -}) + assert.isTrue(actual.equals(expected)); + }); +}); diff --git a/discojs/src/weights/aggregation.ts b/discojs/src/weights/aggregation.ts index e5b904a83..868d23b8a 100644 --- a/discojs/src/weights/aggregation.ts +++ b/discojs/src/weights/aggregation.ts @@ -1,35 +1,41 @@ -import { List } from 'immutable' -import * as tf from '@tensorflow/tfjs' +import { List } from "immutable"; +import * as tf from "@tensorflow/tfjs"; -import type { TensorLike } from './weights_container.js' -import { WeightsContainer } from './weights_container.js' +import type { TensorLike } from "./weights_container.js"; +import { WeightsContainer } from "./weights_container.js"; -type WeightsLike = Iterable +type WeightsLike = Iterable; -function parseWeights (weights: Iterable): List { +function parseWeights( + weights: Iterable, +): List { const r = List(weights).map((w) => - w instanceof WeightsContainer ? w : new WeightsContainer(w)) - const size = r.first()?.weights.length + w instanceof WeightsContainer ? w : new WeightsContainer(w), + ); + const size = r.first()?.weights.length; if (size === undefined) { - throw new Error('no weights to work with') + throw new Error("no weights to work with"); } r.rest().forEach((w) => { - const actual = w.weights.length + const actual = w.weights.length; if (actual !== size) { - throw new Error(`weights dimensions are different for some of the operands: expected ${size} but found ${actual}`) + throw new Error( + `weights dimensions are different for some of the operands: expected ${size} but found ${actual}`, + ); } - }) + }); - return r + return r; } -function reduce ( +function reduce( weights: Iterable, - fn: (a: tf.Tensor, b: tf.Tensor) => tf.Tensor + fn: (a: tf.Tensor, b: tf.Tensor) => tf.Tensor, ): WeightsContainer { - return parseWeights(weights).reduce((acc: WeightsContainer, ws: WeightsContainer) => - acc.mapWith(ws, fn)) + return parseWeights(weights).reduce( + (acc: WeightsContainer, ws: WeightsContainer) => acc.mapWith(ws, fn), + ); } /** @@ -37,16 +43,20 @@ function reduce ( * @param weights The list of weights to sum * @returns The summed weights */ -export function sum (weights: Iterable): WeightsContainer { - return reduce(weights, tf.add) +export function sum( + weights: Iterable, +): WeightsContainer { + return reduce(weights, tf.add); } /** * Computes the successive entry-wise difference between the weights of the given iterable. * The operation is not commutative w.r.t. the iterable's ordering. */ -export function diff (weights: Iterable): WeightsContainer { - return reduce(weights, tf.sub) +export function diff( + weights: Iterable, +): WeightsContainer { + return reduce(weights, tf.sub); } /** @@ -54,7 +64,9 @@ export function diff (weights: Iterable): Weight * @param weights The list of weights to average * @returns The averaged weights */ -export function avg (weights: Iterable): WeightsContainer { - const ws = List(weights) - return sum(ws).map((w) => w.div(ws.size)) +export function avg( + weights: Iterable, +): WeightsContainer { + const ws = List(weights); + return sum(ws).map((w) => w.div(ws.size)); } diff --git a/discojs/src/weights/index.ts b/discojs/src/weights/index.ts index 84a0df956..b44515510 100644 --- a/discojs/src/weights/index.ts +++ b/discojs/src/weights/index.ts @@ -1,2 +1,2 @@ -export { WeightsContainer } from './weights_container.js' -export * as aggregation from './aggregation.js' +export { WeightsContainer } from "./weights_container.js"; +export * as aggregation from "./aggregation.js"; diff --git a/discojs/src/weights/weights_container.ts b/discojs/src/weights/weights_container.ts index 9d28d2fd5..4a146198e 100644 --- a/discojs/src/weights/weights_container.ts +++ b/discojs/src/weights/weights_container.ts @@ -1,15 +1,15 @@ -import { List } from 'immutable' -import * as tf from '@tensorflow/tfjs' +import { List } from "immutable"; +import * as tf from "@tensorflow/tfjs"; -type Weights = tf.Tensor[] +type Weights = tf.Tensor[]; -export type TensorLike = tf.Tensor | ArrayLike +export type TensorLike = tf.Tensor | ArrayLike; /** * Convenient wrapper object representing an immutable list of TF.js tensors. */ export class WeightsContainer { - private readonly _weights: List + private readonly _weights: List; /** * Constructs a weights container based on the given weights iterable. @@ -18,11 +18,12 @@ export class WeightsContainer { */ constructor(weights: Iterable) { this._weights = List(weights).map((w) => - w instanceof tf.Tensor ? w : tf.tensor(w)) + w instanceof tf.Tensor ? w : tf.tensor(w), + ); } get weights(): Weights { - return this._weights.toArray() + return this._weights.toArray(); } /** @@ -32,7 +33,7 @@ export class WeightsContainer { * @returns A new subtracted weights container */ add(other: WeightsContainer): WeightsContainer { - return this.mapWith(other, tf.add) + return this.mapWith(other, tf.add); } /** @@ -42,7 +43,7 @@ export class WeightsContainer { * @returns A new subtracted weights container */ sub(other: WeightsContainer): WeightsContainer { - return this.mapWith(other, tf.sub) + return this.mapWith(other, tf.sub); } /** @@ -52,10 +53,7 @@ export class WeightsContainer { * @returns A new multiplied weights container */ mul(other: TensorLike | number): WeightsContainer { - return new WeightsContainer( - this._weights - .map(w => w.mul(other)) - ) + return new WeightsContainer(this._weights.map((w) => w.mul(other))); } /** @@ -65,22 +63,29 @@ export class WeightsContainer { * @param fn The binary operator * @returns The mapping's result */ - mapWith(other: WeightsContainer, fn: (a: tf.Tensor, b: tf.Tensor) => tf.Tensor): WeightsContainer { + mapWith( + other: WeightsContainer, + fn: (a: tf.Tensor, b: tf.Tensor) => tf.Tensor, + ): WeightsContainer { return new WeightsContainer( this._weights .zip(other._weights) - .map(([w1, w2]) => fn(w1, w2 as tf.Tensor)) - ) + .map(([w1, w2]) => fn(w1, w2 as tf.Tensor)), + ); } - map(fn: (t: tf.Tensor, i: number) => tf.Tensor): WeightsContainer - map(fn: (t: tf.Tensor) => tf.Tensor): WeightsContainer - map(fn: ((t: tf.Tensor) => tf.Tensor) | ((t: tf.Tensor, i: number) => tf.Tensor)): WeightsContainer { - return new WeightsContainer(this._weights.map(fn)) + map(fn: (t: tf.Tensor, i: number) => tf.Tensor): WeightsContainer; + map(fn: (t: tf.Tensor) => tf.Tensor): WeightsContainer; + map( + fn: + | ((t: tf.Tensor) => tf.Tensor) + | ((t: tf.Tensor, i: number) => tf.Tensor), + ): WeightsContainer { + return new WeightsContainer(this._weights.map(fn)); } reduce(fn: (acc: tf.Tensor, t: tf.Tensor) => tf.Tensor): tf.Tensor { - return this._weights.reduce(fn) + return this._weights.reduce(fn); } /** @@ -89,24 +94,24 @@ export class WeightsContainer { * @returns The tensor located at the index */ get(index: number): tf.Tensor | undefined { - return this._weights.get(index) + return this._weights.get(index); } concat(other: WeightsContainer): WeightsContainer { - return WeightsContainer.of( - ...this.weights, - ...other.weights - ) + return WeightsContainer.of(...this.weights, ...other.weights); } equals(other: WeightsContainer, margin = 0): boolean { return this._weights .zip(other._weights) - .every(([w1, w2]) => w1.sub(w2).abs().lessEqual(margin).all().dataSync()[0] === 1) + .every( + ([w1, w2]) => + w1.sub(w2).abs().lessEqual(margin).all().dataSync()[0] === 1, + ); } - + dispose(): void { - this._weights.forEach(w => w.dispose()); + this._weights.forEach((w) => w.dispose()); } /** @@ -115,6 +120,6 @@ export class WeightsContainer { * @returns The instantiated weights container */ static of(...weights: TensorLike[]): WeightsContainer { - return new this(weights) + return new this(weights); } } diff --git a/discojs/tsconfig.json b/discojs/tsconfig.json index 748edd7e2..3a4772a71 100644 --- a/discojs/tsconfig.json +++ b/discojs/tsconfig.json @@ -11,4 +11,4 @@ "path": "tsconfig.vitest.json" } ] -} \ No newline at end of file +} diff --git a/discojs/tsconfig.lib.json b/discojs/tsconfig.lib.json index 81e28b0ba..8c8a0ef5d 100644 --- a/discojs/tsconfig.lib.json +++ b/discojs/tsconfig.lib.json @@ -1,6 +1,6 @@ { - "extends": "../tsconfig.base.json", - "compilerOptions": { "outDir": "dist" }, - "include": ["src"], - "exclude": ["**/*.spec.ts"] + "extends": "../tsconfig.base.json", + "compilerOptions": { "outDir": "dist" }, + "include": ["src"], + "exclude": ["**/*.spec.ts"] } diff --git a/discojs/tsconfig.vitest.json b/discojs/tsconfig.vitest.json index c4a3913ec..63288c889 100644 --- a/discojs/tsconfig.vitest.json +++ b/discojs/tsconfig.vitest.json @@ -1,5 +1,5 @@ { - "extends": "../tsconfig.base.json", - "compilerOptions": { "noEmit": true }, - "include": ["src"] + "extends": "../tsconfig.base.json", + "compilerOptions": { "noEmit": true }, + "include": ["src"] } diff --git a/docs/CONTRIBUTING.md b/docs/CONTRIBUTING.md index 4d92159f3..1109fc233 100644 --- a/docs/CONTRIBUTING.md +++ b/docs/CONTRIBUTING.md @@ -105,9 +105,11 @@ npm -w webapp test The webapp tests rely on `cypress` and the test suite is located in the `webapp/cypress` folder. Note that you can also run test interactively in the browser of your choice. To do so, run + ``` VITE_SERVER_URL=http://server npx -w webapp start-server-and-test start http://localhost:8081 'cypress open --e2e' ``` + which should open the Cypress UI and let you choose the browser you wand to use and which tests to run. More information on [the Cypress docs](https://docs.cypress.io/app/get-started/open-the-app). #### Cypress and Github Actions @@ -116,12 +118,16 @@ It is possible to record the cypress tests ran in the Github Actions CI and visu 1. A Disco project has been created in the Cypress Cloud and you need to be added to the project to be able to visualize the recordings. -2. In case a new Cypress project is now being used, make sure that the settings are correct: +2. In case a new Cypress project is now being used, make sure that the settings are correct: + - In `webapp/cypress.config.ts` make sure the correct project ID has been set, It currently is: + ```js - projectId: "aps8et" +projectId: "aps8et"; ``` + - The github workflow `.github/workflows/record-cypress.yml` relies on `CYPRESS_RECORD_KEY` which is a github repository secret. + 3. Finally, you can trigger the `record-cypress` workflow manually from github as described in the [documentation](https://docs.github.com/en/actions/managing-workflow-runs-and-deployments/managing-workflow-runs/manually-running-a-workflow#running-a-workflow) ### Contributing to `discojs` @@ -167,31 +173,36 @@ Similarly to the server, any file ending with `.spec.ts` will be ran in the test Currently, the `discojs-node` project is available as the `@epfml/discojs-node` NPM package, which can be installed with `npm i @epfml/discojs-node` and the `discojs-web` as the `@epfml/discojs-web`. - ### Debugging > [!TIP] > If your code changes don't seem to be effective, close everything, rebuild everything and restart. For example, changes in `discojs/src/default_tasks` requires rebuilding `discojs` and restarting the `server` to be effective. In Disco, we rely on the widely used [`debug` library](https://github.com/debug-js/debug). To use it, we first import debug and instantiate the debug object: + ```js import createDebug from "debug"; const debug = createDebug("discojs:models:gpt:model"); // use nested namespaces -const logs = { loss: 0.01, accuracy: 0.56} -debug("Here are the GPT logs: %o", logs) +const logs = { loss: 0.01, accuracy: 0.56 }; +debug("Here are the GPT logs: %o", logs); ``` #### In the terminal + To visualize the logs in the command line, we need to set the `DEBUG` environment variable to choose the namespaces from which you want to see the debug statements. For example: + ```bash DEBUG='discojs:models:gpt*' npm -w cli run benchmark_gpt ``` + will print the debug statement from above. Similarly if we set `DEBUG='*'`. The server debug statements are visualized the same way, for example: + ```bash DEBUG='server*,discojs*' npm -w server start ``` + shows the debug statements from anywhere in the server and in discojs. #### Webapp @@ -199,9 +210,13 @@ shows the debug statements from anywhere in the server and in discojs. To visualize debug statements in the browser, you need to open the console (Inspect element > Console) and set the `localStorage.debug` to the namespace of your choice, for example `localStorage.debug='webapp*,discojs*'` to visualize both the debug statements from anywhere in the webapp and in discojs. Note that you may need to refresh the page for changes to localStorage to be effective. To get debug statements in the Cypress tests you need to modify `webapp/cypress/support/e2e.ts` and add: + ```js -beforeEach(() => { localStorage.debug = "discojs*,webapp*" }); +beforeEach(() => { + localStorage.debug = "discojs*,webapp*"; +}); ``` + We need to set the `localStorage` before each test because it is reset between each unit tests. ## Contributing conventions diff --git a/docs/FAQ.md b/docs/FAQ.md index 03902a79b..f4887af18 100644 --- a/docs/FAQ.md +++ b/docs/FAQ.md @@ -44,11 +44,12 @@ You may not be able to open the editor from the repo root level without VSCode r ### npm failing to download or install packages -Firewall settings can sometimes block npm's network requests, preventing it from downloading packages. +Firewall settings can sometimes block npm's network requests, preventing it from downloading packages. **Troubleshooting:** -1. Temporarily disable the firewall. If npm works, the firewall is the cause. -2. Identify the specific firewall rule that is blocking npm and adjust it. + +1. Temporarily disable the firewall. If npm works, the firewall is the cause. +2. Identify the specific firewall rule that is blocking npm and adjust it. 3. Make sure your firewall allows outbound HTTPS (port 443) connections to `registry.npmjs.org`. ### Custom `nvm` setup on EPFL's RCP Cluster @@ -65,35 +66,49 @@ This guide explains how to install and use `nvm` on EPFL's RCP cluster environme --- #### Step 1 — Set install location + ``` export NVM_DIR="/mloscratch/homes/USERNAME/.nvm" ``` + #### Step 2 — Install nvm + ``` curl -o- https://raw.githubusercontent.com/nvm-sh/nvm/v0.39.7/install.sh | bash ``` + #### Step 3 — Load nvm (test it) + ``` source "$NVM_DIR/nvm.sh" nvm --version ``` + #### Step 4 — Configure zsh (IMPORTANT) + Open zsh config file: + ``` vi /mloscratch/homes/USERNAME/.shell/.zshrc ``` + Add this near the top or after plugins: + ``` export NVM_DIR="/mloscratch/homes/USERNAME/.nvm" [ -s "$NVM_DIR/nvm.sh" ] && source "$NVM_DIR/nvm.sh" ``` + #### Step 5 — Reload shell + ``` source /mloscratch/homes/USERNAME/.shell/.zshrc ``` + #### Step 6 — Verify everything + ``` which node nvm ls node -v -``` \ No newline at end of file +``` diff --git a/docs/PRIVACY.md b/docs/PRIVACY.md index d1e6a5d61..a4103d0ce 100644 --- a/docs/PRIVACY.md +++ b/docs/PRIVACY.md @@ -1,10 +1,12 @@ # Privacy protection measures In federated and decentralized learning, a client's data is never sent to another machine. However, some information could be inferred about a client's data set even when data isn't shared. For instance, summary statistics or even the existence of a specific data point can be inferred from sources such as: + 1. the weights of the public and collaborative model ([Carlini et al., 2019](https://www.usenix.org/conference/usenixsecurity19/presentation/carlini)); 2. the model updates shared by the client ([Bonawitz et al., 2017](https://doi.org/10.1145/3133956.3133982)). -In addition to the intrinsic security of federated and decentralized learning, DISCO ensures privacy and security via two different and complementary methods: +In addition to the intrinsic security of federated and decentralized learning, DISCO ensures privacy and security via two different and complementary methods: + 1. Differential privacy ([McMahan et al., 2018](http://arxiv.org/abs/1710.06963) and [Abadi et al., 2016](https://doi.org/10.1145/2976749.2978318)), and 2. Secure aggregation of model updates ([Bonawitz et al., 2017](https://doi.org/10.1145/3133956.3133982)). @@ -15,6 +17,7 @@ Differential privacy methods protect any dataset(s) used in the training of a ma The respective parameters `epsilon`, `delta`, and `clippingRadius` are available in the [task configuration](TASK.md). ### What is Differential Privacy? + Differential privacy (DP) is a rigorous privacy framework that provides a privacy guarantee by ensuring that an algorithm's output does not significantly change when a single data point in the dataset is modified. This protection is achieved by adding carefully calibrated random values (called "noise") to the data or model updates. In DISCO, differential privacy ensures privacy by making sure that the weight updates produced by one client do not significantly change when a single data point in that client's dataset is modified. This is called local differential privacy (LDP). Before sharing weight updates with the server, random noise is added to these updates. By examining only the weight updates that each client sends to the server, no party, including the server, can infer who generated a specific update or which datasets particular clients have. @@ -22,23 +25,29 @@ In DISCO, differential privacy ensures privacy by making sure that the weight up Differntial privacy has an important parameter, epsilon($\epsilon$), which indicates the privacy level applied to the learning process. It is also called the "privacy budget." ### Parameter Explanations + Differential privacy is achieved by adding noise. To guarantee your desired privacy level, you need to specify several parameters: `epsilon` + - This is the privacy budget. The smaller the $\epsilon$ value, the stronger the privacy protection. In DISCO, this $\epsilon$ value indicates the privacy guarantee for a single client. -`delta` +`delta` + - This parameter indicates the failure pobability of the privacy guarantee. It is used in approximate differential privacy, which DISCO implemented. `clipping radius` + - This parameter sets the maximum bound for the adaptive clipping radius. ### Privacy-utility trade-off + The utility degradation that follows from improving privacy is an inherent feature of differential privacy, so you must consider this when choosing your $\epsilon$ value. When $\epsilon$ equals 0, this guarantees perfect privacy but zero utility. As $\epsilon$ approaches infinity, privacy becomes zero and full utility is recovered. As $\epsilon$ decreases, utility degrades gradually. When we repetitively run the same private algorithm, the privacy budget accumulates, resulting in a larger final privacy budget that indicates a weaker privacy guarantee. This is called "composition" of privacy budget. This applies to DP in DISCO: since we add noise to weight updates at every epoch, the privacy budget accumulates with each epoch. The accumulation rate is determined by the total number of epochs defined in the task configuration. ### What is the best $\epsilon$ value? + Choosing an appropriate $\epsilon$ value depends on your specific use case and requires careful consideration of the privacy-utility trade-off. - For local differential privacy (LDP), which DISCO implements, meaningful utility typically requires larger $\epsilon$ values compared to central differential privacy. In practice, LDP implementations often use $\epsilon$ values ranging from 5 to 20. Some implementations may use higher values, though this comes with weaker privacy guarantee. @@ -46,59 +55,62 @@ Choosing an appropriate $\epsilon$ value depends on your specific use case and r - Higher $\epsilon$ values (above 20) may provide better model performance but offer weaker privacy protection. - The approapriate $\epsilon$ value for your task depends on several factors as below. - - Your acceptable level of model accuracy degradation - - The number of rounds (due to privacy budget composition over rounds) + - Your acceptable level of model accuracy degradation + - The number of rounds (due to privacy budget composition over rounds) - To provide context, here are examples of $\epsilon$ values used in real-world deployments: - - Apple's local differential privacy implementation for iOS and macOS uses $\epsilon$ = 16 for QuickType suggestions, with a privacy unit of user per day ([Apple Differential Privacy Overview](https://www.apple.com/privacy/docs/Differential_Privacy_Overview.pdf)) - - Microsoft's Windows telemetry collection uses local differential privacy with $\epsilon$ = 1.672, with a privacy unit of user per 6 hours ([Ding et al., 2017](https://www.microsoft.com/en-us/research/publication/collecting-telemetry-data-privately/)) + - Apple's local differential privacy implementation for iOS and macOS uses $\epsilon$ = 16 for QuickType suggestions, with a privacy unit of user per day ([Apple Differential Privacy Overview](https://www.apple.com/privacy/docs/Differential_Privacy_Overview.pdf)) + - Microsoft's Windows telemetry collection uses local differential privacy with $\epsilon$ = 1.672, with a privacy unit of user per 6 hours ([Ding et al., 2017](https://www.microsoft.com/en-us/research/publication/collecting-telemetry-data-privately/)) ### DISCO's Differential Privacy Implementation -Since model weights are shared for aggregation to converge to a final model in DISCO, we add DP noise to weight updates before sharing them with server or other clients. This noise is calibrated with an interaction between $\epsilon$, $\delta$, and `clipping_radius`. +Since model weights are shared for aggregation to converge to a final model in DISCO, we add DP noise to weight updates before sharing them with server or other clients. This noise is calibrated with an interaction between $\epsilon$, $\delta$, and `clipping_radius`. To carefully calibrate the smallest possible noise for a given privacy guarantee, we implement window-based adaptive local differential privacy(ALDP). The ALDP process works as follows. - 1. Each round, before sharing the weight update with the server, we calibrate the noise using $\epsilon$, $\delta$, and a new adaptive clipping radius, which is the mean value of the three previous weight updates. This helps us find the optimal clipping radius that avoids over-calibrating the noise needed for the privacy guarantee. - 2. We add the calibrated noise to the current weight update and share it with the server. - 3. We store the weight update before noise addition to use for calibrating the clipping radius in the next round. + +1. Each round, before sharing the weight update with the server, we calibrate the noise using $\epsilon$, $\delta$, and a new adaptive clipping radius, which is the mean value of the three previous weight updates. This helps us find the optimal clipping radius that avoids over-calibrating the noise needed for the privacy guarantee. +2. We add the calibrated noise to the current weight update and share it with the server. +3. We store the weight update before noise addition to use for calibrating the clipping radius in the next round. ## Secure aggregation through MPC DISCO protects the clients' data from inference attacks based on the model updates shared by the clients, by ensuring that an individual client's model update is never revealed. This is achieved by secure update aggregation, where multiple clients use secure multiparty computation (MPC) to jointly compute the sum of their model updates without revealing the summands. -In DISCO, we rely on secure aggregation of models / model updates, in each communication round, in order to fully protect the privacy of each user. +In DISCO, we rely on secure aggregation of models / model updates, in each communication round, in order to fully protect the privacy of each user. ### Concept: Private data - Public model -We guarantee input privacy of each personal update and each client's data. +We guarantee input privacy of each personal update and each client's data. The model resulting from training is considered public, both in the federated and decentralized schemes. ### Set-up -Our secure aggregation mechanism is implemented in each communication round, within small aggregation groups of a minimum size, which are formed from clients available to exchange model updates. +Our secure aggregation mechanism is implemented in each communication round, within small aggregation groups of a minimum size, which are formed from clients available to exchange model updates. ### Algorithm description **Orchestration via client-server communication:** + 1. The server keeps track of which clients are ready to share model weights with each other, in order to let them know when enough clients are ready. -Thus, before the aggregation begins, there is a preliminary communication step between the clients and the server: + Thus, before the aggregation begins, there is a preliminary communication step between the clients and the server: 1. Whenever a client finishes a round of local updates, it sends a "ready message" to the server to signal that it is ready to exchange model updates. 2. Once enough clients are ready, the server sends them the list of clients to aggregate with. 3. If the client receives the list of ready peers within a certain time frame after sending its "ready message", it begins the secure aggregation procedure. The **secure aggregation procedure** consists of two rounds of all-to-all communication. In other words, two messages are sent from each member of the list of ready clients, to each member of the list of ready clients: -2. The client generates *n* additive secret shares from their own model update and sends them to the other clients. - 1. *n* is the number of clients participating in the aggregation procedure. Hence `len(list of ready clients) = n`. - 2. Each share has the same shape as the model weights. The *n* shares are generated at random under the constraint that their element-wise sum must be equal to the client's model update. - 3. Once it has generated *n* additive secret shares, the client sends one share to each client who is participating in the aggregation procedure (including to itself). Note that each individual client is unable to reconstruct any other client's model update, because the latter is independent from any strict subset of the set of all *n* shares generated from the model update. - 4. The client expects to receive *n* shares (one from each client on the list). If it receives all expected shares within a certain time frame, it moves on to the next step of the procedure. +2. The client generates _n_ additive secret shares from their own model update and sends them to the other clients. + 1. _n_ is the number of clients participating in the aggregation procedure. Hence `len(list of ready clients) = n`. + 2. Each share has the same shape as the model weights. The _n_ shares are generated at random under the constraint that their element-wise sum must be equal to the client's model update. + 3. Once it has generated _n_ additive secret shares, the client sends one share to each client who is participating in the aggregation procedure (including to itself). Note that each individual client is unable to reconstruct any other client's model update, because the latter is independent from any strict subset of the set of all _n_ shares generated from the model update. + 4. The client expects to receive _n_ shares (one from each client on the list). If it receives all expected shares within a certain time frame, it moves on to the next step of the procedure. 3. The client has received one share from each client on the list. 1. The client computes the sum of the received shares. We call it the _partial sum_. Note that the _partial sums_ computed by all of the clients add up to the sum of the clients' model updates, because every share is accounted for exactly once. 2. The client then sends this _partial sum_ to all clients on the list. - 3. If the client receives all *n* partial sums within a certain time frame, it reconstructs the sum of the model updates of all clients on the list by computing the sum of the partial sums. + 3. If the client receives all _n_ partial sums within a certain time frame, it reconstructs the sum of the model updates of all clients on the list by computing the sum of the partial sums. **Return value:** + - At steps 1.3, 2.4, or 3.3, if the client does not receive all expected message(s) by the end of the time frame, the aggregation round is considered to have failed and the algorithm returns the client's own model update (this value is returned to a local routine of the client). - Otherwise, it returns the reconstructed sum of the model updates of all clients who participated in the aggregation procedure. @@ -108,6 +120,7 @@ DISCO secure aggregation guarantees input privacy for a client's model updates. It is worth noting that due to the current use of floating point arithmetic instead of finite fields implies an effect on the quantization of models. Alternatively quantized integer model weights (with scaling) can be used. Currently, the additive shares generated at step 2 are filled with floating-point values drawn uniformly at random from the interval `[-maxShareValue, +maxShareValue)`. + - If `maxShareValue` is too small, privacy is lost because larger random numbers better obfuscate the secret. Indeed, for each client A, there is one other client who can construct a (1-e)-confidence-interval of size `maxShareValue*2`, where e is also monotonically increasing in `maxShareValue`. - However, at the same time, as `maxShareValue` increases, the reconstruction problem becomes more ill-conditioned (subtraction of very large numbers to obtain a relatively small number). Indeed, given the finite precision of floating point number representations, every order of magnitude increase in `maxShareValue` increases the expected reconstruction error by one order of magnitude. diff --git a/docs/examples/README.md b/docs/examples/README.md index 9d5b0cffc..f836c2ca6 100644 --- a/docs/examples/README.md +++ b/docs/examples/README.md @@ -23,8 +23,7 @@ As you can see in `training.ts` a client is represented by a `Disco` object: ```js const disco = new Disco(task, { url, scheme: "federated" }); for await (const round of disco.fit(dataset)) - for await (const epoch of round) - for await (const batch of epoch); + for await (const epoch of round) for await (const batch of epoch); await disco.close(); ``` diff --git a/docs/examples/custom_task.ts b/docs/examples/custom_task.ts index deb3c08a3..706c465c5 100644 --- a/docs/examples/custom_task.ts +++ b/docs/examples/custom_task.ts @@ -1,63 +1,61 @@ -import tf from '@tensorflow/tfjs-node' +import tf from "@tensorflow/tfjs-node"; -import type { TaskProvider } from '@epfml/discojs' -import { defaultTasks, models } from '@epfml/discojs' -import { Server as DiscoServer } from 'server' +import type { TaskProvider } from "@epfml/discojs"; +import { defaultTasks, models } from "@epfml/discojs"; +import { Server as DiscoServer } from "server"; // Define your own task provider (task definition + model) const customTask: TaskProvider<"tabular", "federated"> = { - getTask () { + getTask() { return Promise.resolve({ - id: 'custom-task', + id: "custom-task", dataType: "tabular", displayInformation: { - title: 'Custom task', + title: "Custom task", summary: { - preview: 'task preview', - overview: 'task overview' - } + preview: "task preview", + overview: "task overview", + }, }, trainingInformation: { epochs: 5, roundDuration: 10, validationSplit: 0, batchSize: 30, - inputColumns: [ - 'Age' - ], - outputColumn: 'Output', - scheme: 'federated', + inputColumns: ["Age"], + outputColumn: "Output", + scheme: "federated", aggregationStrategy: "mean", minNbOfParticipants: 2, - tensorBackend: 'tfjs', + tensorBackend: "tfjs", privacy: undefined, - } + }, }); }, - getModel () { - const model = tf.sequential() + getModel() { + const model = tf.sequential(); model.add( tf.layers.dense({ inputShape: [1], units: 124, - activation: 'relu', - kernelInitializer: 'leCunNormal' - }) - ) - model.add(tf.layers.dense({ units: 32, activation: 'relu' })) - model.add(tf.layers.dense({ units: 1, activation: 'sigmoid' })) + activation: "relu", + kernelInitializer: "leCunNormal", + }), + ); + model.add(tf.layers.dense({ units: 32, activation: "relu" })); + model.add(tf.layers.dense({ units: 1, activation: "sigmoid" })); model.compile({ - optimizer: 'rmsprop', - loss: 'binaryCrossentropy', - metrics: ['accuracy'] - }) + optimizer: "rmsprop", + loss: "binaryCrossentropy", + metrics: ["accuracy"], + }); - return Promise.resolve(new models.TFJS('tabular', model)) - } -} + return Promise.resolve(new models.TFJS("tabular", model)); + }, +}; async function runServer(): Promise { // Create a server @@ -70,4 +68,4 @@ async function runServer(): Promise { await server.serve(8080); } -runServer().catch(console.error) +runServer().catch(console.error); diff --git a/docs/examples/package.json b/docs/examples/package.json index e4562f64c..a45631786 100644 --- a/docs/examples/package.json +++ b/docs/examples/package.json @@ -1,23 +1,23 @@ { - "name": "examples", - "private": true, - "type": "module", - "description": "In `training.ts` we give a brief example of discojs, in it we run two clients training jointly via federated learning. It trains on a few examples of the [face task](https://www.kaggle.com/datasets/frabbisw/facial-age), the samples are already stored in the repo and so it is not necessary to download any additional data.", - "scripts": { - "train": "npm run build && node dist/training.js", - "custom_task": "npm run build && node dist/custom_task.js", - "language_model": "npm run build && node dist/wikitext.js", - "benchmark_gpt": "npm run build && node dist/benchmark_gpt.js", - "build": "tsc", - "test": "npm run train" - }, - "license": "ISC", - "dependencies": { - "server": "*", - "@epfml/discojs": "*", - "@epfml/discojs-node": "*" - }, - "devDependencies": { - "typescript": "6" - } + "name": "examples", + "private": true, + "type": "module", + "description": "In `training.ts` we give a brief example of discojs, in it we run two clients training jointly via federated learning. It trains on a few examples of the [face task](https://www.kaggle.com/datasets/frabbisw/facial-age), the samples are already stored in the repo and so it is not necessary to download any additional data.", + "scripts": { + "train": "npm run build && node dist/training.js", + "custom_task": "npm run build && node dist/custom_task.js", + "language_model": "npm run build && node dist/wikitext.js", + "benchmark_gpt": "npm run build && node dist/benchmark_gpt.js", + "build": "tsc", + "test": "npm run train" + }, + "license": "ISC", + "dependencies": { + "server": "*", + "@epfml/discojs": "*", + "@epfml/discojs-node": "*" + }, + "devDependencies": { + "typescript": "6" + } } diff --git a/docs/examples/training.ts b/docs/examples/training.ts index af5b8e666..c93568192 100644 --- a/docs/examples/training.ts +++ b/docs/examples/training.ts @@ -1,6 +1,6 @@ -import { Repeat } from 'immutable' -import * as path from 'node:path' -import '@tensorflow/tfjs-node' +import { Repeat } from "immutable"; +import * as path from "node:path"; +import "@tensorflow/tfjs-node"; import type { Dataset, @@ -9,9 +9,9 @@ import type { Image, Task, } from "@epfml/discojs"; -import { Disco, fetchTasks, defaultTasks } from '@epfml/discojs' -import { loadCSV, loadImagesInDir } from '@epfml/discojs-node' -import { Server } from 'server' +import { Disco, fetchTasks, defaultTasks } from "@epfml/discojs"; +import { loadCSV, loadImagesInDir } from "@epfml/discojs-node"; +import { Server } from "server"; /** * Example of discojs API, we load data, build the appropriate loggers, the disco object @@ -23,13 +23,13 @@ async function runUser( dataset: Dataset, ): Promise { // Create Disco object associated with the server url, the training scheme - const disco = new Disco(task, url, { scheme: 'federated' }) + const disco = new Disco(task, url, { scheme: "federated" }); // Run training on the dataset await disco.trainFully(dataset); // Disconnect from the remote server - await disco.close() + await disco.close(); } type TaskAndDataset = [ @@ -37,9 +37,9 @@ type TaskAndDataset = [ Dataset, ]; -async function main (): Promise { +async function main(): Promise { // Arbitrary chosen Task ID - const NAME: string = 'titanic' + const NAME: string = "titanic"; // Launch a server instance const server = await Server.with( @@ -49,26 +49,30 @@ async function main (): Promise { const [handle, url] = await server.serve(); // Get all pre-defined tasks - const tasks = await fetchTasks(url) + const tasks = await fetchTasks(url); // Choose the task and load local data // Make sure you first ran ./get_training_data - let taskAndDataset: TaskAndDataset<'image' | 'tabular'> + let taskAndDataset: TaskAndDataset<"image" | "tabular">; switch (NAME) { case "titanic": { - const task = tasks.get("titanic") as | Task<"tabular", "federated"> | undefined; + const task = tasks.get("titanic") as + | Task<"tabular", "federated"> + | undefined; if (task === undefined) throw new Error("task not found"); taskAndDataset = [task, loadCSV("../../datasets/titanic_train.csv")]; break; } case "simple_face": { - const task = tasks.get("simple_face") as | Task<"image", "federated"> | undefined; + const task = tasks.get("simple_face") as + | Task<"image", "federated"> + | undefined; if (task === undefined) throw new Error("task not found"); taskAndDataset = [task, await loadSimpleFaceData()]; break; } default: - throw new Error('task id not found') + throw new Error("task id not found"); } // Add more users to the list to simulate more than 3 clients @@ -76,7 +80,7 @@ async function main (): Promise { runUser(url, ...taskAndDataset), runUser(url, ...taskAndDataset), runUser(url, ...taskAndDataset), - ]) + ]); // Close server await new Promise((resolve, reject) => { @@ -97,4 +101,4 @@ async function loadSimpleFaceData(): Promise> { } // You can run this example with "npm run train" from this folder -main().catch(console.error) +main().catch(console.error); diff --git a/docs/examples/wikitext.ts b/docs/examples/wikitext.ts index 25d2f6de0..a6e22cc5f 100644 --- a/docs/examples/wikitext.ts +++ b/docs/examples/wikitext.ts @@ -1,45 +1,55 @@ -import "@tensorflow/tfjs-node" +import "@tensorflow/tfjs-node"; import { Disco, fetchTasks, models, Task } from "@epfml/discojs"; -import { saveModelToDisk, loadModelFromDisk, loadText } from '@epfml/discojs-node' -import { List } from "immutable" +import { + saveModelToDisk, + loadModelFromDisk, + loadText, +} from "@epfml/discojs-node"; +import { List } from "immutable"; -async function main(): Promise { +async function main(): Promise { // Launch a server instance - const url = new URL('http://localhost:8080') - + const url = new URL("http://localhost:8080"); + // Fetch the wikitext task from the server - const tasks = await fetchTasks(url) + const tasks = await fetchTasks(url); const task = tasks.get("llm_task") as Task<"text", "federated"> | undefined; - if (task === undefined) { throw new Error('task not found') } - + if (task === undefined) { + throw new Error("task not found"); + } + let model; - const modelFolder = './models' - const modelFileName = 'model_random.json' + const modelFolder = "./models"; + const modelFileName = "model_random.json"; // Toggle TRAIN_MODEL to either train and save a new model from scratch or load an existing model - const TRAIN_MODEL = true + const TRAIN_MODEL = true; if (TRAIN_MODEL) { // Load the wikitext dataset from the `datasets` folder - const dataset = loadText("../../datasets/wikitext/wiki.train.tokens") - .chain(loadText("../../datasets/wikitext/wiki.valid.tokens")); - + const dataset = loadText("../../datasets/wikitext/wiki.train.tokens").chain( + loadText("../../datasets/wikitext/wiki.valid.tokens"), + ); + // Initialize a Disco instance and start training a language model - const disco = new Disco(task, url, { scheme: 'federated' }) + const disco = new Disco(task, url, { scheme: "federated" }); await disco.trainFully(dataset); - + // Get the model and save the trained model - model = disco.trainer.model as models.GPT - await saveModelToDisk(model, modelFolder, modelFileName) - await disco.close() + model = disco.trainer.model as models.GPT; + await saveModelToDisk(model, modelFolder, modelFileName); + await disco.close(); } else { // Load the trained model - model = await loadModelFromDisk(`${modelFolder}/${modelFileName}`) as models.GPT + model = (await loadModelFromDisk( + `${modelFolder}/${modelFileName}`, + )) as models.GPT; } // Preprocess prompt - const prompt = 'The game began development in 2010 , carrying over a large portion' + const prompt = + "The game began development in 2010 , carrying over a large portion"; const { tokenizer } = task.trainingInformation; let tokens = tokenizer.tokenize(prompt); @@ -47,11 +57,11 @@ async function main(): Promise { const numberOfTokens = 10; for (let i = 0; i < numberOfTokens; i++) { const next = (await model.predict(List.of(tokens))).first(); - if (next === undefined) throw new Error("no prediction"); - tokens = tokens.push(next) + if (next === undefined) throw new Error("no prediction"); + tokens = tokens.push(next); } console.log(tokenizer.decode(tokens.toArray())); } // You can run this example with "npm start" from this folder -main().catch(console.error) +main().catch(console.error); diff --git a/eslint.config.js b/eslint.config.js index b9a97bd1b..33be2464f 100644 --- a/eslint.config.js +++ b/eslint.config.js @@ -3,67 +3,67 @@ import pluginVitest from "@vitest/eslint-plugin"; import skipFormatting from "@vue/eslint-config-prettier/skip-formatting"; import { - defineConfigWithVueTs, - vueTsConfigs, + defineConfigWithVueTs, + vueTsConfigs, } from "@vue/eslint-config-typescript"; import pluginCypress from "eslint-plugin-cypress"; import pluginVue from "eslint-plugin-vue"; export default defineConfigWithVueTs( - pluginVue.configs["flat/recommended"], - vueTsConfigs.recommendedTypeChecked, - { - languageOptions: { - parserOptions: { - projectService: { - allowDefaultProject: [ - "eslint.config.js", - "isomorphic-wrtc/{{browser,node}.js,types.d.ts}", - "testSetupImportTFJSNode.ts", - "vitest.config.ts", - ], - }, - }, - }, - }, - { - rules: { - // taken from https://typescript-eslint.io/rules/no-unused-vars/ - "@typescript-eslint/no-unused-vars": [ - "error", - { - args: "all", - argsIgnorePattern: "^_", - caughtErrors: "all", - caughtErrorsIgnorePattern: "^_", - destructuredArrayIgnorePattern: "^_", - varsIgnorePattern: "^_", - ignoreRestSiblings: true, - }, - ], - // allow biome formatting - "no-mixed-spaces-and-tabs": "off", - // allow for nicer names - "@typescript-eslint/no-namespace": "off", - }, - }, - { - ...pluginVitest.configs.recommended, - files: ["**/*.spec.ts"], - rules: { - // broken w/ BDD https://github.com/vitest-dev/eslint-plugin-vitest/issues/675 - "vitest/valid-expect": "off", - "@typescript-eslint/no-unused-expressions": "off", - }, - }, - { - ...pluginCypress.configs.recommended, - files: ["webapp/cypress/**/*.ts"], - }, - { ignores: ["**/dist/*"] }, - { ignores: ["docs/examples/**"] }, - { ignores: ["**/src/protobuf/"] }, - - // don't use linter for formatting - skipFormatting, + pluginVue.configs["flat/recommended"], + vueTsConfigs.recommendedTypeChecked, + { + languageOptions: { + parserOptions: { + projectService: { + allowDefaultProject: [ + "eslint.config.js", + "isomorphic-wrtc/{{browser,node}.js,types.d.ts}", + "testSetupImportTFJSNode.ts", + "vitest.config.ts", + ], + }, + }, + }, + }, + { + rules: { + // taken from https://typescript-eslint.io/rules/no-unused-vars/ + "@typescript-eslint/no-unused-vars": [ + "error", + { + args: "all", + argsIgnorePattern: "^_", + caughtErrors: "all", + caughtErrorsIgnorePattern: "^_", + destructuredArrayIgnorePattern: "^_", + varsIgnorePattern: "^_", + ignoreRestSiblings: true, + }, + ], + // allow biome formatting + "no-mixed-spaces-and-tabs": "off", + // allow for nicer names + "@typescript-eslint/no-namespace": "off", + }, + }, + { + ...pluginVitest.configs.recommended, + files: ["**/*.spec.ts"], + rules: { + // broken w/ BDD https://github.com/vitest-dev/eslint-plugin-vitest/issues/675 + "vitest/valid-expect": "off", + "@typescript-eslint/no-unused-expressions": "off", + }, + }, + { + ...pluginCypress.configs.recommended, + files: ["webapp/cypress/**/*.ts"], + }, + { ignores: ["**/dist/*"] }, + { ignores: ["docs/examples/**"] }, + { ignores: ["**/src/protobuf/"] }, + + // don't use linter for formatting + skipFormatting, ); diff --git a/isomorphic-wrtc/README.md b/isomorphic-wrtc/README.md index b36e15802..97b55fb0c 100644 --- a/isomorphic-wrtc/README.md +++ b/isomorphic-wrtc/README.md @@ -2,8 +2,8 @@ Allow to load a different WebRTC implementation depending on the platform. -* on node, load @roamhq/wrtc, a C++ plugin -* in browser, simply exposes the available WebRTC implementation +- on node, load @roamhq/wrtc, a C++ plugin +- in browser, simply exposes the available WebRTC implementation It allows to simply `import wrct from 'isomorphic-wrtc'` and get the same coding experience. diff --git a/isomorphic-wrtc/browser.js b/isomorphic-wrtc/browser.js index b66e77adb..0deaeb7f2 100644 --- a/isomorphic-wrtc/browser.js +++ b/isomorphic-wrtc/browser.js @@ -1,2 +1,2 @@ -import getBrowserRTC from 'get-browser-rtc' -export default getBrowserRTC() +import getBrowserRTC from "get-browser-rtc"; +export default getBrowserRTC(); diff --git a/isomorphic-wrtc/node.js b/isomorphic-wrtc/node.js index e139790ae..658a14af8 100644 --- a/isomorphic-wrtc/node.js +++ b/isomorphic-wrtc/node.js @@ -1,2 +1,2 @@ -import wrtc from "@roamhq/wrtc" -export default wrtc +import wrtc from "@roamhq/wrtc"; +export default wrtc; diff --git a/isomorphic-wrtc/package.json b/isomorphic-wrtc/package.json index 0568006c5..f4fd215cb 100644 --- a/isomorphic-wrtc/package.json +++ b/isomorphic-wrtc/package.json @@ -1,19 +1,19 @@ { - "name": "@epfml/isomorphic-wrtc", - "version": "1.0.0", - "description": "Isomorphic implementation of WebRTC", - "type": "module", - "main": "node.js", - "browser": "browser.js", - "types": "types.d.ts", - "author": "", - "license": "MIT", - "scripts": { - "build": ": nothing", - "lint": ": nothing", - "test": ": nothing" - }, - "peerDependencies": { - "@roamhq/wrtc": "*" - } + "name": "@epfml/isomorphic-wrtc", + "version": "1.0.0", + "description": "Isomorphic implementation of WebRTC", + "type": "module", + "main": "node.js", + "browser": "browser.js", + "types": "types.d.ts", + "author": "", + "license": "MIT", + "scripts": { + "build": ": nothing", + "lint": ": nothing", + "test": ": nothing" + }, + "peerDependencies": { + "@roamhq/wrtc": "*" + } } diff --git a/onnx-converter/README.md b/onnx-converter/README.md index d3f9dff82..4741ec66b 100644 --- a/onnx-converter/README.md +++ b/onnx-converter/README.md @@ -5,6 +5,7 @@ This workspace is currently used to convert ONNX [GPT-2 model](https://huggingfa Therefore, we want to convert pretrained models such as GPT-2 from ONNX format to Tensorflow.js to further fine-tune them. You generate a TF.js `model.json` by running `npm run convert_onnx` in this workspace. What the script does is: + 1. Read the ONNX GPT-2 model from [Xenova's repository](https://huggingface.co/Xenova/gpt2) 2. Use the ONNX protobuf definition to read the file and iterate through the model layers. The ONNX JavaScript protobuf comes from [this repository](https://github.com/microsoft/onnxruntime/blob/main/js/web/lib/onnxjs/). 3. Convert all weights to TF.js tensors @@ -15,7 +16,8 @@ Running `npm run convert_onnx` creates a GPT-tfjs `model.json` file in the `./as ## ONNX JS protobuf The ONNX specification has limited support in JavaScript. We found an old JS implementation in the [ONNX Runtime Web repository](https://github.com/microsoft/onnxruntime/tree/main/js/web/lib/onnxjs/ort-schema/protobuf). We had to adapt their files as follows to be compatible with our newer environment: + 1. Copy `onnx.js` and `onnx.d.ts` from [the repository](https://github.com/microsoft/onnxruntime/tree/main/js/web/lib/onnxjs/ort-schema/protobuf) in the `./onnx-converter/src/protobuf` folder. 2. Rename `onnx.js` to `onnx.cjs` 3. Create [`onnx-proto.js`](./src/protobuf/onnx-proto.js) as a wrapper around the protobuf definition. -4. Create [`onnx-proto.d.ts`](./src/protobuf/onnx-proto.d.ts) with the matching TypeScript definition. \ No newline at end of file +4. Create [`onnx-proto.d.ts`](./src/protobuf/onnx-proto.d.ts) with the matching TypeScript definition. diff --git a/onnx-converter/package.json b/onnx-converter/package.json index 98e8d75a5..59d33823e 100644 --- a/onnx-converter/package.json +++ b/onnx-converter/package.json @@ -5,7 +5,7 @@ "main": "dist/gpt2_from_onnx.js", "scripts": { "convert_onnx": "npm run build && node dist/convert_onnx.js", - "watch": "nodemon --ext ts --ignore dist --exec npm run", + "watch": "nodemon --ext ts --ignore dist --exec npm run", "build": "tsc --build && cp -r src/protobuf dist", "lint": "npx eslint .", "test": ": nothing" diff --git a/onnx-converter/src/convert_onnx.ts b/onnx-converter/src/convert_onnx.ts index cf45a76ed..fb00d0edd 100644 --- a/onnx-converter/src/convert_onnx.ts +++ b/onnx-converter/src/convert_onnx.ts @@ -1,102 +1,113 @@ -import { onnx } from './protobuf/onnx-proto.js'; -import { Map, Range } from 'immutable'; -import fsPromise from 'node:fs/promises'; -import * as tf from '@tensorflow/tfjs-node'; +import { onnx } from "./protobuf/onnx-proto.js"; +import { Map, Range } from "immutable"; +import fsPromise from "node:fs/promises"; +import * as tf from "@tensorflow/tfjs-node"; import { models, serialization } from "@epfml/discojs"; const OUTPUT_FILENAME = "model.json"; const GPT2_N_LAYER = 12; -const ONNX_URL = "https://huggingface.co/Xenova/gpt2/resolve/main/onnx/decoder_model.onnx?download=true" - +const ONNX_URL = + "https://huggingface.co/Xenova/gpt2/resolve/main/onnx/decoder_model.onnx?download=true"; async function main() { console.log(`Downloading ONNX model from ${ONNX_URL}...`); const response = await fetch(ONNX_URL); if (!response.ok) - throw new Error(`Failed to fetch ONNX model from ${ONNX_URL}: ${response.statusText}`); + throw new Error( + `Failed to fetch ONNX model from ${ONNX_URL}: ${response.statusText}`, + ); const arrayBuffer = await response.arrayBuffer(); const data = new Uint8Array(arrayBuffer); - console.log(`Download complete (${(data.length / 1024 / 1024).toFixed(2)} MB).`); + console.log( + `Download complete (${(data.length / 1024 / 1024).toFixed(2)} MB).`, + ); console.log(`Decoding protobuf...`); - const onnxModel = onnx.ModelProto.decode(data) - + const onnxModel = onnx.ModelProto.decode(data); + if (!onnxModel.graph || !onnxModel.graph.initializer) throw new Error("No graph or tensors found in the ONNX model."); - console.log('ONNX model loaded successfully'); - - + console.log("ONNX model loaded successfully"); + // Init empty TF.js model // Context length value from https://huggingface.co/Xenova/gpt2/blob/main/config.json - const gptModel = new models.GPT({ modelType: 'gpt2', contextLength: 1024 }); + const gptModel = new models.GPT({ modelType: "gpt2", contextLength: 1024 }); if (gptModel.config.nLayer != GPT2_N_LAYER) - throw new Error(`ONNX conversion only supports GPT-2 with 12 layers, instead found ${gptModel.config.nLayer}.`); + throw new Error( + `ONNX conversion only supports GPT-2 with 12 layers, instead found ${gptModel.config.nLayer}.`, + ); const gptLayersModel = gptModel.extract(); - console.log("Converting ONNX tensors to TF.js tensors") + console.log("Converting ONNX tensors to TF.js tensors"); // Layer name mapping between ONNX and TF.js const onnxTfjsMapping = createWeightNameMap(); // Create a mapping between layer name and TF.js weight tensors let preTrainedWeights = Map(); // layer name to weight tensor for (const tensor of onnxModel.graph.initializer) { if (tensor.name === undefined || tensor.name === null) - throw new Error("Undefined layer named") - + throw new Error("Undefined layer named"); + const tfjsName = onnxTfjsMapping.get(tensor.name); if (tfjsName === undefined) throw new Error(`Missing ONNX weight in layer mapping: ${tensor.name}`); if (preTrainedWeights.get(tfjsName)) - throw new Error(`Duplicate weight name found: ${tfjsName}`); - + throw new Error(`Duplicate weight name found: ${tfjsName}`); + if (tensor.dims === undefined || tensor.dims === null) - throw new Error(`Undefined layer dimensions for ${tensor.name}`) + throw new Error(`Undefined layer dimensions for ${tensor.name}`); const dims = tensor.dims.map((d) => Number(d)); const flatData = parseTensorData(tensor); - const tfTensor = tf.tensor(flatData).reshape(dims) + const tfTensor = tf.tensor(flatData).reshape(dims); preTrainedWeights = preTrainedWeights.set(tfjsName, tfTensor); } - console.log("Initializing a new TFJS GPT-2 model...") + console.log("Initializing a new TFJS GPT-2 model..."); if (preTrainedWeights.size !== onnxTfjsMapping.size) - throw new Error(`Expected to load ${onnxTfjsMapping.size} weights, but loaded ${preTrainedWeights.size}.`); - + throw new Error( + `Expected to load ${onnxTfjsMapping.size} weights, but loaded ${preTrainedWeights.size}.`, + ); + // Overwrite the GPT-TF.js model weights with the ONNX weights - if (gptLayersModel.weights.length !== onnxTfjsMapping.size) + if (gptLayersModel.weights.length !== onnxTfjsMapping.size) throw new Error(`Mismatch between TFJS and ONNX weight mapping weights.`); - - const finalWeights = gptLayersModel.weights.map(weight => { + + const finalWeights = gptLayersModel.weights.map((weight) => { const newTensor = preTrainedWeights.get(weight.name); if (newTensor === undefined) throw new Error(`Missing ${weight.name} in the ONNX weights`); return newTensor; }); - + gptLayersModel.setWeights(finalWeights); // shape or transpose mismatch will throw here - - const encoded = await serialization.model.encode(gptModel) - await fsPromise.writeFile(OUTPUT_FILENAME, encoded) - console.log(`GPT-TFJS model saved to ${OUTPUT_FILENAME}`) + + const encoded = await serialization.model.encode(gptModel); + await fsPromise.writeFile(OUTPUT_FILENAME, encoded); + console.log(`GPT-TFJS model saved to ${OUTPUT_FILENAME}`); } /** * Converts protobuf's tensors to float 32 arrays. */ function parseTensorData(tensor: onnx.ITensorProto): Float32Array { - // Check for raw data (common in larger models) - if (tensor.rawData && tensor.rawData.length > 0) { - const buffer = tensor.rawData.buffer.slice( - tensor.rawData.byteOffset, - tensor.rawData.byteOffset + tensor.rawData.byteLength - ); - if (tensor.dataType != onnx.TensorProto.DataType.FLOAT) { - throw new Error("found protobuf data type different from expected float 32.") - } - return new Float32Array(buffer); + // Check for raw data (common in larger models) + if (tensor.rawData && tensor.rawData.length > 0) { + const buffer = tensor.rawData.buffer.slice( + tensor.rawData.byteOffset, + tensor.rawData.byteOffset + tensor.rawData.byteLength, + ); + if (tensor.dataType != onnx.TensorProto.DataType.FLOAT) { + throw new Error( + "found protobuf data type different from expected float 32.", + ); } + return new Float32Array(buffer); + } - throw new Error("Protobuf's `rawData` is empty. Potentially check `floatData`.") + throw new Error( + "Protobuf's `rawData` is empty. Potentially check `floatData`.", + ); } /** @@ -106,25 +117,46 @@ function parseTensorData(tensor: onnx.ITensorProto): Float32Array { */ function createWeightNameMap(): Map { let map = Map(); - + map = map.set(`transformer.wte.weight`, `transformer/wte/embedding`); map = map.set(`transformer.wpe.weight`, `transformer/wpe/embeddings`); - - Range(0, GPT2_N_LAYER).forEach(i => { + + Range(0, GPT2_N_LAYER).forEach((i) => { const onnxPrefix = `transformer.h.${i}`; const tfjsPrefix = `transformer/h${i}`; map = map.set(`${onnxPrefix}.ln_1.weight`, `${tfjsPrefix}/ln_1/gamma`); map = map.set(`${onnxPrefix}.ln_1.bias`, `${tfjsPrefix}/ln_1/beta`); - map = map.set(`${onnxPrefix}.attn.c_attn.weight`, `${tfjsPrefix}/attn/c_attn/kernel`); - map = map.set(`${onnxPrefix}.attn.c_attn.bias`, `${tfjsPrefix}/attn/c_attn/bias`); - map = map.set(`${onnxPrefix}.attn.c_proj.weight`, `${tfjsPrefix}/attn/c_proj/kernel`); - map = map.set(`${onnxPrefix}.attn.c_proj.bias`, `${tfjsPrefix}/attn/c_proj/bias`); + map = map.set( + `${onnxPrefix}.attn.c_attn.weight`, + `${tfjsPrefix}/attn/c_attn/kernel`, + ); + map = map.set( + `${onnxPrefix}.attn.c_attn.bias`, + `${tfjsPrefix}/attn/c_attn/bias`, + ); + map = map.set( + `${onnxPrefix}.attn.c_proj.weight`, + `${tfjsPrefix}/attn/c_proj/kernel`, + ); + map = map.set( + `${onnxPrefix}.attn.c_proj.bias`, + `${tfjsPrefix}/attn/c_proj/bias`, + ); map = map.set(`${onnxPrefix}.ln_2.weight`, `${tfjsPrefix}/ln_2/gamma`); map = map.set(`${onnxPrefix}.ln_2.bias`, `${tfjsPrefix}/ln_2/beta`); - map = map.set(`${onnxPrefix}.mlp.c_fc.weight`, `${tfjsPrefix}/mlp/c_fc/kernel`); + map = map.set( + `${onnxPrefix}.mlp.c_fc.weight`, + `${tfjsPrefix}/mlp/c_fc/kernel`, + ); map = map.set(`${onnxPrefix}.mlp.c_fc.bias`, `${tfjsPrefix}/mlp/c_fc/bias`); - map = map.set(`${onnxPrefix}.mlp.c_proj.weight`, `${tfjsPrefix}/mlp/c_proj/kernel`); - map = map.set(`${onnxPrefix}.mlp.c_proj.bias`, `${tfjsPrefix}/mlp/c_proj/bias`); + map = map.set( + `${onnxPrefix}.mlp.c_proj.weight`, + `${tfjsPrefix}/mlp/c_proj/kernel`, + ); + map = map.set( + `${onnxPrefix}.mlp.c_proj.bias`, + `${tfjsPrefix}/mlp/c_proj/bias`, + ); }); map = map.set(`transformer.ln_f.weight`, `transformer/ln_f/gamma`); @@ -132,5 +164,4 @@ function createWeightNameMap(): Map { return map; } - -await main().catch(console.error); \ No newline at end of file +await main().catch(console.error); diff --git a/onnx-converter/src/protobuf/onnx-proto.d.ts b/onnx-converter/src/protobuf/onnx-proto.d.ts index ea32e77ad..c942d2b8c 100644 --- a/onnx-converter/src/protobuf/onnx-proto.d.ts +++ b/onnx-converter/src/protobuf/onnx-proto.d.ts @@ -1,5 +1,5 @@ -export { onnx } from './onnx.js'; +export { onnx } from "./onnx.js"; declare const onnxModule: { - onnx: typeof import('./onnx.js').onnx; + onnx: typeof import("./onnx.js").onnx; }; -export default onnxModule; \ No newline at end of file +export default onnxModule; diff --git a/onnx-converter/src/protobuf/onnx-proto.js b/onnx-converter/src/protobuf/onnx-proto.js index e4d15dc50..53c48e373 100644 --- a/onnx-converter/src/protobuf/onnx-proto.js +++ b/onnx-converter/src/protobuf/onnx-proto.js @@ -1,6 +1,6 @@ -import { createRequire } from 'module'; +import { createRequire } from "module"; const require = createRequire(import.meta.url); -const onnxModule = require('./onnx.cjs'); +const onnxModule = require("./onnx.cjs"); export const onnx = onnxModule.onnx; -export default onnxModule; \ No newline at end of file +export default onnxModule; diff --git a/onnx-converter/src/protobuf/onnx.cjs b/onnx-converter/src/protobuf/onnx.cjs index d7ca95149..52333c7b7 100644 --- a/onnx-converter/src/protobuf/onnx.cjs +++ b/onnx-converter/src/protobuf/onnx.cjs @@ -2,9 +2,9 @@ // LICENSE: MIT /*eslint-disable block-scoped-var, id-length, no-control-regex, no-magic-numbers, no-prototype-builtins, no-redeclare, no-shadow, no-var, sort-vars*/ -'use strict'; +"use strict"; -var $protobuf = require('protobufjs/minimal'); +var $protobuf = require("protobufjs/minimal"); // Common aliases var $Reader = $protobuf.Reader, @@ -12,7 +12,7 @@ var $Reader = $protobuf.Reader, $util = $protobuf.util; // Exported root namespace -var $root = $protobuf.roots['default'] || ($protobuf.roots['default'] = {}); +var $root = $protobuf.roots["default"] || ($protobuf.roots["default"] = {}); $root.onnx = (function () { /** @@ -40,16 +40,16 @@ $root.onnx = (function () { onnx.Version = (function () { var valuesById = {}, values = Object.create(valuesById); - values[(valuesById[0] = '_START_VERSION')] = 0; - values[(valuesById[1] = 'IR_VERSION_2017_10_10')] = 1; - values[(valuesById[2] = 'IR_VERSION_2017_10_30')] = 2; - values[(valuesById[3] = 'IR_VERSION_2017_11_3')] = 3; - values[(valuesById[4] = 'IR_VERSION_2019_1_22')] = 4; - values[(valuesById[5] = 'IR_VERSION_2019_3_18')] = 5; - values[(valuesById[6] = 'IR_VERSION_2019_9_19')] = 6; - values[(valuesById[7] = 'IR_VERSION_2020_5_8')] = 7; - values[(valuesById[8] = 'IR_VERSION_2021_7_30')] = 8; - values[(valuesById[9] = 'IR_VERSION')] = 9; + values[(valuesById[0] = "_START_VERSION")] = 0; + values[(valuesById[1] = "IR_VERSION_2017_10_10")] = 1; + values[(valuesById[2] = "IR_VERSION_2017_10_30")] = 2; + values[(valuesById[3] = "IR_VERSION_2017_11_3")] = 3; + values[(valuesById[4] = "IR_VERSION_2019_1_22")] = 4; + values[(valuesById[5] = "IR_VERSION_2019_3_18")] = 5; + values[(valuesById[6] = "IR_VERSION_2019_9_19")] = 6; + values[(valuesById[7] = "IR_VERSION_2020_5_8")] = 7; + values[(valuesById[8] = "IR_VERSION_2021_7_30")] = 8; + values[(valuesById[9] = "IR_VERSION")] = 9; return values; })(); @@ -105,7 +105,7 @@ $root.onnx = (function () { * @memberof onnx.AttributeProto * @instance */ - AttributeProto.prototype.name = ''; + AttributeProto.prototype.name = ""; /** * AttributeProto refAttrName. @@ -113,7 +113,7 @@ $root.onnx = (function () { * @memberof onnx.AttributeProto * @instance */ - AttributeProto.prototype.refAttrName = ''; + AttributeProto.prototype.refAttrName = ""; /** * AttributeProto docString. @@ -121,7 +121,7 @@ $root.onnx = (function () { * @memberof onnx.AttributeProto * @instance */ - AttributeProto.prototype.docString = ''; + AttributeProto.prototype.docString = ""; /** * AttributeProto type. @@ -145,7 +145,9 @@ $root.onnx = (function () { * @memberof onnx.AttributeProto * @instance */ - AttributeProto.prototype.i = $util.Long ? $util.Long.fromBits(0, 0, false) : 0; + AttributeProto.prototype.i = $util.Long + ? $util.Long.fromBits(0, 0, false) + : 0; /** * AttributeProto s. @@ -266,26 +268,34 @@ $root.onnx = (function () { */ AttributeProto.encode = function encode(message, writer) { if (!writer) writer = $Writer.create(); - if (message.name != null && Object.hasOwnProperty.call(message, 'name')) + if (message.name != null && Object.hasOwnProperty.call(message, "name")) writer.uint32(/* id 1, wireType 2 =*/ 10).string(message.name); - if (message.f != null && Object.hasOwnProperty.call(message, 'f')) + if (message.f != null && Object.hasOwnProperty.call(message, "f")) writer.uint32(/* id 2, wireType 5 =*/ 21).float(message.f); - if (message.i != null && Object.hasOwnProperty.call(message, 'i')) + if (message.i != null && Object.hasOwnProperty.call(message, "i")) writer.uint32(/* id 3, wireType 0 =*/ 24).int64(message.i); - if (message.s != null && Object.hasOwnProperty.call(message, 's')) + if (message.s != null && Object.hasOwnProperty.call(message, "s")) writer.uint32(/* id 4, wireType 2 =*/ 34).bytes(message.s); - if (message.t != null && Object.hasOwnProperty.call(message, 't')) - $root.onnx.TensorProto.encode(message.t, writer.uint32(/* id 5, wireType 2 =*/ 42).fork()).ldelim(); - if (message.g != null && Object.hasOwnProperty.call(message, 'g')) - $root.onnx.GraphProto.encode(message.g, writer.uint32(/* id 6, wireType 2 =*/ 50).fork()).ldelim(); + if (message.t != null && Object.hasOwnProperty.call(message, "t")) + $root.onnx.TensorProto.encode( + message.t, + writer.uint32(/* id 5, wireType 2 =*/ 42).fork(), + ).ldelim(); + if (message.g != null && Object.hasOwnProperty.call(message, "g")) + $root.onnx.GraphProto.encode( + message.g, + writer.uint32(/* id 6, wireType 2 =*/ 50).fork(), + ).ldelim(); if (message.floats != null && message.floats.length) { writer.uint32(/* id 7, wireType 2 =*/ 58).fork(); - for (var i = 0; i < message.floats.length; ++i) writer.float(message.floats[i]); + for (var i = 0; i < message.floats.length; ++i) + writer.float(message.floats[i]); writer.ldelim(); } if (message.ints != null && message.ints.length) { writer.uint32(/* id 8, wireType 2 =*/ 66).fork(); - for (var i = 0; i < message.ints.length; ++i) writer.int64(message.ints[i]); + for (var i = 0; i < message.ints.length; ++i) + writer.int64(message.ints[i]); writer.ldelim(); } if (message.strings != null && message.strings.length) @@ -293,25 +303,43 @@ $root.onnx = (function () { writer.uint32(/* id 9, wireType 2 =*/ 74).bytes(message.strings[i]); if (message.tensors != null && message.tensors.length) for (var i = 0; i < message.tensors.length; ++i) - $root.onnx.TensorProto.encode(message.tensors[i], writer.uint32(/* id 10, wireType 2 =*/ 82).fork()).ldelim(); + $root.onnx.TensorProto.encode( + message.tensors[i], + writer.uint32(/* id 10, wireType 2 =*/ 82).fork(), + ).ldelim(); if (message.graphs != null && message.graphs.length) for (var i = 0; i < message.graphs.length; ++i) - $root.onnx.GraphProto.encode(message.graphs[i], writer.uint32(/* id 11, wireType 2 =*/ 90).fork()).ldelim(); - if (message.docString != null && Object.hasOwnProperty.call(message, 'docString')) + $root.onnx.GraphProto.encode( + message.graphs[i], + writer.uint32(/* id 11, wireType 2 =*/ 90).fork(), + ).ldelim(); + if ( + message.docString != null && + Object.hasOwnProperty.call(message, "docString") + ) writer.uint32(/* id 13, wireType 2 =*/ 106).string(message.docString); - if (message.tp != null && Object.hasOwnProperty.call(message, 'tp')) - $root.onnx.TypeProto.encode(message.tp, writer.uint32(/* id 14, wireType 2 =*/ 114).fork()).ldelim(); + if (message.tp != null && Object.hasOwnProperty.call(message, "tp")) + $root.onnx.TypeProto.encode( + message.tp, + writer.uint32(/* id 14, wireType 2 =*/ 114).fork(), + ).ldelim(); if (message.typeProtos != null && message.typeProtos.length) for (var i = 0; i < message.typeProtos.length; ++i) $root.onnx.TypeProto.encode( message.typeProtos[i], writer.uint32(/* id 15, wireType 2 =*/ 122).fork(), ).ldelim(); - if (message.type != null && Object.hasOwnProperty.call(message, 'type')) + if (message.type != null && Object.hasOwnProperty.call(message, "type")) writer.uint32(/* id 20, wireType 0 =*/ 160).int32(message.type); - if (message.refAttrName != null && Object.hasOwnProperty.call(message, 'refAttrName')) + if ( + message.refAttrName != null && + Object.hasOwnProperty.call(message, "refAttrName") + ) writer.uint32(/* id 21, wireType 2 =*/ 170).string(message.refAttrName); - if (message.sparseTensor != null && Object.hasOwnProperty.call(message, 'sparseTensor')) + if ( + message.sparseTensor != null && + Object.hasOwnProperty.call(message, "sparseTensor") + ) $root.onnx.SparseTensorProto.encode( message.sparseTensor, writer.uint32(/* id 22, wireType 2 =*/ 178).fork(), @@ -393,7 +421,10 @@ $root.onnx = (function () { break; } case 22: { - message.sparseTensor = $root.onnx.SparseTensorProto.decode(reader, reader.uint32()); + message.sparseTensor = $root.onnx.SparseTensorProto.decode( + reader, + reader.uint32(), + ); break; } case 14: { @@ -417,28 +448,40 @@ $root.onnx = (function () { break; } case 9: { - if (!(message.strings && message.strings.length)) message.strings = []; + if (!(message.strings && message.strings.length)) + message.strings = []; message.strings.push(reader.bytes()); break; } case 10: { - if (!(message.tensors && message.tensors.length)) message.tensors = []; - message.tensors.push($root.onnx.TensorProto.decode(reader, reader.uint32())); + if (!(message.tensors && message.tensors.length)) + message.tensors = []; + message.tensors.push( + $root.onnx.TensorProto.decode(reader, reader.uint32()), + ); break; } case 11: { if (!(message.graphs && message.graphs.length)) message.graphs = []; - message.graphs.push($root.onnx.GraphProto.decode(reader, reader.uint32())); + message.graphs.push( + $root.onnx.GraphProto.decode(reader, reader.uint32()), + ); break; } case 23: { - if (!(message.sparseTensors && message.sparseTensors.length)) message.sparseTensors = []; - message.sparseTensors.push($root.onnx.SparseTensorProto.decode(reader, reader.uint32())); + if (!(message.sparseTensors && message.sparseTensors.length)) + message.sparseTensors = []; + message.sparseTensors.push( + $root.onnx.SparseTensorProto.decode(reader, reader.uint32()), + ); break; } case 15: { - if (!(message.typeProtos && message.typeProtos.length)) message.typeProtos = []; - message.typeProtos.push($root.onnx.TypeProto.decode(reader, reader.uint32())); + if (!(message.typeProtos && message.typeProtos.length)) + message.typeProtos = []; + message.typeProtos.push( + $root.onnx.TypeProto.decode(reader, reader.uint32()), + ); break; } default: @@ -473,17 +516,20 @@ $root.onnx = (function () { * @returns {string|null} `null` if valid, otherwise the reason why it is not */ AttributeProto.verify = function verify(message) { - if (typeof message !== 'object' || message === null) return 'object expected'; - if (message.name != null && message.hasOwnProperty('name')) - if (!$util.isString(message.name)) return 'name: string expected'; - if (message.refAttrName != null && message.hasOwnProperty('refAttrName')) - if (!$util.isString(message.refAttrName)) return 'refAttrName: string expected'; - if (message.docString != null && message.hasOwnProperty('docString')) - if (!$util.isString(message.docString)) return 'docString: string expected'; - if (message.type != null && message.hasOwnProperty('type')) + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.name != null && message.hasOwnProperty("name")) + if (!$util.isString(message.name)) return "name: string expected"; + if (message.refAttrName != null && message.hasOwnProperty("refAttrName")) + if (!$util.isString(message.refAttrName)) + return "refAttrName: string expected"; + if (message.docString != null && message.hasOwnProperty("docString")) + if (!$util.isString(message.docString)) + return "docString: string expected"; + if (message.type != null && message.hasOwnProperty("type")) switch (message.type) { default: - return 'type: enum value expected'; + return "type: enum value expected"; case 0: case 1: case 2: @@ -501,84 +547,109 @@ $root.onnx = (function () { case 14: break; } - if (message.f != null && message.hasOwnProperty('f')) - if (typeof message.f !== 'number') return 'f: number expected'; - if (message.i != null && message.hasOwnProperty('i')) + if (message.f != null && message.hasOwnProperty("f")) + if (typeof message.f !== "number") return "f: number expected"; + if (message.i != null && message.hasOwnProperty("i")) if ( !$util.isInteger(message.i) && - !(message.i && $util.isInteger(message.i.low) && $util.isInteger(message.i.high)) + !( + message.i && + $util.isInteger(message.i.low) && + $util.isInteger(message.i.high) + ) ) - return 'i: integer|Long expected'; - if (message.s != null && message.hasOwnProperty('s')) - if (!((message.s && typeof message.s.length === 'number') || $util.isString(message.s))) - return 's: buffer expected'; - if (message.t != null && message.hasOwnProperty('t')) { + return "i: integer|Long expected"; + if (message.s != null && message.hasOwnProperty("s")) + if ( + !( + (message.s && typeof message.s.length === "number") || + $util.isString(message.s) + ) + ) + return "s: buffer expected"; + if (message.t != null && message.hasOwnProperty("t")) { var error = $root.onnx.TensorProto.verify(message.t); - if (error) return 't.' + error; + if (error) return "t." + error; } - if (message.g != null && message.hasOwnProperty('g')) { + if (message.g != null && message.hasOwnProperty("g")) { var error = $root.onnx.GraphProto.verify(message.g); - if (error) return 'g.' + error; + if (error) return "g." + error; } - if (message.sparseTensor != null && message.hasOwnProperty('sparseTensor')) { + if ( + message.sparseTensor != null && + message.hasOwnProperty("sparseTensor") + ) { var error = $root.onnx.SparseTensorProto.verify(message.sparseTensor); - if (error) return 'sparseTensor.' + error; + if (error) return "sparseTensor." + error; } - if (message.tp != null && message.hasOwnProperty('tp')) { + if (message.tp != null && message.hasOwnProperty("tp")) { var error = $root.onnx.TypeProto.verify(message.tp); - if (error) return 'tp.' + error; + if (error) return "tp." + error; } - if (message.floats != null && message.hasOwnProperty('floats')) { - if (!Array.isArray(message.floats)) return 'floats: array expected'; + if (message.floats != null && message.hasOwnProperty("floats")) { + if (!Array.isArray(message.floats)) return "floats: array expected"; for (var i = 0; i < message.floats.length; ++i) - if (typeof message.floats[i] !== 'number') return 'floats: number[] expected'; + if (typeof message.floats[i] !== "number") + return "floats: number[] expected"; } - if (message.ints != null && message.hasOwnProperty('ints')) { - if (!Array.isArray(message.ints)) return 'ints: array expected'; + if (message.ints != null && message.hasOwnProperty("ints")) { + if (!Array.isArray(message.ints)) return "ints: array expected"; for (var i = 0; i < message.ints.length; ++i) if ( !$util.isInteger(message.ints[i]) && - !(message.ints[i] && $util.isInteger(message.ints[i].low) && $util.isInteger(message.ints[i].high)) + !( + message.ints[i] && + $util.isInteger(message.ints[i].low) && + $util.isInteger(message.ints[i].high) + ) ) - return 'ints: integer|Long[] expected'; + return "ints: integer|Long[] expected"; } - if (message.strings != null && message.hasOwnProperty('strings')) { - if (!Array.isArray(message.strings)) return 'strings: array expected'; + if (message.strings != null && message.hasOwnProperty("strings")) { + if (!Array.isArray(message.strings)) return "strings: array expected"; for (var i = 0; i < message.strings.length; ++i) if ( !( - (message.strings[i] && typeof message.strings[i].length === 'number') || + (message.strings[i] && + typeof message.strings[i].length === "number") || $util.isString(message.strings[i]) ) ) - return 'strings: buffer[] expected'; + return "strings: buffer[] expected"; } - if (message.tensors != null && message.hasOwnProperty('tensors')) { - if (!Array.isArray(message.tensors)) return 'tensors: array expected'; + if (message.tensors != null && message.hasOwnProperty("tensors")) { + if (!Array.isArray(message.tensors)) return "tensors: array expected"; for (var i = 0; i < message.tensors.length; ++i) { var error = $root.onnx.TensorProto.verify(message.tensors[i]); - if (error) return 'tensors.' + error; + if (error) return "tensors." + error; } } - if (message.graphs != null && message.hasOwnProperty('graphs')) { - if (!Array.isArray(message.graphs)) return 'graphs: array expected'; + if (message.graphs != null && message.hasOwnProperty("graphs")) { + if (!Array.isArray(message.graphs)) return "graphs: array expected"; for (var i = 0; i < message.graphs.length; ++i) { var error = $root.onnx.GraphProto.verify(message.graphs[i]); - if (error) return 'graphs.' + error; + if (error) return "graphs." + error; } } - if (message.sparseTensors != null && message.hasOwnProperty('sparseTensors')) { - if (!Array.isArray(message.sparseTensors)) return 'sparseTensors: array expected'; + if ( + message.sparseTensors != null && + message.hasOwnProperty("sparseTensors") + ) { + if (!Array.isArray(message.sparseTensors)) + return "sparseTensors: array expected"; for (var i = 0; i < message.sparseTensors.length; ++i) { - var error = $root.onnx.SparseTensorProto.verify(message.sparseTensors[i]); - if (error) return 'sparseTensors.' + error; + var error = $root.onnx.SparseTensorProto.verify( + message.sparseTensors[i], + ); + if (error) return "sparseTensors." + error; } } - if (message.typeProtos != null && message.hasOwnProperty('typeProtos')) { - if (!Array.isArray(message.typeProtos)) return 'typeProtos: array expected'; + if (message.typeProtos != null && message.hasOwnProperty("typeProtos")) { + if (!Array.isArray(message.typeProtos)) + return "typeProtos: array expected"; for (var i = 0; i < message.typeProtos.length; ++i) { var error = $root.onnx.TypeProto.verify(message.typeProtos[i]); - if (error) return 'typeProtos.' + error; + if (error) return "typeProtos." + error; } } return null; @@ -596,163 +667,209 @@ $root.onnx = (function () { if (object instanceof $root.onnx.AttributeProto) return object; var message = new $root.onnx.AttributeProto(); if (object.name != null) message.name = String(object.name); - if (object.refAttrName != null) message.refAttrName = String(object.refAttrName); - if (object.docString != null) message.docString = String(object.docString); + if (object.refAttrName != null) + message.refAttrName = String(object.refAttrName); + if (object.docString != null) + message.docString = String(object.docString); switch (object.type) { default: - if (typeof object.type === 'number') { + if (typeof object.type === "number") { message.type = object.type; break; } break; - case 'UNDEFINED': + case "UNDEFINED": case 0: message.type = 0; break; - case 'FLOAT': + case "FLOAT": case 1: message.type = 1; break; - case 'INT': + case "INT": case 2: message.type = 2; break; - case 'STRING': + case "STRING": case 3: message.type = 3; break; - case 'TENSOR': + case "TENSOR": case 4: message.type = 4; break; - case 'GRAPH': + case "GRAPH": case 5: message.type = 5; break; - case 'SPARSE_TENSOR': + case "SPARSE_TENSOR": case 11: message.type = 11; break; - case 'TYPE_PROTO': + case "TYPE_PROTO": case 13: message.type = 13; break; - case 'FLOATS': + case "FLOATS": case 6: message.type = 6; break; - case 'INTS': + case "INTS": case 7: message.type = 7; break; - case 'STRINGS': + case "STRINGS": case 8: message.type = 8; break; - case 'TENSORS': + case "TENSORS": case 9: message.type = 9; break; - case 'GRAPHS': + case "GRAPHS": case 10: message.type = 10; break; - case 'SPARSE_TENSORS': + case "SPARSE_TENSORS": case 12: message.type = 12; break; - case 'TYPE_PROTOS': + case "TYPE_PROTOS": case 14: message.type = 14; break; } if (object.f != null) message.f = Number(object.f); if (object.i != null) - if ($util.Long) (message.i = $util.Long.fromValue(object.i)).unsigned = false; - else if (typeof object.i === 'string') message.i = parseInt(object.i, 10); - else if (typeof object.i === 'number') message.i = object.i; - else if (typeof object.i === 'object') - message.i = new $util.LongBits(object.i.low >>> 0, object.i.high >>> 0).toNumber(); + if ($util.Long) + (message.i = $util.Long.fromValue(object.i)).unsigned = false; + else if (typeof object.i === "string") + message.i = parseInt(object.i, 10); + else if (typeof object.i === "number") message.i = object.i; + else if (typeof object.i === "object") + message.i = new $util.LongBits( + object.i.low >>> 0, + object.i.high >>> 0, + ).toNumber(); if (object.s != null) - if (typeof object.s === 'string') - $util.base64.decode(object.s, (message.s = $util.newBuffer($util.base64.length(object.s))), 0); + if (typeof object.s === "string") + $util.base64.decode( + object.s, + (message.s = $util.newBuffer($util.base64.length(object.s))), + 0, + ); else if (object.s.length >= 0) message.s = object.s; if (object.t != null) { - if (typeof object.t !== 'object') throw TypeError('.onnx.AttributeProto.t: object expected'); + if (typeof object.t !== "object") + throw TypeError(".onnx.AttributeProto.t: object expected"); message.t = $root.onnx.TensorProto.fromObject(object.t); } if (object.g != null) { - if (typeof object.g !== 'object') throw TypeError('.onnx.AttributeProto.g: object expected'); + if (typeof object.g !== "object") + throw TypeError(".onnx.AttributeProto.g: object expected"); message.g = $root.onnx.GraphProto.fromObject(object.g); } if (object.sparseTensor != null) { - if (typeof object.sparseTensor !== 'object') - throw TypeError('.onnx.AttributeProto.sparseTensor: object expected'); - message.sparseTensor = $root.onnx.SparseTensorProto.fromObject(object.sparseTensor); + if (typeof object.sparseTensor !== "object") + throw TypeError(".onnx.AttributeProto.sparseTensor: object expected"); + message.sparseTensor = $root.onnx.SparseTensorProto.fromObject( + object.sparseTensor, + ); } if (object.tp != null) { - if (typeof object.tp !== 'object') throw TypeError('.onnx.AttributeProto.tp: object expected'); + if (typeof object.tp !== "object") + throw TypeError(".onnx.AttributeProto.tp: object expected"); message.tp = $root.onnx.TypeProto.fromObject(object.tp); } if (object.floats) { - if (!Array.isArray(object.floats)) throw TypeError('.onnx.AttributeProto.floats: array expected'); + if (!Array.isArray(object.floats)) + throw TypeError(".onnx.AttributeProto.floats: array expected"); message.floats = []; - for (var i = 0; i < object.floats.length; ++i) message.floats[i] = Number(object.floats[i]); + for (var i = 0; i < object.floats.length; ++i) + message.floats[i] = Number(object.floats[i]); } if (object.ints) { - if (!Array.isArray(object.ints)) throw TypeError('.onnx.AttributeProto.ints: array expected'); + if (!Array.isArray(object.ints)) + throw TypeError(".onnx.AttributeProto.ints: array expected"); message.ints = []; for (var i = 0; i < object.ints.length; ++i) - if ($util.Long) (message.ints[i] = $util.Long.fromValue(object.ints[i])).unsigned = false; - else if (typeof object.ints[i] === 'string') message.ints[i] = parseInt(object.ints[i], 10); - else if (typeof object.ints[i] === 'number') message.ints[i] = object.ints[i]; - else if (typeof object.ints[i] === 'object') - message.ints[i] = new $util.LongBits(object.ints[i].low >>> 0, object.ints[i].high >>> 0).toNumber(); + if ($util.Long) + (message.ints[i] = $util.Long.fromValue(object.ints[i])).unsigned = + false; + else if (typeof object.ints[i] === "string") + message.ints[i] = parseInt(object.ints[i], 10); + else if (typeof object.ints[i] === "number") + message.ints[i] = object.ints[i]; + else if (typeof object.ints[i] === "object") + message.ints[i] = new $util.LongBits( + object.ints[i].low >>> 0, + object.ints[i].high >>> 0, + ).toNumber(); } if (object.strings) { - if (!Array.isArray(object.strings)) throw TypeError('.onnx.AttributeProto.strings: array expected'); + if (!Array.isArray(object.strings)) + throw TypeError(".onnx.AttributeProto.strings: array expected"); message.strings = []; for (var i = 0; i < object.strings.length; ++i) - if (typeof object.strings[i] === 'string') + if (typeof object.strings[i] === "string") $util.base64.decode( object.strings[i], - (message.strings[i] = $util.newBuffer($util.base64.length(object.strings[i]))), + (message.strings[i] = $util.newBuffer( + $util.base64.length(object.strings[i]), + )), 0, ); - else if (object.strings[i].length >= 0) message.strings[i] = object.strings[i]; + else if (object.strings[i].length >= 0) + message.strings[i] = object.strings[i]; } if (object.tensors) { - if (!Array.isArray(object.tensors)) throw TypeError('.onnx.AttributeProto.tensors: array expected'); + if (!Array.isArray(object.tensors)) + throw TypeError(".onnx.AttributeProto.tensors: array expected"); message.tensors = []; for (var i = 0; i < object.tensors.length; ++i) { - if (typeof object.tensors[i] !== 'object') throw TypeError('.onnx.AttributeProto.tensors: object expected'); - message.tensors[i] = $root.onnx.TensorProto.fromObject(object.tensors[i]); + if (typeof object.tensors[i] !== "object") + throw TypeError(".onnx.AttributeProto.tensors: object expected"); + message.tensors[i] = $root.onnx.TensorProto.fromObject( + object.tensors[i], + ); } } if (object.graphs) { - if (!Array.isArray(object.graphs)) throw TypeError('.onnx.AttributeProto.graphs: array expected'); + if (!Array.isArray(object.graphs)) + throw TypeError(".onnx.AttributeProto.graphs: array expected"); message.graphs = []; for (var i = 0; i < object.graphs.length; ++i) { - if (typeof object.graphs[i] !== 'object') throw TypeError('.onnx.AttributeProto.graphs: object expected'); - message.graphs[i] = $root.onnx.GraphProto.fromObject(object.graphs[i]); + if (typeof object.graphs[i] !== "object") + throw TypeError(".onnx.AttributeProto.graphs: object expected"); + message.graphs[i] = $root.onnx.GraphProto.fromObject( + object.graphs[i], + ); } } if (object.sparseTensors) { - if (!Array.isArray(object.sparseTensors)) throw TypeError('.onnx.AttributeProto.sparseTensors: array expected'); + if (!Array.isArray(object.sparseTensors)) + throw TypeError(".onnx.AttributeProto.sparseTensors: array expected"); message.sparseTensors = []; for (var i = 0; i < object.sparseTensors.length; ++i) { - if (typeof object.sparseTensors[i] !== 'object') - throw TypeError('.onnx.AttributeProto.sparseTensors: object expected'); - message.sparseTensors[i] = $root.onnx.SparseTensorProto.fromObject(object.sparseTensors[i]); + if (typeof object.sparseTensors[i] !== "object") + throw TypeError( + ".onnx.AttributeProto.sparseTensors: object expected", + ); + message.sparseTensors[i] = $root.onnx.SparseTensorProto.fromObject( + object.sparseTensors[i], + ); } } if (object.typeProtos) { - if (!Array.isArray(object.typeProtos)) throw TypeError('.onnx.AttributeProto.typeProtos: array expected'); + if (!Array.isArray(object.typeProtos)) + throw TypeError(".onnx.AttributeProto.typeProtos: array expected"); message.typeProtos = []; for (var i = 0; i < object.typeProtos.length; ++i) { - if (typeof object.typeProtos[i] !== 'object') - throw TypeError('.onnx.AttributeProto.typeProtos: object expected'); - message.typeProtos[i] = $root.onnx.TypeProto.fromObject(object.typeProtos[i]); + if (typeof object.typeProtos[i] !== "object") + throw TypeError(".onnx.AttributeProto.typeProtos: object expected"); + message.typeProtos[i] = $root.onnx.TypeProto.fromObject( + object.typeProtos[i], + ); } } return message; @@ -780,65 +897,84 @@ $root.onnx = (function () { object.sparseTensors = []; } if (options.defaults) { - object.name = ''; + object.name = ""; object.f = 0; if ($util.Long) { var long = new $util.Long(0, 0, false); - object.i = options.longs === String ? long.toString() : options.longs === Number ? long.toNumber() : long; - } else object.i = options.longs === String ? '0' : 0; - if (options.bytes === String) object.s = ''; + object.i = + options.longs === String + ? long.toString() + : options.longs === Number + ? long.toNumber() + : long; + } else object.i = options.longs === String ? "0" : 0; + if (options.bytes === String) object.s = ""; else { object.s = []; if (options.bytes !== Array) object.s = $util.newBuffer(object.s); } object.t = null; object.g = null; - object.docString = ''; + object.docString = ""; object.tp = null; - object.type = options.enums === String ? 'UNDEFINED' : 0; - object.refAttrName = ''; + object.type = options.enums === String ? "UNDEFINED" : 0; + object.refAttrName = ""; object.sparseTensor = null; } - if (message.name != null && message.hasOwnProperty('name')) object.name = message.name; - if (message.f != null && message.hasOwnProperty('f')) - object.f = options.json && !isFinite(message.f) ? String(message.f) : message.f; - if (message.i != null && message.hasOwnProperty('i')) - if (typeof message.i === 'number') object.i = options.longs === String ? String(message.i) : message.i; + if (message.name != null && message.hasOwnProperty("name")) + object.name = message.name; + if (message.f != null && message.hasOwnProperty("f")) + object.f = + options.json && !isFinite(message.f) ? String(message.f) : message.f; + if (message.i != null && message.hasOwnProperty("i")) + if (typeof message.i === "number") + object.i = options.longs === String ? String(message.i) : message.i; else object.i = options.longs === String ? $util.Long.prototype.toString.call(message.i) : options.longs === Number - ? new $util.LongBits(message.i.low >>> 0, message.i.high >>> 0).toNumber() + ? new $util.LongBits( + message.i.low >>> 0, + message.i.high >>> 0, + ).toNumber() : message.i; - if (message.s != null && message.hasOwnProperty('s')) + if (message.s != null && message.hasOwnProperty("s")) object.s = options.bytes === String ? $util.base64.encode(message.s, 0, message.s.length) : options.bytes === Array ? Array.prototype.slice.call(message.s) : message.s; - if (message.t != null && message.hasOwnProperty('t')) + if (message.t != null && message.hasOwnProperty("t")) object.t = $root.onnx.TensorProto.toObject(message.t, options); - if (message.g != null && message.hasOwnProperty('g')) + if (message.g != null && message.hasOwnProperty("g")) object.g = $root.onnx.GraphProto.toObject(message.g, options); if (message.floats && message.floats.length) { object.floats = []; for (var j = 0; j < message.floats.length; ++j) object.floats[j] = - options.json && !isFinite(message.floats[j]) ? String(message.floats[j]) : message.floats[j]; + options.json && !isFinite(message.floats[j]) + ? String(message.floats[j]) + : message.floats[j]; } if (message.ints && message.ints.length) { object.ints = []; for (var j = 0; j < message.ints.length; ++j) - if (typeof message.ints[j] === 'number') - object.ints[j] = options.longs === String ? String(message.ints[j]) : message.ints[j]; + if (typeof message.ints[j] === "number") + object.ints[j] = + options.longs === String + ? String(message.ints[j]) + : message.ints[j]; else object.ints[j] = options.longs === String ? $util.Long.prototype.toString.call(message.ints[j]) : options.longs === Number - ? new $util.LongBits(message.ints[j].low >>> 0, message.ints[j].high >>> 0).toNumber() + ? new $util.LongBits( + message.ints[j].low >>> 0, + message.ints[j].high >>> 0, + ).toNumber() : message.ints[j]; } if (message.strings && message.strings.length) { @@ -846,7 +982,11 @@ $root.onnx = (function () { for (var j = 0; j < message.strings.length; ++j) object.strings[j] = options.bytes === String - ? $util.base64.encode(message.strings[j], 0, message.strings[j].length) + ? $util.base64.encode( + message.strings[j], + 0, + message.strings[j].length, + ) : options.bytes === Array ? Array.prototype.slice.call(message.strings[j]) : message.strings[j]; @@ -854,36 +994,56 @@ $root.onnx = (function () { if (message.tensors && message.tensors.length) { object.tensors = []; for (var j = 0; j < message.tensors.length; ++j) - object.tensors[j] = $root.onnx.TensorProto.toObject(message.tensors[j], options); + object.tensors[j] = $root.onnx.TensorProto.toObject( + message.tensors[j], + options, + ); } if (message.graphs && message.graphs.length) { object.graphs = []; for (var j = 0; j < message.graphs.length; ++j) - object.graphs[j] = $root.onnx.GraphProto.toObject(message.graphs[j], options); + object.graphs[j] = $root.onnx.GraphProto.toObject( + message.graphs[j], + options, + ); } - if (message.docString != null && message.hasOwnProperty('docString')) object.docString = message.docString; - if (message.tp != null && message.hasOwnProperty('tp')) + if (message.docString != null && message.hasOwnProperty("docString")) + object.docString = message.docString; + if (message.tp != null && message.hasOwnProperty("tp")) object.tp = $root.onnx.TypeProto.toObject(message.tp, options); if (message.typeProtos && message.typeProtos.length) { object.typeProtos = []; for (var j = 0; j < message.typeProtos.length; ++j) - object.typeProtos[j] = $root.onnx.TypeProto.toObject(message.typeProtos[j], options); + object.typeProtos[j] = $root.onnx.TypeProto.toObject( + message.typeProtos[j], + options, + ); } - if (message.type != null && message.hasOwnProperty('type')) + if (message.type != null && message.hasOwnProperty("type")) object.type = options.enums === String - ? $root.onnx.AttributeProto.AttributeType[message.type] === undefined + ? $root.onnx.AttributeProto.AttributeType[message.type] === + undefined ? message.type : $root.onnx.AttributeProto.AttributeType[message.type] : message.type; - if (message.refAttrName != null && message.hasOwnProperty('refAttrName')) + if (message.refAttrName != null && message.hasOwnProperty("refAttrName")) object.refAttrName = message.refAttrName; - if (message.sparseTensor != null && message.hasOwnProperty('sparseTensor')) - object.sparseTensor = $root.onnx.SparseTensorProto.toObject(message.sparseTensor, options); + if ( + message.sparseTensor != null && + message.hasOwnProperty("sparseTensor") + ) + object.sparseTensor = $root.onnx.SparseTensorProto.toObject( + message.sparseTensor, + options, + ); if (message.sparseTensors && message.sparseTensors.length) { object.sparseTensors = []; for (var j = 0; j < message.sparseTensors.length; ++j) - object.sparseTensors[j] = $root.onnx.SparseTensorProto.toObject(message.sparseTensors[j], options); + object.sparseTensors[j] = $root.onnx.SparseTensorProto.toObject( + message.sparseTensors[j], + options, + ); } return object; }; @@ -909,9 +1069,9 @@ $root.onnx = (function () { */ AttributeProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { if (typeUrlPrefix === undefined) { - typeUrlPrefix = 'type.googleapis.com'; + typeUrlPrefix = "type.googleapis.com"; } - return typeUrlPrefix + '/onnx.AttributeProto'; + return typeUrlPrefix + "/onnx.AttributeProto"; }; /** @@ -937,21 +1097,21 @@ $root.onnx = (function () { AttributeProto.AttributeType = (function () { var valuesById = {}, values = Object.create(valuesById); - values[(valuesById[0] = 'UNDEFINED')] = 0; - values[(valuesById[1] = 'FLOAT')] = 1; - values[(valuesById[2] = 'INT')] = 2; - values[(valuesById[3] = 'STRING')] = 3; - values[(valuesById[4] = 'TENSOR')] = 4; - values[(valuesById[5] = 'GRAPH')] = 5; - values[(valuesById[11] = 'SPARSE_TENSOR')] = 11; - values[(valuesById[13] = 'TYPE_PROTO')] = 13; - values[(valuesById[6] = 'FLOATS')] = 6; - values[(valuesById[7] = 'INTS')] = 7; - values[(valuesById[8] = 'STRINGS')] = 8; - values[(valuesById[9] = 'TENSORS')] = 9; - values[(valuesById[10] = 'GRAPHS')] = 10; - values[(valuesById[12] = 'SPARSE_TENSORS')] = 12; - values[(valuesById[14] = 'TYPE_PROTOS')] = 14; + values[(valuesById[0] = "UNDEFINED")] = 0; + values[(valuesById[1] = "FLOAT")] = 1; + values[(valuesById[2] = "INT")] = 2; + values[(valuesById[3] = "STRING")] = 3; + values[(valuesById[4] = "TENSOR")] = 4; + values[(valuesById[5] = "GRAPH")] = 5; + values[(valuesById[11] = "SPARSE_TENSOR")] = 11; + values[(valuesById[13] = "TYPE_PROTO")] = 13; + values[(valuesById[6] = "FLOATS")] = 6; + values[(valuesById[7] = "INTS")] = 7; + values[(valuesById[8] = "STRINGS")] = 8; + values[(valuesById[9] = "TENSORS")] = 9; + values[(valuesById[10] = "GRAPHS")] = 10; + values[(valuesById[12] = "SPARSE_TENSORS")] = 12; + values[(valuesById[14] = "TYPE_PROTOS")] = 14; return values; })(); @@ -988,7 +1148,7 @@ $root.onnx = (function () { * @memberof onnx.ValueInfoProto * @instance */ - ValueInfoProto.prototype.name = ''; + ValueInfoProto.prototype.name = ""; /** * ValueInfoProto type. @@ -1004,7 +1164,7 @@ $root.onnx = (function () { * @memberof onnx.ValueInfoProto * @instance */ - ValueInfoProto.prototype.docString = ''; + ValueInfoProto.prototype.docString = ""; /** * Creates a new ValueInfoProto instance using the specified properties. @@ -1029,11 +1189,17 @@ $root.onnx = (function () { */ ValueInfoProto.encode = function encode(message, writer) { if (!writer) writer = $Writer.create(); - if (message.name != null && Object.hasOwnProperty.call(message, 'name')) + if (message.name != null && Object.hasOwnProperty.call(message, "name")) writer.uint32(/* id 1, wireType 2 =*/ 10).string(message.name); - if (message.type != null && Object.hasOwnProperty.call(message, 'type')) - $root.onnx.TypeProto.encode(message.type, writer.uint32(/* id 2, wireType 2 =*/ 18).fork()).ldelim(); - if (message.docString != null && Object.hasOwnProperty.call(message, 'docString')) + if (message.type != null && Object.hasOwnProperty.call(message, "type")) + $root.onnx.TypeProto.encode( + message.type, + writer.uint32(/* id 2, wireType 2 =*/ 18).fork(), + ).ldelim(); + if ( + message.docString != null && + Object.hasOwnProperty.call(message, "docString") + ) writer.uint32(/* id 3, wireType 2 =*/ 26).string(message.docString); return writer; }; @@ -1113,15 +1279,17 @@ $root.onnx = (function () { * @returns {string|null} `null` if valid, otherwise the reason why it is not */ ValueInfoProto.verify = function verify(message) { - if (typeof message !== 'object' || message === null) return 'object expected'; - if (message.name != null && message.hasOwnProperty('name')) - if (!$util.isString(message.name)) return 'name: string expected'; - if (message.type != null && message.hasOwnProperty('type')) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.name != null && message.hasOwnProperty("name")) + if (!$util.isString(message.name)) return "name: string expected"; + if (message.type != null && message.hasOwnProperty("type")) { var error = $root.onnx.TypeProto.verify(message.type); - if (error) return 'type.' + error; + if (error) return "type." + error; } - if (message.docString != null && message.hasOwnProperty('docString')) - if (!$util.isString(message.docString)) return 'docString: string expected'; + if (message.docString != null && message.hasOwnProperty("docString")) + if (!$util.isString(message.docString)) + return "docString: string expected"; return null; }; @@ -1138,10 +1306,12 @@ $root.onnx = (function () { var message = new $root.onnx.ValueInfoProto(); if (object.name != null) message.name = String(object.name); if (object.type != null) { - if (typeof object.type !== 'object') throw TypeError('.onnx.ValueInfoProto.type: object expected'); + if (typeof object.type !== "object") + throw TypeError(".onnx.ValueInfoProto.type: object expected"); message.type = $root.onnx.TypeProto.fromObject(object.type); } - if (object.docString != null) message.docString = String(object.docString); + if (object.docString != null) + message.docString = String(object.docString); return message; }; @@ -1158,14 +1328,16 @@ $root.onnx = (function () { if (!options) options = {}; var object = {}; if (options.defaults) { - object.name = ''; + object.name = ""; object.type = null; - object.docString = ''; + object.docString = ""; } - if (message.name != null && message.hasOwnProperty('name')) object.name = message.name; - if (message.type != null && message.hasOwnProperty('type')) + if (message.name != null && message.hasOwnProperty("name")) + object.name = message.name; + if (message.type != null && message.hasOwnProperty("type")) object.type = $root.onnx.TypeProto.toObject(message.type, options); - if (message.docString != null && message.hasOwnProperty('docString')) object.docString = message.docString; + if (message.docString != null && message.hasOwnProperty("docString")) + object.docString = message.docString; return object; }; @@ -1190,9 +1362,9 @@ $root.onnx = (function () { */ ValueInfoProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { if (typeUrlPrefix === undefined) { - typeUrlPrefix = 'type.googleapis.com'; + typeUrlPrefix = "type.googleapis.com"; } - return typeUrlPrefix + '/onnx.ValueInfoProto'; + return typeUrlPrefix + "/onnx.ValueInfoProto"; }; return ValueInfoProto; @@ -1251,7 +1423,7 @@ $root.onnx = (function () { * @memberof onnx.NodeProto * @instance */ - NodeProto.prototype.name = ''; + NodeProto.prototype.name = ""; /** * NodeProto opType. @@ -1259,7 +1431,7 @@ $root.onnx = (function () { * @memberof onnx.NodeProto * @instance */ - NodeProto.prototype.opType = ''; + NodeProto.prototype.opType = ""; /** * NodeProto domain. @@ -1267,7 +1439,7 @@ $root.onnx = (function () { * @memberof onnx.NodeProto * @instance */ - NodeProto.prototype.domain = ''; + NodeProto.prototype.domain = ""; /** * NodeProto attribute. @@ -1283,7 +1455,7 @@ $root.onnx = (function () { * @memberof onnx.NodeProto * @instance */ - NodeProto.prototype.docString = ''; + NodeProto.prototype.docString = ""; /** * Creates a new NodeProto instance using the specified properties. @@ -1314,9 +1486,12 @@ $root.onnx = (function () { if (message.output != null && message.output.length) for (var i = 0; i < message.output.length; ++i) writer.uint32(/* id 2, wireType 2 =*/ 18).string(message.output[i]); - if (message.name != null && Object.hasOwnProperty.call(message, 'name')) + if (message.name != null && Object.hasOwnProperty.call(message, "name")) writer.uint32(/* id 3, wireType 2 =*/ 26).string(message.name); - if (message.opType != null && Object.hasOwnProperty.call(message, 'opType')) + if ( + message.opType != null && + Object.hasOwnProperty.call(message, "opType") + ) writer.uint32(/* id 4, wireType 2 =*/ 34).string(message.opType); if (message.attribute != null && message.attribute.length) for (var i = 0; i < message.attribute.length; ++i) @@ -1324,9 +1499,15 @@ $root.onnx = (function () { message.attribute[i], writer.uint32(/* id 5, wireType 2 =*/ 42).fork(), ).ldelim(); - if (message.docString != null && Object.hasOwnProperty.call(message, 'docString')) + if ( + message.docString != null && + Object.hasOwnProperty.call(message, "docString") + ) writer.uint32(/* id 6, wireType 2 =*/ 50).string(message.docString); - if (message.domain != null && Object.hasOwnProperty.call(message, 'domain')) + if ( + message.domain != null && + Object.hasOwnProperty.call(message, "domain") + ) writer.uint32(/* id 7, wireType 2 =*/ 58).string(message.domain); return writer; }; @@ -1385,8 +1566,11 @@ $root.onnx = (function () { break; } case 5: { - if (!(message.attribute && message.attribute.length)) message.attribute = []; - message.attribute.push($root.onnx.AttributeProto.decode(reader, reader.uint32())); + if (!(message.attribute && message.attribute.length)) + message.attribute = []; + message.attribute.push( + $root.onnx.AttributeProto.decode(reader, reader.uint32()), + ); break; } case 6: { @@ -1425,32 +1609,37 @@ $root.onnx = (function () { * @returns {string|null} `null` if valid, otherwise the reason why it is not */ NodeProto.verify = function verify(message) { - if (typeof message !== 'object' || message === null) return 'object expected'; - if (message.input != null && message.hasOwnProperty('input')) { - if (!Array.isArray(message.input)) return 'input: array expected'; + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.input != null && message.hasOwnProperty("input")) { + if (!Array.isArray(message.input)) return "input: array expected"; for (var i = 0; i < message.input.length; ++i) - if (!$util.isString(message.input[i])) return 'input: string[] expected'; + if (!$util.isString(message.input[i])) + return "input: string[] expected"; } - if (message.output != null && message.hasOwnProperty('output')) { - if (!Array.isArray(message.output)) return 'output: array expected'; + if (message.output != null && message.hasOwnProperty("output")) { + if (!Array.isArray(message.output)) return "output: array expected"; for (var i = 0; i < message.output.length; ++i) - if (!$util.isString(message.output[i])) return 'output: string[] expected'; - } - if (message.name != null && message.hasOwnProperty('name')) - if (!$util.isString(message.name)) return 'name: string expected'; - if (message.opType != null && message.hasOwnProperty('opType')) - if (!$util.isString(message.opType)) return 'opType: string expected'; - if (message.domain != null && message.hasOwnProperty('domain')) - if (!$util.isString(message.domain)) return 'domain: string expected'; - if (message.attribute != null && message.hasOwnProperty('attribute')) { - if (!Array.isArray(message.attribute)) return 'attribute: array expected'; + if (!$util.isString(message.output[i])) + return "output: string[] expected"; + } + if (message.name != null && message.hasOwnProperty("name")) + if (!$util.isString(message.name)) return "name: string expected"; + if (message.opType != null && message.hasOwnProperty("opType")) + if (!$util.isString(message.opType)) return "opType: string expected"; + if (message.domain != null && message.hasOwnProperty("domain")) + if (!$util.isString(message.domain)) return "domain: string expected"; + if (message.attribute != null && message.hasOwnProperty("attribute")) { + if (!Array.isArray(message.attribute)) + return "attribute: array expected"; for (var i = 0; i < message.attribute.length; ++i) { var error = $root.onnx.AttributeProto.verify(message.attribute[i]); - if (error) return 'attribute.' + error; + if (error) return "attribute." + error; } } - if (message.docString != null && message.hasOwnProperty('docString')) - if (!$util.isString(message.docString)) return 'docString: string expected'; + if (message.docString != null && message.hasOwnProperty("docString")) + if (!$util.isString(message.docString)) + return "docString: string expected"; return null; }; @@ -1466,27 +1655,36 @@ $root.onnx = (function () { if (object instanceof $root.onnx.NodeProto) return object; var message = new $root.onnx.NodeProto(); if (object.input) { - if (!Array.isArray(object.input)) throw TypeError('.onnx.NodeProto.input: array expected'); + if (!Array.isArray(object.input)) + throw TypeError(".onnx.NodeProto.input: array expected"); message.input = []; - for (var i = 0; i < object.input.length; ++i) message.input[i] = String(object.input[i]); + for (var i = 0; i < object.input.length; ++i) + message.input[i] = String(object.input[i]); } if (object.output) { - if (!Array.isArray(object.output)) throw TypeError('.onnx.NodeProto.output: array expected'); + if (!Array.isArray(object.output)) + throw TypeError(".onnx.NodeProto.output: array expected"); message.output = []; - for (var i = 0; i < object.output.length; ++i) message.output[i] = String(object.output[i]); + for (var i = 0; i < object.output.length; ++i) + message.output[i] = String(object.output[i]); } if (object.name != null) message.name = String(object.name); if (object.opType != null) message.opType = String(object.opType); if (object.domain != null) message.domain = String(object.domain); if (object.attribute) { - if (!Array.isArray(object.attribute)) throw TypeError('.onnx.NodeProto.attribute: array expected'); + if (!Array.isArray(object.attribute)) + throw TypeError(".onnx.NodeProto.attribute: array expected"); message.attribute = []; for (var i = 0; i < object.attribute.length; ++i) { - if (typeof object.attribute[i] !== 'object') throw TypeError('.onnx.NodeProto.attribute: object expected'); - message.attribute[i] = $root.onnx.AttributeProto.fromObject(object.attribute[i]); + if (typeof object.attribute[i] !== "object") + throw TypeError(".onnx.NodeProto.attribute: object expected"); + message.attribute[i] = $root.onnx.AttributeProto.fromObject( + object.attribute[i], + ); } } - if (object.docString != null) message.docString = String(object.docString); + if (object.docString != null) + message.docString = String(object.docString); return message; }; @@ -1508,28 +1706,37 @@ $root.onnx = (function () { object.attribute = []; } if (options.defaults) { - object.name = ''; - object.opType = ''; - object.docString = ''; - object.domain = ''; + object.name = ""; + object.opType = ""; + object.docString = ""; + object.domain = ""; } if (message.input && message.input.length) { object.input = []; - for (var j = 0; j < message.input.length; ++j) object.input[j] = message.input[j]; + for (var j = 0; j < message.input.length; ++j) + object.input[j] = message.input[j]; } if (message.output && message.output.length) { object.output = []; - for (var j = 0; j < message.output.length; ++j) object.output[j] = message.output[j]; + for (var j = 0; j < message.output.length; ++j) + object.output[j] = message.output[j]; } - if (message.name != null && message.hasOwnProperty('name')) object.name = message.name; - if (message.opType != null && message.hasOwnProperty('opType')) object.opType = message.opType; + if (message.name != null && message.hasOwnProperty("name")) + object.name = message.name; + if (message.opType != null && message.hasOwnProperty("opType")) + object.opType = message.opType; if (message.attribute && message.attribute.length) { object.attribute = []; for (var j = 0; j < message.attribute.length; ++j) - object.attribute[j] = $root.onnx.AttributeProto.toObject(message.attribute[j], options); + object.attribute[j] = $root.onnx.AttributeProto.toObject( + message.attribute[j], + options, + ); } - if (message.docString != null && message.hasOwnProperty('docString')) object.docString = message.docString; - if (message.domain != null && message.hasOwnProperty('domain')) object.domain = message.domain; + if (message.docString != null && message.hasOwnProperty("docString")) + object.docString = message.docString; + if (message.domain != null && message.hasOwnProperty("domain")) + object.domain = message.domain; return object; }; @@ -1554,9 +1761,9 @@ $root.onnx = (function () { */ NodeProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { if (typeUrlPrefix === undefined) { - typeUrlPrefix = 'type.googleapis.com'; + typeUrlPrefix = "type.googleapis.com"; } - return typeUrlPrefix + '/onnx.NodeProto'; + return typeUrlPrefix + "/onnx.NodeProto"; }; return NodeProto; @@ -1644,11 +1851,26 @@ $root.onnx = (function () { */ TrainingInfoProto.encode = function encode(message, writer) { if (!writer) writer = $Writer.create(); - if (message.initialization != null && Object.hasOwnProperty.call(message, 'initialization')) - $root.onnx.GraphProto.encode(message.initialization, writer.uint32(/* id 1, wireType 2 =*/ 10).fork()).ldelim(); - if (message.algorithm != null && Object.hasOwnProperty.call(message, 'algorithm')) - $root.onnx.GraphProto.encode(message.algorithm, writer.uint32(/* id 2, wireType 2 =*/ 18).fork()).ldelim(); - if (message.initializationBinding != null && message.initializationBinding.length) + if ( + message.initialization != null && + Object.hasOwnProperty.call(message, "initialization") + ) + $root.onnx.GraphProto.encode( + message.initialization, + writer.uint32(/* id 1, wireType 2 =*/ 10).fork(), + ).ldelim(); + if ( + message.algorithm != null && + Object.hasOwnProperty.call(message, "algorithm") + ) + $root.onnx.GraphProto.encode( + message.algorithm, + writer.uint32(/* id 2, wireType 2 =*/ 18).fork(), + ).ldelim(); + if ( + message.initializationBinding != null && + message.initializationBinding.length + ) for (var i = 0; i < message.initializationBinding.length; ++i) $root.onnx.StringStringEntryProto.encode( message.initializationBinding[i], @@ -1672,7 +1894,10 @@ $root.onnx = (function () { * @param {$protobuf.Writer} [writer] Writer to encode to * @returns {$protobuf.Writer} Writer */ - TrainingInfoProto.encodeDelimited = function encodeDelimited(message, writer) { + TrainingInfoProto.encodeDelimited = function encodeDelimited( + message, + writer, + ) { return this.encode(message, writer).ldelim(); }; @@ -1695,22 +1920,38 @@ $root.onnx = (function () { var tag = reader.uint32(); switch (tag >>> 3) { case 1: { - message.initialization = $root.onnx.GraphProto.decode(reader, reader.uint32()); + message.initialization = $root.onnx.GraphProto.decode( + reader, + reader.uint32(), + ); break; } case 2: { - message.algorithm = $root.onnx.GraphProto.decode(reader, reader.uint32()); + message.algorithm = $root.onnx.GraphProto.decode( + reader, + reader.uint32(), + ); break; } case 3: { - if (!(message.initializationBinding && message.initializationBinding.length)) + if ( + !( + message.initializationBinding && + message.initializationBinding.length + ) + ) message.initializationBinding = []; - message.initializationBinding.push($root.onnx.StringStringEntryProto.decode(reader, reader.uint32())); + message.initializationBinding.push( + $root.onnx.StringStringEntryProto.decode(reader, reader.uint32()), + ); break; } case 4: { - if (!(message.updateBinding && message.updateBinding.length)) message.updateBinding = []; - message.updateBinding.push($root.onnx.StringStringEntryProto.decode(reader, reader.uint32())); + if (!(message.updateBinding && message.updateBinding.length)) + message.updateBinding = []; + message.updateBinding.push( + $root.onnx.StringStringEntryProto.decode(reader, reader.uint32()), + ); break; } default: @@ -1745,27 +1986,43 @@ $root.onnx = (function () { * @returns {string|null} `null` if valid, otherwise the reason why it is not */ TrainingInfoProto.verify = function verify(message) { - if (typeof message !== 'object' || message === null) return 'object expected'; - if (message.initialization != null && message.hasOwnProperty('initialization')) { + if (typeof message !== "object" || message === null) + return "object expected"; + if ( + message.initialization != null && + message.hasOwnProperty("initialization") + ) { var error = $root.onnx.GraphProto.verify(message.initialization); - if (error) return 'initialization.' + error; + if (error) return "initialization." + error; } - if (message.algorithm != null && message.hasOwnProperty('algorithm')) { + if (message.algorithm != null && message.hasOwnProperty("algorithm")) { var error = $root.onnx.GraphProto.verify(message.algorithm); - if (error) return 'algorithm.' + error; - } - if (message.initializationBinding != null && message.hasOwnProperty('initializationBinding')) { - if (!Array.isArray(message.initializationBinding)) return 'initializationBinding: array expected'; + if (error) return "algorithm." + error; + } + if ( + message.initializationBinding != null && + message.hasOwnProperty("initializationBinding") + ) { + if (!Array.isArray(message.initializationBinding)) + return "initializationBinding: array expected"; for (var i = 0; i < message.initializationBinding.length; ++i) { - var error = $root.onnx.StringStringEntryProto.verify(message.initializationBinding[i]); - if (error) return 'initializationBinding.' + error; + var error = $root.onnx.StringStringEntryProto.verify( + message.initializationBinding[i], + ); + if (error) return "initializationBinding." + error; } } - if (message.updateBinding != null && message.hasOwnProperty('updateBinding')) { - if (!Array.isArray(message.updateBinding)) return 'updateBinding: array expected'; + if ( + message.updateBinding != null && + message.hasOwnProperty("updateBinding") + ) { + if (!Array.isArray(message.updateBinding)) + return "updateBinding: array expected"; for (var i = 0; i < message.updateBinding.length; ++i) { - var error = $root.onnx.StringStringEntryProto.verify(message.updateBinding[i]); - if (error) return 'updateBinding.' + error; + var error = $root.onnx.StringStringEntryProto.verify( + message.updateBinding[i], + ); + if (error) return "updateBinding." + error; } } return null; @@ -1783,34 +2040,51 @@ $root.onnx = (function () { if (object instanceof $root.onnx.TrainingInfoProto) return object; var message = new $root.onnx.TrainingInfoProto(); if (object.initialization != null) { - if (typeof object.initialization !== 'object') - throw TypeError('.onnx.TrainingInfoProto.initialization: object expected'); - message.initialization = $root.onnx.GraphProto.fromObject(object.initialization); + if (typeof object.initialization !== "object") + throw TypeError( + ".onnx.TrainingInfoProto.initialization: object expected", + ); + message.initialization = $root.onnx.GraphProto.fromObject( + object.initialization, + ); } if (object.algorithm != null) { - if (typeof object.algorithm !== 'object') throw TypeError('.onnx.TrainingInfoProto.algorithm: object expected'); + if (typeof object.algorithm !== "object") + throw TypeError(".onnx.TrainingInfoProto.algorithm: object expected"); message.algorithm = $root.onnx.GraphProto.fromObject(object.algorithm); } if (object.initializationBinding) { if (!Array.isArray(object.initializationBinding)) - throw TypeError('.onnx.TrainingInfoProto.initializationBinding: array expected'); + throw TypeError( + ".onnx.TrainingInfoProto.initializationBinding: array expected", + ); message.initializationBinding = []; for (var i = 0; i < object.initializationBinding.length; ++i) { - if (typeof object.initializationBinding[i] !== 'object') - throw TypeError('.onnx.TrainingInfoProto.initializationBinding: object expected'); - message.initializationBinding[i] = $root.onnx.StringStringEntryProto.fromObject( - object.initializationBinding[i], - ); + if (typeof object.initializationBinding[i] !== "object") + throw TypeError( + ".onnx.TrainingInfoProto.initializationBinding: object expected", + ); + message.initializationBinding[i] = + $root.onnx.StringStringEntryProto.fromObject( + object.initializationBinding[i], + ); } } if (object.updateBinding) { if (!Array.isArray(object.updateBinding)) - throw TypeError('.onnx.TrainingInfoProto.updateBinding: array expected'); + throw TypeError( + ".onnx.TrainingInfoProto.updateBinding: array expected", + ); message.updateBinding = []; for (var i = 0; i < object.updateBinding.length; ++i) { - if (typeof object.updateBinding[i] !== 'object') - throw TypeError('.onnx.TrainingInfoProto.updateBinding: object expected'); - message.updateBinding[i] = $root.onnx.StringStringEntryProto.fromObject(object.updateBinding[i]); + if (typeof object.updateBinding[i] !== "object") + throw TypeError( + ".onnx.TrainingInfoProto.updateBinding: object expected", + ); + message.updateBinding[i] = + $root.onnx.StringStringEntryProto.fromObject( + object.updateBinding[i], + ); } } return message; @@ -1836,22 +2110,38 @@ $root.onnx = (function () { object.initialization = null; object.algorithm = null; } - if (message.initialization != null && message.hasOwnProperty('initialization')) - object.initialization = $root.onnx.GraphProto.toObject(message.initialization, options); - if (message.algorithm != null && message.hasOwnProperty('algorithm')) - object.algorithm = $root.onnx.GraphProto.toObject(message.algorithm, options); - if (message.initializationBinding && message.initializationBinding.length) { + if ( + message.initialization != null && + message.hasOwnProperty("initialization") + ) + object.initialization = $root.onnx.GraphProto.toObject( + message.initialization, + options, + ); + if (message.algorithm != null && message.hasOwnProperty("algorithm")) + object.algorithm = $root.onnx.GraphProto.toObject( + message.algorithm, + options, + ); + if ( + message.initializationBinding && + message.initializationBinding.length + ) { object.initializationBinding = []; for (var j = 0; j < message.initializationBinding.length; ++j) - object.initializationBinding[j] = $root.onnx.StringStringEntryProto.toObject( - message.initializationBinding[j], - options, - ); + object.initializationBinding[j] = + $root.onnx.StringStringEntryProto.toObject( + message.initializationBinding[j], + options, + ); } if (message.updateBinding && message.updateBinding.length) { object.updateBinding = []; for (var j = 0; j < message.updateBinding.length; ++j) - object.updateBinding[j] = $root.onnx.StringStringEntryProto.toObject(message.updateBinding[j], options); + object.updateBinding[j] = $root.onnx.StringStringEntryProto.toObject( + message.updateBinding[j], + options, + ); } return object; }; @@ -1877,9 +2167,9 @@ $root.onnx = (function () { */ TrainingInfoProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { if (typeUrlPrefix === undefined) { - typeUrlPrefix = 'type.googleapis.com'; + typeUrlPrefix = "type.googleapis.com"; } - return typeUrlPrefix + '/onnx.TrainingInfoProto'; + return typeUrlPrefix + "/onnx.TrainingInfoProto"; }; return TrainingInfoProto; @@ -1927,7 +2217,9 @@ $root.onnx = (function () { * @memberof onnx.ModelProto * @instance */ - ModelProto.prototype.irVersion = $util.Long ? $util.Long.fromBits(0, 0, false) : 0; + ModelProto.prototype.irVersion = $util.Long + ? $util.Long.fromBits(0, 0, false) + : 0; /** * ModelProto opsetImport. @@ -1943,7 +2235,7 @@ $root.onnx = (function () { * @memberof onnx.ModelProto * @instance */ - ModelProto.prototype.producerName = ''; + ModelProto.prototype.producerName = ""; /** * ModelProto producerVersion. @@ -1951,7 +2243,7 @@ $root.onnx = (function () { * @memberof onnx.ModelProto * @instance */ - ModelProto.prototype.producerVersion = ''; + ModelProto.prototype.producerVersion = ""; /** * ModelProto domain. @@ -1959,7 +2251,7 @@ $root.onnx = (function () { * @memberof onnx.ModelProto * @instance */ - ModelProto.prototype.domain = ''; + ModelProto.prototype.domain = ""; /** * ModelProto modelVersion. @@ -1967,7 +2259,9 @@ $root.onnx = (function () { * @memberof onnx.ModelProto * @instance */ - ModelProto.prototype.modelVersion = $util.Long ? $util.Long.fromBits(0, 0, false) : 0; + ModelProto.prototype.modelVersion = $util.Long + ? $util.Long.fromBits(0, 0, false) + : 0; /** * ModelProto docString. @@ -1975,7 +2269,7 @@ $root.onnx = (function () { * @memberof onnx.ModelProto * @instance */ - ModelProto.prototype.docString = ''; + ModelProto.prototype.docString = ""; /** * ModelProto graph. @@ -2032,20 +2326,43 @@ $root.onnx = (function () { */ ModelProto.encode = function encode(message, writer) { if (!writer) writer = $Writer.create(); - if (message.irVersion != null && Object.hasOwnProperty.call(message, 'irVersion')) + if ( + message.irVersion != null && + Object.hasOwnProperty.call(message, "irVersion") + ) writer.uint32(/* id 1, wireType 0 =*/ 8).int64(message.irVersion); - if (message.producerName != null && Object.hasOwnProperty.call(message, 'producerName')) + if ( + message.producerName != null && + Object.hasOwnProperty.call(message, "producerName") + ) writer.uint32(/* id 2, wireType 2 =*/ 18).string(message.producerName); - if (message.producerVersion != null && Object.hasOwnProperty.call(message, 'producerVersion')) - writer.uint32(/* id 3, wireType 2 =*/ 26).string(message.producerVersion); - if (message.domain != null && Object.hasOwnProperty.call(message, 'domain')) + if ( + message.producerVersion != null && + Object.hasOwnProperty.call(message, "producerVersion") + ) + writer + .uint32(/* id 3, wireType 2 =*/ 26) + .string(message.producerVersion); + if ( + message.domain != null && + Object.hasOwnProperty.call(message, "domain") + ) writer.uint32(/* id 4, wireType 2 =*/ 34).string(message.domain); - if (message.modelVersion != null && Object.hasOwnProperty.call(message, 'modelVersion')) + if ( + message.modelVersion != null && + Object.hasOwnProperty.call(message, "modelVersion") + ) writer.uint32(/* id 5, wireType 0 =*/ 40).int64(message.modelVersion); - if (message.docString != null && Object.hasOwnProperty.call(message, 'docString')) + if ( + message.docString != null && + Object.hasOwnProperty.call(message, "docString") + ) writer.uint32(/* id 6, wireType 2 =*/ 50).string(message.docString); - if (message.graph != null && Object.hasOwnProperty.call(message, 'graph')) - $root.onnx.GraphProto.encode(message.graph, writer.uint32(/* id 7, wireType 2 =*/ 58).fork()).ldelim(); + if (message.graph != null && Object.hasOwnProperty.call(message, "graph")) + $root.onnx.GraphProto.encode( + message.graph, + writer.uint32(/* id 7, wireType 2 =*/ 58).fork(), + ).ldelim(); if (message.opsetImport != null && message.opsetImport.length) for (var i = 0; i < message.opsetImport.length; ++i) $root.onnx.OperatorSetIdProto.encode( @@ -2109,8 +2426,11 @@ $root.onnx = (function () { break; } case 8: { - if (!(message.opsetImport && message.opsetImport.length)) message.opsetImport = []; - message.opsetImport.push($root.onnx.OperatorSetIdProto.decode(reader, reader.uint32())); + if (!(message.opsetImport && message.opsetImport.length)) + message.opsetImport = []; + message.opsetImport.push( + $root.onnx.OperatorSetIdProto.decode(reader, reader.uint32()), + ); break; } case 2: { @@ -2134,22 +2454,34 @@ $root.onnx = (function () { break; } case 7: { - message.graph = $root.onnx.GraphProto.decode(reader, reader.uint32()); + message.graph = $root.onnx.GraphProto.decode( + reader, + reader.uint32(), + ); break; } case 14: { - if (!(message.metadataProps && message.metadataProps.length)) message.metadataProps = []; - message.metadataProps.push($root.onnx.StringStringEntryProto.decode(reader, reader.uint32())); + if (!(message.metadataProps && message.metadataProps.length)) + message.metadataProps = []; + message.metadataProps.push( + $root.onnx.StringStringEntryProto.decode(reader, reader.uint32()), + ); break; } case 20: { - if (!(message.trainingInfo && message.trainingInfo.length)) message.trainingInfo = []; - message.trainingInfo.push($root.onnx.TrainingInfoProto.decode(reader, reader.uint32())); + if (!(message.trainingInfo && message.trainingInfo.length)) + message.trainingInfo = []; + message.trainingInfo.push( + $root.onnx.TrainingInfoProto.decode(reader, reader.uint32()), + ); break; } case 25: { - if (!(message.functions && message.functions.length)) message.functions = []; - message.functions.push($root.onnx.FunctionProto.decode(reader, reader.uint32())); + if (!(message.functions && message.functions.length)) + message.functions = []; + message.functions.push( + $root.onnx.FunctionProto.decode(reader, reader.uint32()), + ); break; } default: @@ -2184,27 +2516,49 @@ $root.onnx = (function () { * @returns {string|null} `null` if valid, otherwise the reason why it is not */ ModelProto.verify = function verify(message) { - if (typeof message !== 'object' || message === null) return 'object expected'; - if (message.irVersion != null && message.hasOwnProperty('irVersion')) + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.irVersion != null && message.hasOwnProperty("irVersion")) if ( !$util.isInteger(message.irVersion) && - !(message.irVersion && $util.isInteger(message.irVersion.low) && $util.isInteger(message.irVersion.high)) + !( + message.irVersion && + $util.isInteger(message.irVersion.low) && + $util.isInteger(message.irVersion.high) + ) ) - return 'irVersion: integer|Long expected'; - if (message.opsetImport != null && message.hasOwnProperty('opsetImport')) { - if (!Array.isArray(message.opsetImport)) return 'opsetImport: array expected'; + return "irVersion: integer|Long expected"; + if ( + message.opsetImport != null && + message.hasOwnProperty("opsetImport") + ) { + if (!Array.isArray(message.opsetImport)) + return "opsetImport: array expected"; for (var i = 0; i < message.opsetImport.length; ++i) { - var error = $root.onnx.OperatorSetIdProto.verify(message.opsetImport[i]); - if (error) return 'opsetImport.' + error; + var error = $root.onnx.OperatorSetIdProto.verify( + message.opsetImport[i], + ); + if (error) return "opsetImport." + error; } } - if (message.producerName != null && message.hasOwnProperty('producerName')) - if (!$util.isString(message.producerName)) return 'producerName: string expected'; - if (message.producerVersion != null && message.hasOwnProperty('producerVersion')) - if (!$util.isString(message.producerVersion)) return 'producerVersion: string expected'; - if (message.domain != null && message.hasOwnProperty('domain')) - if (!$util.isString(message.domain)) return 'domain: string expected'; - if (message.modelVersion != null && message.hasOwnProperty('modelVersion')) + if ( + message.producerName != null && + message.hasOwnProperty("producerName") + ) + if (!$util.isString(message.producerName)) + return "producerName: string expected"; + if ( + message.producerVersion != null && + message.hasOwnProperty("producerVersion") + ) + if (!$util.isString(message.producerVersion)) + return "producerVersion: string expected"; + if (message.domain != null && message.hasOwnProperty("domain")) + if (!$util.isString(message.domain)) return "domain: string expected"; + if ( + message.modelVersion != null && + message.hasOwnProperty("modelVersion") + ) if ( !$util.isInteger(message.modelVersion) && !( @@ -2213,32 +2567,46 @@ $root.onnx = (function () { $util.isInteger(message.modelVersion.high) ) ) - return 'modelVersion: integer|Long expected'; - if (message.docString != null && message.hasOwnProperty('docString')) - if (!$util.isString(message.docString)) return 'docString: string expected'; - if (message.graph != null && message.hasOwnProperty('graph')) { + return "modelVersion: integer|Long expected"; + if (message.docString != null && message.hasOwnProperty("docString")) + if (!$util.isString(message.docString)) + return "docString: string expected"; + if (message.graph != null && message.hasOwnProperty("graph")) { var error = $root.onnx.GraphProto.verify(message.graph); - if (error) return 'graph.' + error; - } - if (message.metadataProps != null && message.hasOwnProperty('metadataProps')) { - if (!Array.isArray(message.metadataProps)) return 'metadataProps: array expected'; + if (error) return "graph." + error; + } + if ( + message.metadataProps != null && + message.hasOwnProperty("metadataProps") + ) { + if (!Array.isArray(message.metadataProps)) + return "metadataProps: array expected"; for (var i = 0; i < message.metadataProps.length; ++i) { - var error = $root.onnx.StringStringEntryProto.verify(message.metadataProps[i]); - if (error) return 'metadataProps.' + error; + var error = $root.onnx.StringStringEntryProto.verify( + message.metadataProps[i], + ); + if (error) return "metadataProps." + error; } } - if (message.trainingInfo != null && message.hasOwnProperty('trainingInfo')) { - if (!Array.isArray(message.trainingInfo)) return 'trainingInfo: array expected'; + if ( + message.trainingInfo != null && + message.hasOwnProperty("trainingInfo") + ) { + if (!Array.isArray(message.trainingInfo)) + return "trainingInfo: array expected"; for (var i = 0; i < message.trainingInfo.length; ++i) { - var error = $root.onnx.TrainingInfoProto.verify(message.trainingInfo[i]); - if (error) return 'trainingInfo.' + error; + var error = $root.onnx.TrainingInfoProto.verify( + message.trainingInfo[i], + ); + if (error) return "trainingInfo." + error; } } - if (message.functions != null && message.hasOwnProperty('functions')) { - if (!Array.isArray(message.functions)) return 'functions: array expected'; + if (message.functions != null && message.hasOwnProperty("functions")) { + if (!Array.isArray(message.functions)) + return "functions: array expected"; for (var i = 0; i < message.functions.length; ++i) { var error = $root.onnx.FunctionProto.verify(message.functions[i]); - if (error) return 'functions.' + error; + if (error) return "functions." + error; } } return null; @@ -2256,61 +2624,92 @@ $root.onnx = (function () { if (object instanceof $root.onnx.ModelProto) return object; var message = new $root.onnx.ModelProto(); if (object.irVersion != null) - if ($util.Long) (message.irVersion = $util.Long.fromValue(object.irVersion)).unsigned = false; - else if (typeof object.irVersion === 'string') message.irVersion = parseInt(object.irVersion, 10); - else if (typeof object.irVersion === 'number') message.irVersion = object.irVersion; - else if (typeof object.irVersion === 'object') - message.irVersion = new $util.LongBits(object.irVersion.low >>> 0, object.irVersion.high >>> 0).toNumber(); + if ($util.Long) + (message.irVersion = $util.Long.fromValue( + object.irVersion, + )).unsigned = false; + else if (typeof object.irVersion === "string") + message.irVersion = parseInt(object.irVersion, 10); + else if (typeof object.irVersion === "number") + message.irVersion = object.irVersion; + else if (typeof object.irVersion === "object") + message.irVersion = new $util.LongBits( + object.irVersion.low >>> 0, + object.irVersion.high >>> 0, + ).toNumber(); if (object.opsetImport) { - if (!Array.isArray(object.opsetImport)) throw TypeError('.onnx.ModelProto.opsetImport: array expected'); + if (!Array.isArray(object.opsetImport)) + throw TypeError(".onnx.ModelProto.opsetImport: array expected"); message.opsetImport = []; for (var i = 0; i < object.opsetImport.length; ++i) { - if (typeof object.opsetImport[i] !== 'object') - throw TypeError('.onnx.ModelProto.opsetImport: object expected'); - message.opsetImport[i] = $root.onnx.OperatorSetIdProto.fromObject(object.opsetImport[i]); + if (typeof object.opsetImport[i] !== "object") + throw TypeError(".onnx.ModelProto.opsetImport: object expected"); + message.opsetImport[i] = $root.onnx.OperatorSetIdProto.fromObject( + object.opsetImport[i], + ); } } - if (object.producerName != null) message.producerName = String(object.producerName); - if (object.producerVersion != null) message.producerVersion = String(object.producerVersion); + if (object.producerName != null) + message.producerName = String(object.producerName); + if (object.producerVersion != null) + message.producerVersion = String(object.producerVersion); if (object.domain != null) message.domain = String(object.domain); if (object.modelVersion != null) - if ($util.Long) (message.modelVersion = $util.Long.fromValue(object.modelVersion)).unsigned = false; - else if (typeof object.modelVersion === 'string') message.modelVersion = parseInt(object.modelVersion, 10); - else if (typeof object.modelVersion === 'number') message.modelVersion = object.modelVersion; - else if (typeof object.modelVersion === 'object') + if ($util.Long) + (message.modelVersion = $util.Long.fromValue( + object.modelVersion, + )).unsigned = false; + else if (typeof object.modelVersion === "string") + message.modelVersion = parseInt(object.modelVersion, 10); + else if (typeof object.modelVersion === "number") + message.modelVersion = object.modelVersion; + else if (typeof object.modelVersion === "object") message.modelVersion = new $util.LongBits( object.modelVersion.low >>> 0, object.modelVersion.high >>> 0, ).toNumber(); - if (object.docString != null) message.docString = String(object.docString); + if (object.docString != null) + message.docString = String(object.docString); if (object.graph != null) { - if (typeof object.graph !== 'object') throw TypeError('.onnx.ModelProto.graph: object expected'); + if (typeof object.graph !== "object") + throw TypeError(".onnx.ModelProto.graph: object expected"); message.graph = $root.onnx.GraphProto.fromObject(object.graph); } if (object.metadataProps) { - if (!Array.isArray(object.metadataProps)) throw TypeError('.onnx.ModelProto.metadataProps: array expected'); + if (!Array.isArray(object.metadataProps)) + throw TypeError(".onnx.ModelProto.metadataProps: array expected"); message.metadataProps = []; for (var i = 0; i < object.metadataProps.length; ++i) { - if (typeof object.metadataProps[i] !== 'object') - throw TypeError('.onnx.ModelProto.metadataProps: object expected'); - message.metadataProps[i] = $root.onnx.StringStringEntryProto.fromObject(object.metadataProps[i]); + if (typeof object.metadataProps[i] !== "object") + throw TypeError(".onnx.ModelProto.metadataProps: object expected"); + message.metadataProps[i] = + $root.onnx.StringStringEntryProto.fromObject( + object.metadataProps[i], + ); } } if (object.trainingInfo) { - if (!Array.isArray(object.trainingInfo)) throw TypeError('.onnx.ModelProto.trainingInfo: array expected'); + if (!Array.isArray(object.trainingInfo)) + throw TypeError(".onnx.ModelProto.trainingInfo: array expected"); message.trainingInfo = []; for (var i = 0; i < object.trainingInfo.length; ++i) { - if (typeof object.trainingInfo[i] !== 'object') - throw TypeError('.onnx.ModelProto.trainingInfo: object expected'); - message.trainingInfo[i] = $root.onnx.TrainingInfoProto.fromObject(object.trainingInfo[i]); + if (typeof object.trainingInfo[i] !== "object") + throw TypeError(".onnx.ModelProto.trainingInfo: object expected"); + message.trainingInfo[i] = $root.onnx.TrainingInfoProto.fromObject( + object.trainingInfo[i], + ); } } if (object.functions) { - if (!Array.isArray(object.functions)) throw TypeError('.onnx.ModelProto.functions: array expected'); + if (!Array.isArray(object.functions)) + throw TypeError(".onnx.ModelProto.functions: array expected"); message.functions = []; for (var i = 0; i < object.functions.length; ++i) { - if (typeof object.functions[i] !== 'object') throw TypeError('.onnx.ModelProto.functions: object expected'); - message.functions[i] = $root.onnx.FunctionProto.fromObject(object.functions[i]); + if (typeof object.functions[i] !== "object") + throw TypeError(".onnx.ModelProto.functions: object expected"); + message.functions[i] = $root.onnx.FunctionProto.fromObject( + object.functions[i], + ); } } return message; @@ -2338,66 +2737,109 @@ $root.onnx = (function () { if ($util.Long) { var long = new $util.Long(0, 0, false); object.irVersion = - options.longs === String ? long.toString() : options.longs === Number ? long.toNumber() : long; - } else object.irVersion = options.longs === String ? '0' : 0; - object.producerName = ''; - object.producerVersion = ''; - object.domain = ''; + options.longs === String + ? long.toString() + : options.longs === Number + ? long.toNumber() + : long; + } else object.irVersion = options.longs === String ? "0" : 0; + object.producerName = ""; + object.producerVersion = ""; + object.domain = ""; if ($util.Long) { var long = new $util.Long(0, 0, false); object.modelVersion = - options.longs === String ? long.toString() : options.longs === Number ? long.toNumber() : long; - } else object.modelVersion = options.longs === String ? '0' : 0; - object.docString = ''; + options.longs === String + ? long.toString() + : options.longs === Number + ? long.toNumber() + : long; + } else object.modelVersion = options.longs === String ? "0" : 0; + object.docString = ""; object.graph = null; } - if (message.irVersion != null && message.hasOwnProperty('irVersion')) - if (typeof message.irVersion === 'number') - object.irVersion = options.longs === String ? String(message.irVersion) : message.irVersion; + if (message.irVersion != null && message.hasOwnProperty("irVersion")) + if (typeof message.irVersion === "number") + object.irVersion = + options.longs === String + ? String(message.irVersion) + : message.irVersion; else object.irVersion = options.longs === String ? $util.Long.prototype.toString.call(message.irVersion) : options.longs === Number - ? new $util.LongBits(message.irVersion.low >>> 0, message.irVersion.high >>> 0).toNumber() + ? new $util.LongBits( + message.irVersion.low >>> 0, + message.irVersion.high >>> 0, + ).toNumber() : message.irVersion; - if (message.producerName != null && message.hasOwnProperty('producerName')) + if ( + message.producerName != null && + message.hasOwnProperty("producerName") + ) object.producerName = message.producerName; - if (message.producerVersion != null && message.hasOwnProperty('producerVersion')) + if ( + message.producerVersion != null && + message.hasOwnProperty("producerVersion") + ) object.producerVersion = message.producerVersion; - if (message.domain != null && message.hasOwnProperty('domain')) object.domain = message.domain; - if (message.modelVersion != null && message.hasOwnProperty('modelVersion')) - if (typeof message.modelVersion === 'number') - object.modelVersion = options.longs === String ? String(message.modelVersion) : message.modelVersion; + if (message.domain != null && message.hasOwnProperty("domain")) + object.domain = message.domain; + if ( + message.modelVersion != null && + message.hasOwnProperty("modelVersion") + ) + if (typeof message.modelVersion === "number") + object.modelVersion = + options.longs === String + ? String(message.modelVersion) + : message.modelVersion; else object.modelVersion = options.longs === String ? $util.Long.prototype.toString.call(message.modelVersion) : options.longs === Number - ? new $util.LongBits(message.modelVersion.low >>> 0, message.modelVersion.high >>> 0).toNumber() + ? new $util.LongBits( + message.modelVersion.low >>> 0, + message.modelVersion.high >>> 0, + ).toNumber() : message.modelVersion; - if (message.docString != null && message.hasOwnProperty('docString')) object.docString = message.docString; - if (message.graph != null && message.hasOwnProperty('graph')) + if (message.docString != null && message.hasOwnProperty("docString")) + object.docString = message.docString; + if (message.graph != null && message.hasOwnProperty("graph")) object.graph = $root.onnx.GraphProto.toObject(message.graph, options); if (message.opsetImport && message.opsetImport.length) { object.opsetImport = []; for (var j = 0; j < message.opsetImport.length; ++j) - object.opsetImport[j] = $root.onnx.OperatorSetIdProto.toObject(message.opsetImport[j], options); + object.opsetImport[j] = $root.onnx.OperatorSetIdProto.toObject( + message.opsetImport[j], + options, + ); } if (message.metadataProps && message.metadataProps.length) { object.metadataProps = []; for (var j = 0; j < message.metadataProps.length; ++j) - object.metadataProps[j] = $root.onnx.StringStringEntryProto.toObject(message.metadataProps[j], options); + object.metadataProps[j] = $root.onnx.StringStringEntryProto.toObject( + message.metadataProps[j], + options, + ); } if (message.trainingInfo && message.trainingInfo.length) { object.trainingInfo = []; for (var j = 0; j < message.trainingInfo.length; ++j) - object.trainingInfo[j] = $root.onnx.TrainingInfoProto.toObject(message.trainingInfo[j], options); + object.trainingInfo[j] = $root.onnx.TrainingInfoProto.toObject( + message.trainingInfo[j], + options, + ); } if (message.functions && message.functions.length) { object.functions = []; for (var j = 0; j < message.functions.length; ++j) - object.functions[j] = $root.onnx.FunctionProto.toObject(message.functions[j], options); + object.functions[j] = $root.onnx.FunctionProto.toObject( + message.functions[j], + options, + ); } return object; }; @@ -2423,9 +2865,9 @@ $root.onnx = (function () { */ ModelProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { if (typeUrlPrefix === undefined) { - typeUrlPrefix = 'type.googleapis.com'; + typeUrlPrefix = "type.googleapis.com"; } - return typeUrlPrefix + '/onnx.ModelProto'; + return typeUrlPrefix + "/onnx.ModelProto"; }; return ModelProto; @@ -2460,7 +2902,7 @@ $root.onnx = (function () { * @memberof onnx.StringStringEntryProto * @instance */ - StringStringEntryProto.prototype.key = ''; + StringStringEntryProto.prototype.key = ""; /** * StringStringEntryProto value. @@ -2468,7 +2910,7 @@ $root.onnx = (function () { * @memberof onnx.StringStringEntryProto * @instance */ - StringStringEntryProto.prototype.value = ''; + StringStringEntryProto.prototype.value = ""; /** * Creates a new StringStringEntryProto instance using the specified properties. @@ -2493,9 +2935,9 @@ $root.onnx = (function () { */ StringStringEntryProto.encode = function encode(message, writer) { if (!writer) writer = $Writer.create(); - if (message.key != null && Object.hasOwnProperty.call(message, 'key')) + if (message.key != null && Object.hasOwnProperty.call(message, "key")) writer.uint32(/* id 1, wireType 2 =*/ 10).string(message.key); - if (message.value != null && Object.hasOwnProperty.call(message, 'value')) + if (message.value != null && Object.hasOwnProperty.call(message, "value")) writer.uint32(/* id 2, wireType 2 =*/ 18).string(message.value); return writer; }; @@ -2509,7 +2951,10 @@ $root.onnx = (function () { * @param {$protobuf.Writer} [writer] Writer to encode to * @returns {$protobuf.Writer} Writer */ - StringStringEntryProto.encodeDelimited = function encodeDelimited(message, writer) { + StringStringEntryProto.encodeDelimited = function encodeDelimited( + message, + writer, + ) { return this.encode(message, writer).ldelim(); }; @@ -2571,11 +3016,12 @@ $root.onnx = (function () { * @returns {string|null} `null` if valid, otherwise the reason why it is not */ StringStringEntryProto.verify = function verify(message) { - if (typeof message !== 'object' || message === null) return 'object expected'; - if (message.key != null && message.hasOwnProperty('key')) - if (!$util.isString(message.key)) return 'key: string expected'; - if (message.value != null && message.hasOwnProperty('value')) - if (!$util.isString(message.value)) return 'value: string expected'; + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.key != null && message.hasOwnProperty("key")) + if (!$util.isString(message.key)) return "key: string expected"; + if (message.value != null && message.hasOwnProperty("value")) + if (!$util.isString(message.value)) return "value: string expected"; return null; }; @@ -2608,11 +3054,13 @@ $root.onnx = (function () { if (!options) options = {}; var object = {}; if (options.defaults) { - object.key = ''; - object.value = ''; + object.key = ""; + object.value = ""; } - if (message.key != null && message.hasOwnProperty('key')) object.key = message.key; - if (message.value != null && message.hasOwnProperty('value')) object.value = message.value; + if (message.key != null && message.hasOwnProperty("key")) + object.key = message.key; + if (message.value != null && message.hasOwnProperty("value")) + object.value = message.value; return object; }; @@ -2637,9 +3085,9 @@ $root.onnx = (function () { */ StringStringEntryProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { if (typeUrlPrefix === undefined) { - typeUrlPrefix = 'type.googleapis.com'; + typeUrlPrefix = "type.googleapis.com"; } - return typeUrlPrefix + '/onnx.StringStringEntryProto'; + return typeUrlPrefix + "/onnx.StringStringEntryProto"; }; return StringStringEntryProto; @@ -2675,7 +3123,7 @@ $root.onnx = (function () { * @memberof onnx.TensorAnnotation * @instance */ - TensorAnnotation.prototype.tensorName = ''; + TensorAnnotation.prototype.tensorName = ""; /** * TensorAnnotation quantParameterTensorNames. @@ -2708,9 +3156,15 @@ $root.onnx = (function () { */ TensorAnnotation.encode = function encode(message, writer) { if (!writer) writer = $Writer.create(); - if (message.tensorName != null && Object.hasOwnProperty.call(message, 'tensorName')) + if ( + message.tensorName != null && + Object.hasOwnProperty.call(message, "tensorName") + ) writer.uint32(/* id 1, wireType 2 =*/ 10).string(message.tensorName); - if (message.quantParameterTensorNames != null && message.quantParameterTensorNames.length) + if ( + message.quantParameterTensorNames != null && + message.quantParameterTensorNames.length + ) for (var i = 0; i < message.quantParameterTensorNames.length; ++i) $root.onnx.StringStringEntryProto.encode( message.quantParameterTensorNames[i], @@ -2728,7 +3182,10 @@ $root.onnx = (function () { * @param {$protobuf.Writer} [writer] Writer to encode to * @returns {$protobuf.Writer} Writer */ - TensorAnnotation.encodeDelimited = function encodeDelimited(message, writer) { + TensorAnnotation.encodeDelimited = function encodeDelimited( + message, + writer, + ) { return this.encode(message, writer).ldelim(); }; @@ -2755,9 +3212,16 @@ $root.onnx = (function () { break; } case 2: { - if (!(message.quantParameterTensorNames && message.quantParameterTensorNames.length)) + if ( + !( + message.quantParameterTensorNames && + message.quantParameterTensorNames.length + ) + ) message.quantParameterTensorNames = []; - message.quantParameterTensorNames.push($root.onnx.StringStringEntryProto.decode(reader, reader.uint32())); + message.quantParameterTensorNames.push( + $root.onnx.StringStringEntryProto.decode(reader, reader.uint32()), + ); break; } default: @@ -2792,14 +3256,22 @@ $root.onnx = (function () { * @returns {string|null} `null` if valid, otherwise the reason why it is not */ TensorAnnotation.verify = function verify(message) { - if (typeof message !== 'object' || message === null) return 'object expected'; - if (message.tensorName != null && message.hasOwnProperty('tensorName')) - if (!$util.isString(message.tensorName)) return 'tensorName: string expected'; - if (message.quantParameterTensorNames != null && message.hasOwnProperty('quantParameterTensorNames')) { - if (!Array.isArray(message.quantParameterTensorNames)) return 'quantParameterTensorNames: array expected'; + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.tensorName != null && message.hasOwnProperty("tensorName")) + if (!$util.isString(message.tensorName)) + return "tensorName: string expected"; + if ( + message.quantParameterTensorNames != null && + message.hasOwnProperty("quantParameterTensorNames") + ) { + if (!Array.isArray(message.quantParameterTensorNames)) + return "quantParameterTensorNames: array expected"; for (var i = 0; i < message.quantParameterTensorNames.length; ++i) { - var error = $root.onnx.StringStringEntryProto.verify(message.quantParameterTensorNames[i]); - if (error) return 'quantParameterTensorNames.' + error; + var error = $root.onnx.StringStringEntryProto.verify( + message.quantParameterTensorNames[i], + ); + if (error) return "quantParameterTensorNames." + error; } } return null; @@ -2816,17 +3288,23 @@ $root.onnx = (function () { TensorAnnotation.fromObject = function fromObject(object) { if (object instanceof $root.onnx.TensorAnnotation) return object; var message = new $root.onnx.TensorAnnotation(); - if (object.tensorName != null) message.tensorName = String(object.tensorName); + if (object.tensorName != null) + message.tensorName = String(object.tensorName); if (object.quantParameterTensorNames) { if (!Array.isArray(object.quantParameterTensorNames)) - throw TypeError('.onnx.TensorAnnotation.quantParameterTensorNames: array expected'); + throw TypeError( + ".onnx.TensorAnnotation.quantParameterTensorNames: array expected", + ); message.quantParameterTensorNames = []; for (var i = 0; i < object.quantParameterTensorNames.length; ++i) { - if (typeof object.quantParameterTensorNames[i] !== 'object') - throw TypeError('.onnx.TensorAnnotation.quantParameterTensorNames: object expected'); - message.quantParameterTensorNames[i] = $root.onnx.StringStringEntryProto.fromObject( - object.quantParameterTensorNames[i], - ); + if (typeof object.quantParameterTensorNames[i] !== "object") + throw TypeError( + ".onnx.TensorAnnotation.quantParameterTensorNames: object expected", + ); + message.quantParameterTensorNames[i] = + $root.onnx.StringStringEntryProto.fromObject( + object.quantParameterTensorNames[i], + ); } } return message; @@ -2844,16 +3322,22 @@ $root.onnx = (function () { TensorAnnotation.toObject = function toObject(message, options) { if (!options) options = {}; var object = {}; - if (options.arrays || options.defaults) object.quantParameterTensorNames = []; - if (options.defaults) object.tensorName = ''; - if (message.tensorName != null && message.hasOwnProperty('tensorName')) object.tensorName = message.tensorName; - if (message.quantParameterTensorNames && message.quantParameterTensorNames.length) { + if (options.arrays || options.defaults) + object.quantParameterTensorNames = []; + if (options.defaults) object.tensorName = ""; + if (message.tensorName != null && message.hasOwnProperty("tensorName")) + object.tensorName = message.tensorName; + if ( + message.quantParameterTensorNames && + message.quantParameterTensorNames.length + ) { object.quantParameterTensorNames = []; for (var j = 0; j < message.quantParameterTensorNames.length; ++j) - object.quantParameterTensorNames[j] = $root.onnx.StringStringEntryProto.toObject( - message.quantParameterTensorNames[j], - options, - ); + object.quantParameterTensorNames[j] = + $root.onnx.StringStringEntryProto.toObject( + message.quantParameterTensorNames[j], + options, + ); } return object; }; @@ -2879,9 +3363,9 @@ $root.onnx = (function () { */ TensorAnnotation.getTypeUrl = function getTypeUrl(typeUrlPrefix) { if (typeUrlPrefix === undefined) { - typeUrlPrefix = 'type.googleapis.com'; + typeUrlPrefix = "type.googleapis.com"; } - return typeUrlPrefix + '/onnx.TensorAnnotation'; + return typeUrlPrefix + "/onnx.TensorAnnotation"; }; return TensorAnnotation; @@ -2938,7 +3422,7 @@ $root.onnx = (function () { * @memberof onnx.GraphProto * @instance */ - GraphProto.prototype.name = ''; + GraphProto.prototype.name = ""; /** * GraphProto initializer. @@ -2962,7 +3446,7 @@ $root.onnx = (function () { * @memberof onnx.GraphProto * @instance */ - GraphProto.prototype.docString = ''; + GraphProto.prototype.docString = ""; /** * GraphProto input. @@ -3021,8 +3505,11 @@ $root.onnx = (function () { if (!writer) writer = $Writer.create(); if (message.node != null && message.node.length) for (var i = 0; i < message.node.length; ++i) - $root.onnx.NodeProto.encode(message.node[i], writer.uint32(/* id 1, wireType 2 =*/ 10).fork()).ldelim(); - if (message.name != null && Object.hasOwnProperty.call(message, 'name')) + $root.onnx.NodeProto.encode( + message.node[i], + writer.uint32(/* id 1, wireType 2 =*/ 10).fork(), + ).ldelim(); + if (message.name != null && Object.hasOwnProperty.call(message, "name")) writer.uint32(/* id 2, wireType 2 =*/ 18).string(message.name); if (message.initializer != null && message.initializer.length) for (var i = 0; i < message.initializer.length; ++i) @@ -3030,7 +3517,10 @@ $root.onnx = (function () { message.initializer[i], writer.uint32(/* id 5, wireType 2 =*/ 42).fork(), ).ldelim(); - if (message.docString != null && Object.hasOwnProperty.call(message, 'docString')) + if ( + message.docString != null && + Object.hasOwnProperty.call(message, "docString") + ) writer.uint32(/* id 10, wireType 2 =*/ 82).string(message.docString); if (message.input != null && message.input.length) for (var i = 0; i < message.input.length; ++i) @@ -3050,7 +3540,10 @@ $root.onnx = (function () { message.valueInfo[i], writer.uint32(/* id 13, wireType 2 =*/ 106).fork(), ).ldelim(); - if (message.quantizationAnnotation != null && message.quantizationAnnotation.length) + if ( + message.quantizationAnnotation != null && + message.quantizationAnnotation.length + ) for (var i = 0; i < message.quantizationAnnotation.length; ++i) $root.onnx.TensorAnnotation.encode( message.quantizationAnnotation[i], @@ -3098,7 +3591,9 @@ $root.onnx = (function () { switch (tag >>> 3) { case 1: { if (!(message.node && message.node.length)) message.node = []; - message.node.push($root.onnx.NodeProto.decode(reader, reader.uint32())); + message.node.push( + $root.onnx.NodeProto.decode(reader, reader.uint32()), + ); break; } case 2: { @@ -3106,13 +3601,21 @@ $root.onnx = (function () { break; } case 5: { - if (!(message.initializer && message.initializer.length)) message.initializer = []; - message.initializer.push($root.onnx.TensorProto.decode(reader, reader.uint32())); + if (!(message.initializer && message.initializer.length)) + message.initializer = []; + message.initializer.push( + $root.onnx.TensorProto.decode(reader, reader.uint32()), + ); break; } case 15: { - if (!(message.sparseInitializer && message.sparseInitializer.length)) message.sparseInitializer = []; - message.sparseInitializer.push($root.onnx.SparseTensorProto.decode(reader, reader.uint32())); + if ( + !(message.sparseInitializer && message.sparseInitializer.length) + ) + message.sparseInitializer = []; + message.sparseInitializer.push( + $root.onnx.SparseTensorProto.decode(reader, reader.uint32()), + ); break; } case 10: { @@ -3121,23 +3624,37 @@ $root.onnx = (function () { } case 11: { if (!(message.input && message.input.length)) message.input = []; - message.input.push($root.onnx.ValueInfoProto.decode(reader, reader.uint32())); + message.input.push( + $root.onnx.ValueInfoProto.decode(reader, reader.uint32()), + ); break; } case 12: { if (!(message.output && message.output.length)) message.output = []; - message.output.push($root.onnx.ValueInfoProto.decode(reader, reader.uint32())); + message.output.push( + $root.onnx.ValueInfoProto.decode(reader, reader.uint32()), + ); break; } case 13: { - if (!(message.valueInfo && message.valueInfo.length)) message.valueInfo = []; - message.valueInfo.push($root.onnx.ValueInfoProto.decode(reader, reader.uint32())); + if (!(message.valueInfo && message.valueInfo.length)) + message.valueInfo = []; + message.valueInfo.push( + $root.onnx.ValueInfoProto.decode(reader, reader.uint32()), + ); break; } case 14: { - if (!(message.quantizationAnnotation && message.quantizationAnnotation.length)) + if ( + !( + message.quantizationAnnotation && + message.quantizationAnnotation.length + ) + ) message.quantizationAnnotation = []; - message.quantizationAnnotation.push($root.onnx.TensorAnnotation.decode(reader, reader.uint32())); + message.quantizationAnnotation.push( + $root.onnx.TensorAnnotation.decode(reader, reader.uint32()), + ); break; } default: @@ -3172,58 +3689,77 @@ $root.onnx = (function () { * @returns {string|null} `null` if valid, otherwise the reason why it is not */ GraphProto.verify = function verify(message) { - if (typeof message !== 'object' || message === null) return 'object expected'; - if (message.node != null && message.hasOwnProperty('node')) { - if (!Array.isArray(message.node)) return 'node: array expected'; + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.node != null && message.hasOwnProperty("node")) { + if (!Array.isArray(message.node)) return "node: array expected"; for (var i = 0; i < message.node.length; ++i) { var error = $root.onnx.NodeProto.verify(message.node[i]); - if (error) return 'node.' + error; + if (error) return "node." + error; } } - if (message.name != null && message.hasOwnProperty('name')) - if (!$util.isString(message.name)) return 'name: string expected'; - if (message.initializer != null && message.hasOwnProperty('initializer')) { - if (!Array.isArray(message.initializer)) return 'initializer: array expected'; + if (message.name != null && message.hasOwnProperty("name")) + if (!$util.isString(message.name)) return "name: string expected"; + if ( + message.initializer != null && + message.hasOwnProperty("initializer") + ) { + if (!Array.isArray(message.initializer)) + return "initializer: array expected"; for (var i = 0; i < message.initializer.length; ++i) { var error = $root.onnx.TensorProto.verify(message.initializer[i]); - if (error) return 'initializer.' + error; + if (error) return "initializer." + error; } } - if (message.sparseInitializer != null && message.hasOwnProperty('sparseInitializer')) { - if (!Array.isArray(message.sparseInitializer)) return 'sparseInitializer: array expected'; + if ( + message.sparseInitializer != null && + message.hasOwnProperty("sparseInitializer") + ) { + if (!Array.isArray(message.sparseInitializer)) + return "sparseInitializer: array expected"; for (var i = 0; i < message.sparseInitializer.length; ++i) { - var error = $root.onnx.SparseTensorProto.verify(message.sparseInitializer[i]); - if (error) return 'sparseInitializer.' + error; + var error = $root.onnx.SparseTensorProto.verify( + message.sparseInitializer[i], + ); + if (error) return "sparseInitializer." + error; } } - if (message.docString != null && message.hasOwnProperty('docString')) - if (!$util.isString(message.docString)) return 'docString: string expected'; - if (message.input != null && message.hasOwnProperty('input')) { - if (!Array.isArray(message.input)) return 'input: array expected'; + if (message.docString != null && message.hasOwnProperty("docString")) + if (!$util.isString(message.docString)) + return "docString: string expected"; + if (message.input != null && message.hasOwnProperty("input")) { + if (!Array.isArray(message.input)) return "input: array expected"; for (var i = 0; i < message.input.length; ++i) { var error = $root.onnx.ValueInfoProto.verify(message.input[i]); - if (error) return 'input.' + error; + if (error) return "input." + error; } } - if (message.output != null && message.hasOwnProperty('output')) { - if (!Array.isArray(message.output)) return 'output: array expected'; + if (message.output != null && message.hasOwnProperty("output")) { + if (!Array.isArray(message.output)) return "output: array expected"; for (var i = 0; i < message.output.length; ++i) { var error = $root.onnx.ValueInfoProto.verify(message.output[i]); - if (error) return 'output.' + error; + if (error) return "output." + error; } } - if (message.valueInfo != null && message.hasOwnProperty('valueInfo')) { - if (!Array.isArray(message.valueInfo)) return 'valueInfo: array expected'; + if (message.valueInfo != null && message.hasOwnProperty("valueInfo")) { + if (!Array.isArray(message.valueInfo)) + return "valueInfo: array expected"; for (var i = 0; i < message.valueInfo.length; ++i) { var error = $root.onnx.ValueInfoProto.verify(message.valueInfo[i]); - if (error) return 'valueInfo.' + error; + if (error) return "valueInfo." + error; } } - if (message.quantizationAnnotation != null && message.hasOwnProperty('quantizationAnnotation')) { - if (!Array.isArray(message.quantizationAnnotation)) return 'quantizationAnnotation: array expected'; + if ( + message.quantizationAnnotation != null && + message.hasOwnProperty("quantizationAnnotation") + ) { + if (!Array.isArray(message.quantizationAnnotation)) + return "quantizationAnnotation: array expected"; for (var i = 0; i < message.quantizationAnnotation.length; ++i) { - var error = $root.onnx.TensorAnnotation.verify(message.quantizationAnnotation[i]); - if (error) return 'quantizationAnnotation.' + error; + var error = $root.onnx.TensorAnnotation.verify( + message.quantizationAnnotation[i], + ); + if (error) return "quantizationAnnotation." + error; } } return null; @@ -3241,66 +3777,96 @@ $root.onnx = (function () { if (object instanceof $root.onnx.GraphProto) return object; var message = new $root.onnx.GraphProto(); if (object.node) { - if (!Array.isArray(object.node)) throw TypeError('.onnx.GraphProto.node: array expected'); + if (!Array.isArray(object.node)) + throw TypeError(".onnx.GraphProto.node: array expected"); message.node = []; for (var i = 0; i < object.node.length; ++i) { - if (typeof object.node[i] !== 'object') throw TypeError('.onnx.GraphProto.node: object expected'); + if (typeof object.node[i] !== "object") + throw TypeError(".onnx.GraphProto.node: object expected"); message.node[i] = $root.onnx.NodeProto.fromObject(object.node[i]); } } if (object.name != null) message.name = String(object.name); if (object.initializer) { - if (!Array.isArray(object.initializer)) throw TypeError('.onnx.GraphProto.initializer: array expected'); + if (!Array.isArray(object.initializer)) + throw TypeError(".onnx.GraphProto.initializer: array expected"); message.initializer = []; for (var i = 0; i < object.initializer.length; ++i) { - if (typeof object.initializer[i] !== 'object') - throw TypeError('.onnx.GraphProto.initializer: object expected'); - message.initializer[i] = $root.onnx.TensorProto.fromObject(object.initializer[i]); + if (typeof object.initializer[i] !== "object") + throw TypeError(".onnx.GraphProto.initializer: object expected"); + message.initializer[i] = $root.onnx.TensorProto.fromObject( + object.initializer[i], + ); } } if (object.sparseInitializer) { if (!Array.isArray(object.sparseInitializer)) - throw TypeError('.onnx.GraphProto.sparseInitializer: array expected'); + throw TypeError(".onnx.GraphProto.sparseInitializer: array expected"); message.sparseInitializer = []; for (var i = 0; i < object.sparseInitializer.length; ++i) { - if (typeof object.sparseInitializer[i] !== 'object') - throw TypeError('.onnx.GraphProto.sparseInitializer: object expected'); - message.sparseInitializer[i] = $root.onnx.SparseTensorProto.fromObject(object.sparseInitializer[i]); + if (typeof object.sparseInitializer[i] !== "object") + throw TypeError( + ".onnx.GraphProto.sparseInitializer: object expected", + ); + message.sparseInitializer[i] = + $root.onnx.SparseTensorProto.fromObject( + object.sparseInitializer[i], + ); } } - if (object.docString != null) message.docString = String(object.docString); + if (object.docString != null) + message.docString = String(object.docString); if (object.input) { - if (!Array.isArray(object.input)) throw TypeError('.onnx.GraphProto.input: array expected'); + if (!Array.isArray(object.input)) + throw TypeError(".onnx.GraphProto.input: array expected"); message.input = []; for (var i = 0; i < object.input.length; ++i) { - if (typeof object.input[i] !== 'object') throw TypeError('.onnx.GraphProto.input: object expected'); - message.input[i] = $root.onnx.ValueInfoProto.fromObject(object.input[i]); + if (typeof object.input[i] !== "object") + throw TypeError(".onnx.GraphProto.input: object expected"); + message.input[i] = $root.onnx.ValueInfoProto.fromObject( + object.input[i], + ); } } if (object.output) { - if (!Array.isArray(object.output)) throw TypeError('.onnx.GraphProto.output: array expected'); + if (!Array.isArray(object.output)) + throw TypeError(".onnx.GraphProto.output: array expected"); message.output = []; for (var i = 0; i < object.output.length; ++i) { - if (typeof object.output[i] !== 'object') throw TypeError('.onnx.GraphProto.output: object expected'); - message.output[i] = $root.onnx.ValueInfoProto.fromObject(object.output[i]); + if (typeof object.output[i] !== "object") + throw TypeError(".onnx.GraphProto.output: object expected"); + message.output[i] = $root.onnx.ValueInfoProto.fromObject( + object.output[i], + ); } } if (object.valueInfo) { - if (!Array.isArray(object.valueInfo)) throw TypeError('.onnx.GraphProto.valueInfo: array expected'); + if (!Array.isArray(object.valueInfo)) + throw TypeError(".onnx.GraphProto.valueInfo: array expected"); message.valueInfo = []; for (var i = 0; i < object.valueInfo.length; ++i) { - if (typeof object.valueInfo[i] !== 'object') throw TypeError('.onnx.GraphProto.valueInfo: object expected'); - message.valueInfo[i] = $root.onnx.ValueInfoProto.fromObject(object.valueInfo[i]); + if (typeof object.valueInfo[i] !== "object") + throw TypeError(".onnx.GraphProto.valueInfo: object expected"); + message.valueInfo[i] = $root.onnx.ValueInfoProto.fromObject( + object.valueInfo[i], + ); } } if (object.quantizationAnnotation) { if (!Array.isArray(object.quantizationAnnotation)) - throw TypeError('.onnx.GraphProto.quantizationAnnotation: array expected'); + throw TypeError( + ".onnx.GraphProto.quantizationAnnotation: array expected", + ); message.quantizationAnnotation = []; for (var i = 0; i < object.quantizationAnnotation.length; ++i) { - if (typeof object.quantizationAnnotation[i] !== 'object') - throw TypeError('.onnx.GraphProto.quantizationAnnotation: object expected'); - message.quantizationAnnotation[i] = $root.onnx.TensorAnnotation.fromObject(object.quantizationAnnotation[i]); + if (typeof object.quantizationAnnotation[i] !== "object") + throw TypeError( + ".onnx.GraphProto.quantizationAnnotation: object expected", + ); + message.quantizationAnnotation[i] = + $root.onnx.TensorAnnotation.fromObject( + object.quantizationAnnotation[i], + ); } } return message; @@ -3328,48 +3894,72 @@ $root.onnx = (function () { object.sparseInitializer = []; } if (options.defaults) { - object.name = ''; - object.docString = ''; + object.name = ""; + object.docString = ""; } if (message.node && message.node.length) { object.node = []; for (var j = 0; j < message.node.length; ++j) - object.node[j] = $root.onnx.NodeProto.toObject(message.node[j], options); + object.node[j] = $root.onnx.NodeProto.toObject( + message.node[j], + options, + ); } - if (message.name != null && message.hasOwnProperty('name')) object.name = message.name; + if (message.name != null && message.hasOwnProperty("name")) + object.name = message.name; if (message.initializer && message.initializer.length) { object.initializer = []; for (var j = 0; j < message.initializer.length; ++j) - object.initializer[j] = $root.onnx.TensorProto.toObject(message.initializer[j], options); + object.initializer[j] = $root.onnx.TensorProto.toObject( + message.initializer[j], + options, + ); } - if (message.docString != null && message.hasOwnProperty('docString')) object.docString = message.docString; + if (message.docString != null && message.hasOwnProperty("docString")) + object.docString = message.docString; if (message.input && message.input.length) { object.input = []; for (var j = 0; j < message.input.length; ++j) - object.input[j] = $root.onnx.ValueInfoProto.toObject(message.input[j], options); + object.input[j] = $root.onnx.ValueInfoProto.toObject( + message.input[j], + options, + ); } if (message.output && message.output.length) { object.output = []; for (var j = 0; j < message.output.length; ++j) - object.output[j] = $root.onnx.ValueInfoProto.toObject(message.output[j], options); + object.output[j] = $root.onnx.ValueInfoProto.toObject( + message.output[j], + options, + ); } if (message.valueInfo && message.valueInfo.length) { object.valueInfo = []; for (var j = 0; j < message.valueInfo.length; ++j) - object.valueInfo[j] = $root.onnx.ValueInfoProto.toObject(message.valueInfo[j], options); + object.valueInfo[j] = $root.onnx.ValueInfoProto.toObject( + message.valueInfo[j], + options, + ); } - if (message.quantizationAnnotation && message.quantizationAnnotation.length) { + if ( + message.quantizationAnnotation && + message.quantizationAnnotation.length + ) { object.quantizationAnnotation = []; for (var j = 0; j < message.quantizationAnnotation.length; ++j) - object.quantizationAnnotation[j] = $root.onnx.TensorAnnotation.toObject( - message.quantizationAnnotation[j], - options, - ); + object.quantizationAnnotation[j] = + $root.onnx.TensorAnnotation.toObject( + message.quantizationAnnotation[j], + options, + ); } if (message.sparseInitializer && message.sparseInitializer.length) { object.sparseInitializer = []; for (var j = 0; j < message.sparseInitializer.length; ++j) - object.sparseInitializer[j] = $root.onnx.SparseTensorProto.toObject(message.sparseInitializer[j], options); + object.sparseInitializer[j] = $root.onnx.SparseTensorProto.toObject( + message.sparseInitializer[j], + options, + ); } return object; }; @@ -3395,9 +3985,9 @@ $root.onnx = (function () { */ GraphProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { if (typeUrlPrefix === undefined) { - typeUrlPrefix = 'type.googleapis.com'; + typeUrlPrefix = "type.googleapis.com"; } - return typeUrlPrefix + '/onnx.GraphProto'; + return typeUrlPrefix + "/onnx.GraphProto"; }; return GraphProto; @@ -3508,7 +4098,7 @@ $root.onnx = (function () { * @memberof onnx.TensorProto * @instance */ - TensorProto.prototype.name = ''; + TensorProto.prototype.name = ""; /** * TensorProto docString. @@ -3516,7 +4106,7 @@ $root.onnx = (function () { * @memberof onnx.TensorProto * @instance */ - TensorProto.prototype.docString = ''; + TensorProto.prototype.docString = ""; /** * TensorProto rawData. @@ -3583,49 +4173,69 @@ $root.onnx = (function () { if (!writer) writer = $Writer.create(); if (message.dims != null && message.dims.length) { writer.uint32(/* id 1, wireType 2 =*/ 10).fork(); - for (var i = 0; i < message.dims.length; ++i) writer.int64(message.dims[i]); + for (var i = 0; i < message.dims.length; ++i) + writer.int64(message.dims[i]); writer.ldelim(); } - if (message.dataType != null && Object.hasOwnProperty.call(message, 'dataType')) + if ( + message.dataType != null && + Object.hasOwnProperty.call(message, "dataType") + ) writer.uint32(/* id 2, wireType 0 =*/ 16).int32(message.dataType); - if (message.segment != null && Object.hasOwnProperty.call(message, 'segment')) + if ( + message.segment != null && + Object.hasOwnProperty.call(message, "segment") + ) $root.onnx.TensorProto.Segment.encode( message.segment, writer.uint32(/* id 3, wireType 2 =*/ 26).fork(), ).ldelim(); if (message.floatData != null && message.floatData.length) { writer.uint32(/* id 4, wireType 2 =*/ 34).fork(); - for (var i = 0; i < message.floatData.length; ++i) writer.float(message.floatData[i]); + for (var i = 0; i < message.floatData.length; ++i) + writer.float(message.floatData[i]); writer.ldelim(); } if (message.int32Data != null && message.int32Data.length) { writer.uint32(/* id 5, wireType 2 =*/ 42).fork(); - for (var i = 0; i < message.int32Data.length; ++i) writer.int32(message.int32Data[i]); + for (var i = 0; i < message.int32Data.length; ++i) + writer.int32(message.int32Data[i]); writer.ldelim(); } if (message.stringData != null && message.stringData.length) for (var i = 0; i < message.stringData.length; ++i) - writer.uint32(/* id 6, wireType 2 =*/ 50).bytes(message.stringData[i]); + writer + .uint32(/* id 6, wireType 2 =*/ 50) + .bytes(message.stringData[i]); if (message.int64Data != null && message.int64Data.length) { writer.uint32(/* id 7, wireType 2 =*/ 58).fork(); - for (var i = 0; i < message.int64Data.length; ++i) writer.int64(message.int64Data[i]); + for (var i = 0; i < message.int64Data.length; ++i) + writer.int64(message.int64Data[i]); writer.ldelim(); } - if (message.name != null && Object.hasOwnProperty.call(message, 'name')) + if (message.name != null && Object.hasOwnProperty.call(message, "name")) writer.uint32(/* id 8, wireType 2 =*/ 66).string(message.name); - if (message.rawData != null && Object.hasOwnProperty.call(message, 'rawData')) + if ( + message.rawData != null && + Object.hasOwnProperty.call(message, "rawData") + ) writer.uint32(/* id 9, wireType 2 =*/ 74).bytes(message.rawData); if (message.doubleData != null && message.doubleData.length) { writer.uint32(/* id 10, wireType 2 =*/ 82).fork(); - for (var i = 0; i < message.doubleData.length; ++i) writer.double(message.doubleData[i]); + for (var i = 0; i < message.doubleData.length; ++i) + writer.double(message.doubleData[i]); writer.ldelim(); } if (message.uint64Data != null && message.uint64Data.length) { writer.uint32(/* id 11, wireType 2 =*/ 90).fork(); - for (var i = 0; i < message.uint64Data.length; ++i) writer.uint64(message.uint64Data[i]); + for (var i = 0; i < message.uint64Data.length; ++i) + writer.uint64(message.uint64Data[i]); writer.ldelim(); } - if (message.docString != null && Object.hasOwnProperty.call(message, 'docString')) + if ( + message.docString != null && + Object.hasOwnProperty.call(message, "docString") + ) writer.uint32(/* id 12, wireType 2 =*/ 98).string(message.docString); if (message.externalData != null && message.externalData.length) for (var i = 0; i < message.externalData.length; ++i) @@ -3633,7 +4243,10 @@ $root.onnx = (function () { message.externalData[i], writer.uint32(/* id 13, wireType 2 =*/ 106).fork(), ).ldelim(); - if (message.dataLocation != null && Object.hasOwnProperty.call(message, 'dataLocation')) + if ( + message.dataLocation != null && + Object.hasOwnProperty.call(message, "dataLocation") + ) writer.uint32(/* id 14, wireType 0 =*/ 112).int32(message.dataLocation); return writer; }; @@ -3682,11 +4295,15 @@ $root.onnx = (function () { break; } case 3: { - message.segment = $root.onnx.TensorProto.Segment.decode(reader, reader.uint32()); + message.segment = $root.onnx.TensorProto.Segment.decode( + reader, + reader.uint32(), + ); break; } case 4: { - if (!(message.floatData && message.floatData.length)) message.floatData = []; + if (!(message.floatData && message.floatData.length)) + message.floatData = []; if ((tag & 7) === 2) { var end2 = reader.uint32() + reader.pos; while (reader.pos < end2) message.floatData.push(reader.float()); @@ -3694,7 +4311,8 @@ $root.onnx = (function () { break; } case 5: { - if (!(message.int32Data && message.int32Data.length)) message.int32Data = []; + if (!(message.int32Data && message.int32Data.length)) + message.int32Data = []; if ((tag & 7) === 2) { var end2 = reader.uint32() + reader.pos; while (reader.pos < end2) message.int32Data.push(reader.int32()); @@ -3702,12 +4320,14 @@ $root.onnx = (function () { break; } case 6: { - if (!(message.stringData && message.stringData.length)) message.stringData = []; + if (!(message.stringData && message.stringData.length)) + message.stringData = []; message.stringData.push(reader.bytes()); break; } case 7: { - if (!(message.int64Data && message.int64Data.length)) message.int64Data = []; + if (!(message.int64Data && message.int64Data.length)) + message.int64Data = []; if ((tag & 7) === 2) { var end2 = reader.uint32() + reader.pos; while (reader.pos < end2) message.int64Data.push(reader.int64()); @@ -3727,8 +4347,11 @@ $root.onnx = (function () { break; } case 13: { - if (!(message.externalData && message.externalData.length)) message.externalData = []; - message.externalData.push($root.onnx.StringStringEntryProto.decode(reader, reader.uint32())); + if (!(message.externalData && message.externalData.length)) + message.externalData = []; + message.externalData.push( + $root.onnx.StringStringEntryProto.decode(reader, reader.uint32()), + ); break; } case 14: { @@ -3736,18 +4359,22 @@ $root.onnx = (function () { break; } case 10: { - if (!(message.doubleData && message.doubleData.length)) message.doubleData = []; + if (!(message.doubleData && message.doubleData.length)) + message.doubleData = []; if ((tag & 7) === 2) { var end2 = reader.uint32() + reader.pos; - while (reader.pos < end2) message.doubleData.push(reader.double()); + while (reader.pos < end2) + message.doubleData.push(reader.double()); } else message.doubleData.push(reader.double()); break; } case 11: { - if (!(message.uint64Data && message.uint64Data.length)) message.uint64Data = []; + if (!(message.uint64Data && message.uint64Data.length)) + message.uint64Data = []; if ((tag & 7) === 2) { var end2 = reader.uint32() + reader.pos; - while (reader.pos < end2) message.uint64Data.push(reader.uint64()); + while (reader.pos < end2) + message.uint64Data.push(reader.uint64()); } else message.uint64Data.push(reader.uint64()); break; } @@ -3783,45 +4410,58 @@ $root.onnx = (function () { * @returns {string|null} `null` if valid, otherwise the reason why it is not */ TensorProto.verify = function verify(message) { - if (typeof message !== 'object' || message === null) return 'object expected'; - if (message.dims != null && message.hasOwnProperty('dims')) { - if (!Array.isArray(message.dims)) return 'dims: array expected'; + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.dims != null && message.hasOwnProperty("dims")) { + if (!Array.isArray(message.dims)) return "dims: array expected"; for (var i = 0; i < message.dims.length; ++i) if ( !$util.isInteger(message.dims[i]) && - !(message.dims[i] && $util.isInteger(message.dims[i].low) && $util.isInteger(message.dims[i].high)) + !( + message.dims[i] && + $util.isInteger(message.dims[i].low) && + $util.isInteger(message.dims[i].high) + ) ) - return 'dims: integer|Long[] expected'; + return "dims: integer|Long[] expected"; } - if (message.dataType != null && message.hasOwnProperty('dataType')) - if (!$util.isInteger(message.dataType)) return 'dataType: integer expected'; - if (message.segment != null && message.hasOwnProperty('segment')) { + if (message.dataType != null && message.hasOwnProperty("dataType")) + if (!$util.isInteger(message.dataType)) + return "dataType: integer expected"; + if (message.segment != null && message.hasOwnProperty("segment")) { var error = $root.onnx.TensorProto.Segment.verify(message.segment); - if (error) return 'segment.' + error; + if (error) return "segment." + error; } - if (message.floatData != null && message.hasOwnProperty('floatData')) { - if (!Array.isArray(message.floatData)) return 'floatData: array expected'; + if (message.floatData != null && message.hasOwnProperty("floatData")) { + if (!Array.isArray(message.floatData)) + return "floatData: array expected"; for (var i = 0; i < message.floatData.length; ++i) - if (typeof message.floatData[i] !== 'number') return 'floatData: number[] expected'; + if (typeof message.floatData[i] !== "number") + return "floatData: number[] expected"; } - if (message.int32Data != null && message.hasOwnProperty('int32Data')) { - if (!Array.isArray(message.int32Data)) return 'int32Data: array expected'; + if (message.int32Data != null && message.hasOwnProperty("int32Data")) { + if (!Array.isArray(message.int32Data)) + return "int32Data: array expected"; for (var i = 0; i < message.int32Data.length; ++i) - if (!$util.isInteger(message.int32Data[i])) return 'int32Data: integer[] expected'; + if (!$util.isInteger(message.int32Data[i])) + return "int32Data: integer[] expected"; } - if (message.stringData != null && message.hasOwnProperty('stringData')) { - if (!Array.isArray(message.stringData)) return 'stringData: array expected'; + if (message.stringData != null && message.hasOwnProperty("stringData")) { + if (!Array.isArray(message.stringData)) + return "stringData: array expected"; for (var i = 0; i < message.stringData.length; ++i) if ( !( - (message.stringData[i] && typeof message.stringData[i].length === 'number') || + (message.stringData[i] && + typeof message.stringData[i].length === "number") || $util.isString(message.stringData[i]) ) ) - return 'stringData: buffer[] expected'; + return "stringData: buffer[] expected"; } - if (message.int64Data != null && message.hasOwnProperty('int64Data')) { - if (!Array.isArray(message.int64Data)) return 'int64Data: array expected'; + if (message.int64Data != null && message.hasOwnProperty("int64Data")) { + if (!Array.isArray(message.int64Data)) + return "int64Data: array expected"; for (var i = 0; i < message.int64Data.length; ++i) if ( !$util.isInteger(message.int64Data[i]) && @@ -3831,37 +4471,55 @@ $root.onnx = (function () { $util.isInteger(message.int64Data[i].high) ) ) - return 'int64Data: integer|Long[] expected'; - } - if (message.name != null && message.hasOwnProperty('name')) - if (!$util.isString(message.name)) return 'name: string expected'; - if (message.docString != null && message.hasOwnProperty('docString')) - if (!$util.isString(message.docString)) return 'docString: string expected'; - if (message.rawData != null && message.hasOwnProperty('rawData')) - if (!((message.rawData && typeof message.rawData.length === 'number') || $util.isString(message.rawData))) - return 'rawData: buffer expected'; - if (message.externalData != null && message.hasOwnProperty('externalData')) { - if (!Array.isArray(message.externalData)) return 'externalData: array expected'; + return "int64Data: integer|Long[] expected"; + } + if (message.name != null && message.hasOwnProperty("name")) + if (!$util.isString(message.name)) return "name: string expected"; + if (message.docString != null && message.hasOwnProperty("docString")) + if (!$util.isString(message.docString)) + return "docString: string expected"; + if (message.rawData != null && message.hasOwnProperty("rawData")) + if ( + !( + (message.rawData && typeof message.rawData.length === "number") || + $util.isString(message.rawData) + ) + ) + return "rawData: buffer expected"; + if ( + message.externalData != null && + message.hasOwnProperty("externalData") + ) { + if (!Array.isArray(message.externalData)) + return "externalData: array expected"; for (var i = 0; i < message.externalData.length; ++i) { - var error = $root.onnx.StringStringEntryProto.verify(message.externalData[i]); - if (error) return 'externalData.' + error; + var error = $root.onnx.StringStringEntryProto.verify( + message.externalData[i], + ); + if (error) return "externalData." + error; } } - if (message.dataLocation != null && message.hasOwnProperty('dataLocation')) + if ( + message.dataLocation != null && + message.hasOwnProperty("dataLocation") + ) switch (message.dataLocation) { default: - return 'dataLocation: enum value expected'; + return "dataLocation: enum value expected"; case 0: case 1: break; } - if (message.doubleData != null && message.hasOwnProperty('doubleData')) { - if (!Array.isArray(message.doubleData)) return 'doubleData: array expected'; + if (message.doubleData != null && message.hasOwnProperty("doubleData")) { + if (!Array.isArray(message.doubleData)) + return "doubleData: array expected"; for (var i = 0; i < message.doubleData.length; ++i) - if (typeof message.doubleData[i] !== 'number') return 'doubleData: number[] expected'; + if (typeof message.doubleData[i] !== "number") + return "doubleData: number[] expected"; } - if (message.uint64Data != null && message.hasOwnProperty('uint64Data')) { - if (!Array.isArray(message.uint64Data)) return 'uint64Data: array expected'; + if (message.uint64Data != null && message.hasOwnProperty("uint64Data")) { + if (!Array.isArray(message.uint64Data)) + return "uint64Data: array expected"; for (var i = 0; i < message.uint64Data.length; ++i) if ( !$util.isInteger(message.uint64Data[i]) && @@ -3871,7 +4529,7 @@ $root.onnx = (function () { $util.isInteger(message.uint64Data[i].high) ) ) - return 'uint64Data: integer|Long[] expected'; + return "uint64Data: integer|Long[] expected"; } return null; }; @@ -3888,103 +4546,143 @@ $root.onnx = (function () { if (object instanceof $root.onnx.TensorProto) return object; var message = new $root.onnx.TensorProto(); if (object.dims) { - if (!Array.isArray(object.dims)) throw TypeError('.onnx.TensorProto.dims: array expected'); + if (!Array.isArray(object.dims)) + throw TypeError(".onnx.TensorProto.dims: array expected"); message.dims = []; for (var i = 0; i < object.dims.length; ++i) - if ($util.Long) (message.dims[i] = $util.Long.fromValue(object.dims[i])).unsigned = false; - else if (typeof object.dims[i] === 'string') message.dims[i] = parseInt(object.dims[i], 10); - else if (typeof object.dims[i] === 'number') message.dims[i] = object.dims[i]; - else if (typeof object.dims[i] === 'object') - message.dims[i] = new $util.LongBits(object.dims[i].low >>> 0, object.dims[i].high >>> 0).toNumber(); + if ($util.Long) + (message.dims[i] = $util.Long.fromValue(object.dims[i])).unsigned = + false; + else if (typeof object.dims[i] === "string") + message.dims[i] = parseInt(object.dims[i], 10); + else if (typeof object.dims[i] === "number") + message.dims[i] = object.dims[i]; + else if (typeof object.dims[i] === "object") + message.dims[i] = new $util.LongBits( + object.dims[i].low >>> 0, + object.dims[i].high >>> 0, + ).toNumber(); } if (object.dataType != null) message.dataType = object.dataType | 0; if (object.segment != null) { - if (typeof object.segment !== 'object') throw TypeError('.onnx.TensorProto.segment: object expected'); - message.segment = $root.onnx.TensorProto.Segment.fromObject(object.segment); + if (typeof object.segment !== "object") + throw TypeError(".onnx.TensorProto.segment: object expected"); + message.segment = $root.onnx.TensorProto.Segment.fromObject( + object.segment, + ); } if (object.floatData) { - if (!Array.isArray(object.floatData)) throw TypeError('.onnx.TensorProto.floatData: array expected'); + if (!Array.isArray(object.floatData)) + throw TypeError(".onnx.TensorProto.floatData: array expected"); message.floatData = []; - for (var i = 0; i < object.floatData.length; ++i) message.floatData[i] = Number(object.floatData[i]); + for (var i = 0; i < object.floatData.length; ++i) + message.floatData[i] = Number(object.floatData[i]); } if (object.int32Data) { - if (!Array.isArray(object.int32Data)) throw TypeError('.onnx.TensorProto.int32Data: array expected'); + if (!Array.isArray(object.int32Data)) + throw TypeError(".onnx.TensorProto.int32Data: array expected"); message.int32Data = []; - for (var i = 0; i < object.int32Data.length; ++i) message.int32Data[i] = object.int32Data[i] | 0; + for (var i = 0; i < object.int32Data.length; ++i) + message.int32Data[i] = object.int32Data[i] | 0; } if (object.stringData) { - if (!Array.isArray(object.stringData)) throw TypeError('.onnx.TensorProto.stringData: array expected'); + if (!Array.isArray(object.stringData)) + throw TypeError(".onnx.TensorProto.stringData: array expected"); message.stringData = []; for (var i = 0; i < object.stringData.length; ++i) - if (typeof object.stringData[i] === 'string') + if (typeof object.stringData[i] === "string") $util.base64.decode( object.stringData[i], - (message.stringData[i] = $util.newBuffer($util.base64.length(object.stringData[i]))), + (message.stringData[i] = $util.newBuffer( + $util.base64.length(object.stringData[i]), + )), 0, ); - else if (object.stringData[i].length >= 0) message.stringData[i] = object.stringData[i]; + else if (object.stringData[i].length >= 0) + message.stringData[i] = object.stringData[i]; } if (object.int64Data) { - if (!Array.isArray(object.int64Data)) throw TypeError('.onnx.TensorProto.int64Data: array expected'); + if (!Array.isArray(object.int64Data)) + throw TypeError(".onnx.TensorProto.int64Data: array expected"); message.int64Data = []; for (var i = 0; i < object.int64Data.length; ++i) - if ($util.Long) (message.int64Data[i] = $util.Long.fromValue(object.int64Data[i])).unsigned = false; - else if (typeof object.int64Data[i] === 'string') message.int64Data[i] = parseInt(object.int64Data[i], 10); - else if (typeof object.int64Data[i] === 'number') message.int64Data[i] = object.int64Data[i]; - else if (typeof object.int64Data[i] === 'object') + if ($util.Long) + (message.int64Data[i] = $util.Long.fromValue( + object.int64Data[i], + )).unsigned = false; + else if (typeof object.int64Data[i] === "string") + message.int64Data[i] = parseInt(object.int64Data[i], 10); + else if (typeof object.int64Data[i] === "number") + message.int64Data[i] = object.int64Data[i]; + else if (typeof object.int64Data[i] === "object") message.int64Data[i] = new $util.LongBits( object.int64Data[i].low >>> 0, object.int64Data[i].high >>> 0, ).toNumber(); } if (object.name != null) message.name = String(object.name); - if (object.docString != null) message.docString = String(object.docString); + if (object.docString != null) + message.docString = String(object.docString); if (object.rawData != null) - if (typeof object.rawData === 'string') + if (typeof object.rawData === "string") $util.base64.decode( object.rawData, - (message.rawData = $util.newBuffer($util.base64.length(object.rawData))), + (message.rawData = $util.newBuffer( + $util.base64.length(object.rawData), + )), 0, ); else if (object.rawData.length >= 0) message.rawData = object.rawData; if (object.externalData) { - if (!Array.isArray(object.externalData)) throw TypeError('.onnx.TensorProto.externalData: array expected'); + if (!Array.isArray(object.externalData)) + throw TypeError(".onnx.TensorProto.externalData: array expected"); message.externalData = []; for (var i = 0; i < object.externalData.length; ++i) { - if (typeof object.externalData[i] !== 'object') - throw TypeError('.onnx.TensorProto.externalData: object expected'); - message.externalData[i] = $root.onnx.StringStringEntryProto.fromObject(object.externalData[i]); + if (typeof object.externalData[i] !== "object") + throw TypeError(".onnx.TensorProto.externalData: object expected"); + message.externalData[i] = + $root.onnx.StringStringEntryProto.fromObject( + object.externalData[i], + ); } } switch (object.dataLocation) { default: - if (typeof object.dataLocation === 'number') { + if (typeof object.dataLocation === "number") { message.dataLocation = object.dataLocation; break; } break; - case 'DEFAULT': + case "DEFAULT": case 0: message.dataLocation = 0; break; - case 'EXTERNAL': + case "EXTERNAL": case 1: message.dataLocation = 1; break; } if (object.doubleData) { - if (!Array.isArray(object.doubleData)) throw TypeError('.onnx.TensorProto.doubleData: array expected'); + if (!Array.isArray(object.doubleData)) + throw TypeError(".onnx.TensorProto.doubleData: array expected"); message.doubleData = []; - for (var i = 0; i < object.doubleData.length; ++i) message.doubleData[i] = Number(object.doubleData[i]); + for (var i = 0; i < object.doubleData.length; ++i) + message.doubleData[i] = Number(object.doubleData[i]); } if (object.uint64Data) { - if (!Array.isArray(object.uint64Data)) throw TypeError('.onnx.TensorProto.uint64Data: array expected'); + if (!Array.isArray(object.uint64Data)) + throw TypeError(".onnx.TensorProto.uint64Data: array expected"); message.uint64Data = []; for (var i = 0; i < object.uint64Data.length; ++i) - if ($util.Long) (message.uint64Data[i] = $util.Long.fromValue(object.uint64Data[i])).unsigned = true; - else if (typeof object.uint64Data[i] === 'string') message.uint64Data[i] = parseInt(object.uint64Data[i], 10); - else if (typeof object.uint64Data[i] === 'number') message.uint64Data[i] = object.uint64Data[i]; - else if (typeof object.uint64Data[i] === 'object') + if ($util.Long) + (message.uint64Data[i] = $util.Long.fromValue( + object.uint64Data[i], + )).unsigned = true; + else if (typeof object.uint64Data[i] === "string") + message.uint64Data[i] = parseInt(object.uint64Data[i], 10); + else if (typeof object.uint64Data[i] === "number") + message.uint64Data[i] = object.uint64Data[i]; + else if (typeof object.uint64Data[i] === "object") message.uint64Data[i] = new $util.LongBits( object.uint64Data[i].low >>> 0, object.uint64Data[i].high >>> 0, @@ -4018,47 +4716,65 @@ $root.onnx = (function () { if (options.defaults) { object.dataType = 0; object.segment = null; - object.name = ''; - if (options.bytes === String) object.rawData = ''; + object.name = ""; + if (options.bytes === String) object.rawData = ""; else { object.rawData = []; - if (options.bytes !== Array) object.rawData = $util.newBuffer(object.rawData); + if (options.bytes !== Array) + object.rawData = $util.newBuffer(object.rawData); } - object.docString = ''; - object.dataLocation = options.enums === String ? 'DEFAULT' : 0; + object.docString = ""; + object.dataLocation = options.enums === String ? "DEFAULT" : 0; } if (message.dims && message.dims.length) { object.dims = []; for (var j = 0; j < message.dims.length; ++j) - if (typeof message.dims[j] === 'number') - object.dims[j] = options.longs === String ? String(message.dims[j]) : message.dims[j]; + if (typeof message.dims[j] === "number") + object.dims[j] = + options.longs === String + ? String(message.dims[j]) + : message.dims[j]; else object.dims[j] = options.longs === String ? $util.Long.prototype.toString.call(message.dims[j]) : options.longs === Number - ? new $util.LongBits(message.dims[j].low >>> 0, message.dims[j].high >>> 0).toNumber() + ? new $util.LongBits( + message.dims[j].low >>> 0, + message.dims[j].high >>> 0, + ).toNumber() : message.dims[j]; } - if (message.dataType != null && message.hasOwnProperty('dataType')) object.dataType = message.dataType; - if (message.segment != null && message.hasOwnProperty('segment')) - object.segment = $root.onnx.TensorProto.Segment.toObject(message.segment, options); + if (message.dataType != null && message.hasOwnProperty("dataType")) + object.dataType = message.dataType; + if (message.segment != null && message.hasOwnProperty("segment")) + object.segment = $root.onnx.TensorProto.Segment.toObject( + message.segment, + options, + ); if (message.floatData && message.floatData.length) { object.floatData = []; for (var j = 0; j < message.floatData.length; ++j) object.floatData[j] = - options.json && !isFinite(message.floatData[j]) ? String(message.floatData[j]) : message.floatData[j]; + options.json && !isFinite(message.floatData[j]) + ? String(message.floatData[j]) + : message.floatData[j]; } if (message.int32Data && message.int32Data.length) { object.int32Data = []; - for (var j = 0; j < message.int32Data.length; ++j) object.int32Data[j] = message.int32Data[j]; + for (var j = 0; j < message.int32Data.length; ++j) + object.int32Data[j] = message.int32Data[j]; } if (message.stringData && message.stringData.length) { object.stringData = []; for (var j = 0; j < message.stringData.length; ++j) object.stringData[j] = options.bytes === String - ? $util.base64.encode(message.stringData[j], 0, message.stringData[j].length) + ? $util.base64.encode( + message.stringData[j], + 0, + message.stringData[j].length, + ) : options.bytes === Array ? Array.prototype.slice.call(message.stringData[j]) : message.stringData[j]; @@ -4066,18 +4782,25 @@ $root.onnx = (function () { if (message.int64Data && message.int64Data.length) { object.int64Data = []; for (var j = 0; j < message.int64Data.length; ++j) - if (typeof message.int64Data[j] === 'number') - object.int64Data[j] = options.longs === String ? String(message.int64Data[j]) : message.int64Data[j]; + if (typeof message.int64Data[j] === "number") + object.int64Data[j] = + options.longs === String + ? String(message.int64Data[j]) + : message.int64Data[j]; else object.int64Data[j] = options.longs === String ? $util.Long.prototype.toString.call(message.int64Data[j]) : options.longs === Number - ? new $util.LongBits(message.int64Data[j].low >>> 0, message.int64Data[j].high >>> 0).toNumber() + ? new $util.LongBits( + message.int64Data[j].low >>> 0, + message.int64Data[j].high >>> 0, + ).toNumber() : message.int64Data[j]; } - if (message.name != null && message.hasOwnProperty('name')) object.name = message.name; - if (message.rawData != null && message.hasOwnProperty('rawData')) + if (message.name != null && message.hasOwnProperty("name")) + object.name = message.name; + if (message.rawData != null && message.hasOwnProperty("rawData")) object.rawData = options.bytes === String ? $util.base64.encode(message.rawData, 0, message.rawData.length) @@ -4088,31 +4811,47 @@ $root.onnx = (function () { object.doubleData = []; for (var j = 0; j < message.doubleData.length; ++j) object.doubleData[j] = - options.json && !isFinite(message.doubleData[j]) ? String(message.doubleData[j]) : message.doubleData[j]; + options.json && !isFinite(message.doubleData[j]) + ? String(message.doubleData[j]) + : message.doubleData[j]; } if (message.uint64Data && message.uint64Data.length) { object.uint64Data = []; for (var j = 0; j < message.uint64Data.length; ++j) - if (typeof message.uint64Data[j] === 'number') - object.uint64Data[j] = options.longs === String ? String(message.uint64Data[j]) : message.uint64Data[j]; + if (typeof message.uint64Data[j] === "number") + object.uint64Data[j] = + options.longs === String + ? String(message.uint64Data[j]) + : message.uint64Data[j]; else object.uint64Data[j] = options.longs === String ? $util.Long.prototype.toString.call(message.uint64Data[j]) : options.longs === Number - ? new $util.LongBits(message.uint64Data[j].low >>> 0, message.uint64Data[j].high >>> 0).toNumber(true) + ? new $util.LongBits( + message.uint64Data[j].low >>> 0, + message.uint64Data[j].high >>> 0, + ).toNumber(true) : message.uint64Data[j]; } - if (message.docString != null && message.hasOwnProperty('docString')) object.docString = message.docString; + if (message.docString != null && message.hasOwnProperty("docString")) + object.docString = message.docString; if (message.externalData && message.externalData.length) { object.externalData = []; for (var j = 0; j < message.externalData.length; ++j) - object.externalData[j] = $root.onnx.StringStringEntryProto.toObject(message.externalData[j], options); + object.externalData[j] = $root.onnx.StringStringEntryProto.toObject( + message.externalData[j], + options, + ); } - if (message.dataLocation != null && message.hasOwnProperty('dataLocation')) + if ( + message.dataLocation != null && + message.hasOwnProperty("dataLocation") + ) object.dataLocation = options.enums === String - ? $root.onnx.TensorProto.DataLocation[message.dataLocation] === undefined + ? $root.onnx.TensorProto.DataLocation[message.dataLocation] === + undefined ? message.dataLocation : $root.onnx.TensorProto.DataLocation[message.dataLocation] : message.dataLocation; @@ -4140,9 +4879,9 @@ $root.onnx = (function () { */ TensorProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { if (typeUrlPrefix === undefined) { - typeUrlPrefix = 'type.googleapis.com'; + typeUrlPrefix = "type.googleapis.com"; } - return typeUrlPrefix + '/onnx.TensorProto'; + return typeUrlPrefix + "/onnx.TensorProto"; }; /** @@ -4174,27 +4913,27 @@ $root.onnx = (function () { TensorProto.DataType = (function () { var valuesById = {}, values = Object.create(valuesById); - values[(valuesById[0] = 'UNDEFINED')] = 0; - values[(valuesById[1] = 'FLOAT')] = 1; - values[(valuesById[2] = 'UINT8')] = 2; - values[(valuesById[3] = 'INT8')] = 3; - values[(valuesById[4] = 'UINT16')] = 4; - values[(valuesById[5] = 'INT16')] = 5; - values[(valuesById[6] = 'INT32')] = 6; - values[(valuesById[7] = 'INT64')] = 7; - values[(valuesById[8] = 'STRING')] = 8; - values[(valuesById[9] = 'BOOL')] = 9; - values[(valuesById[10] = 'FLOAT16')] = 10; - values[(valuesById[11] = 'DOUBLE')] = 11; - values[(valuesById[12] = 'UINT32')] = 12; - values[(valuesById[13] = 'UINT64')] = 13; - values[(valuesById[14] = 'COMPLEX64')] = 14; - values[(valuesById[15] = 'COMPLEX128')] = 15; - values[(valuesById[16] = 'BFLOAT16')] = 16; - values[(valuesById[17] = 'FLOAT8E4M3FN')] = 17; - values[(valuesById[18] = 'FLOAT8E4M3FNUZ')] = 18; - values[(valuesById[19] = 'FLOAT8E5M2')] = 19; - values[(valuesById[20] = 'FLOAT8E5M2FNUZ')] = 20; + values[(valuesById[0] = "UNDEFINED")] = 0; + values[(valuesById[1] = "FLOAT")] = 1; + values[(valuesById[2] = "UINT8")] = 2; + values[(valuesById[3] = "INT8")] = 3; + values[(valuesById[4] = "UINT16")] = 4; + values[(valuesById[5] = "INT16")] = 5; + values[(valuesById[6] = "INT32")] = 6; + values[(valuesById[7] = "INT64")] = 7; + values[(valuesById[8] = "STRING")] = 8; + values[(valuesById[9] = "BOOL")] = 9; + values[(valuesById[10] = "FLOAT16")] = 10; + values[(valuesById[11] = "DOUBLE")] = 11; + values[(valuesById[12] = "UINT32")] = 12; + values[(valuesById[13] = "UINT64")] = 13; + values[(valuesById[14] = "COMPLEX64")] = 14; + values[(valuesById[15] = "COMPLEX128")] = 15; + values[(valuesById[16] = "BFLOAT16")] = 16; + values[(valuesById[17] = "FLOAT8E4M3FN")] = 17; + values[(valuesById[18] = "FLOAT8E4M3FNUZ")] = 18; + values[(valuesById[19] = "FLOAT8E5M2")] = 19; + values[(valuesById[20] = "FLOAT8E5M2FNUZ")] = 20; return values; })(); @@ -4218,7 +4957,8 @@ $root.onnx = (function () { function Segment(properties) { if (properties) for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) - if (properties[keys[i]] != null) this[keys[i]] = properties[keys[i]]; + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; } /** @@ -4227,7 +4967,9 @@ $root.onnx = (function () { * @memberof onnx.TensorProto.Segment * @instance */ - Segment.prototype.begin = $util.Long ? $util.Long.fromBits(0, 0, false) : 0; + Segment.prototype.begin = $util.Long + ? $util.Long.fromBits(0, 0, false) + : 0; /** * Segment end. @@ -4260,9 +5002,12 @@ $root.onnx = (function () { */ Segment.encode = function encode(message, writer) { if (!writer) writer = $Writer.create(); - if (message.begin != null && Object.hasOwnProperty.call(message, 'begin')) + if ( + message.begin != null && + Object.hasOwnProperty.call(message, "begin") + ) writer.uint32(/* id 1, wireType 0 =*/ 8).int64(message.begin); - if (message.end != null && Object.hasOwnProperty.call(message, 'end')) + if (message.end != null && Object.hasOwnProperty.call(message, "end")) writer.uint32(/* id 2, wireType 0 =*/ 16).int64(message.end); return writer; }; @@ -4338,19 +5083,28 @@ $root.onnx = (function () { * @returns {string|null} `null` if valid, otherwise the reason why it is not */ Segment.verify = function verify(message) { - if (typeof message !== 'object' || message === null) return 'object expected'; - if (message.begin != null && message.hasOwnProperty('begin')) + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.begin != null && message.hasOwnProperty("begin")) if ( !$util.isInteger(message.begin) && - !(message.begin && $util.isInteger(message.begin.low) && $util.isInteger(message.begin.high)) + !( + message.begin && + $util.isInteger(message.begin.low) && + $util.isInteger(message.begin.high) + ) ) - return 'begin: integer|Long expected'; - if (message.end != null && message.hasOwnProperty('end')) + return "begin: integer|Long expected"; + if (message.end != null && message.hasOwnProperty("end")) if ( !$util.isInteger(message.end) && - !(message.end && $util.isInteger(message.end.low) && $util.isInteger(message.end.high)) + !( + message.end && + $util.isInteger(message.end.low) && + $util.isInteger(message.end.high) + ) ) - return 'end: integer|Long expected'; + return "end: integer|Long expected"; return null; }; @@ -4366,17 +5120,29 @@ $root.onnx = (function () { if (object instanceof $root.onnx.TensorProto.Segment) return object; var message = new $root.onnx.TensorProto.Segment(); if (object.begin != null) - if ($util.Long) (message.begin = $util.Long.fromValue(object.begin)).unsigned = false; - else if (typeof object.begin === 'string') message.begin = parseInt(object.begin, 10); - else if (typeof object.begin === 'number') message.begin = object.begin; - else if (typeof object.begin === 'object') - message.begin = new $util.LongBits(object.begin.low >>> 0, object.begin.high >>> 0).toNumber(); + if ($util.Long) + (message.begin = $util.Long.fromValue(object.begin)).unsigned = + false; + else if (typeof object.begin === "string") + message.begin = parseInt(object.begin, 10); + else if (typeof object.begin === "number") + message.begin = object.begin; + else if (typeof object.begin === "object") + message.begin = new $util.LongBits( + object.begin.low >>> 0, + object.begin.high >>> 0, + ).toNumber(); if (object.end != null) - if ($util.Long) (message.end = $util.Long.fromValue(object.end)).unsigned = false; - else if (typeof object.end === 'string') message.end = parseInt(object.end, 10); - else if (typeof object.end === 'number') message.end = object.end; - else if (typeof object.end === 'object') - message.end = new $util.LongBits(object.end.low >>> 0, object.end.high >>> 0).toNumber(); + if ($util.Long) + (message.end = $util.Long.fromValue(object.end)).unsigned = false; + else if (typeof object.end === "string") + message.end = parseInt(object.end, 10); + else if (typeof object.end === "number") message.end = object.end; + else if (typeof object.end === "object") + message.end = new $util.LongBits( + object.end.low >>> 0, + object.end.high >>> 0, + ).toNumber(); return message; }; @@ -4396,32 +5162,49 @@ $root.onnx = (function () { if ($util.Long) { var long = new $util.Long(0, 0, false); object.begin = - options.longs === String ? long.toString() : options.longs === Number ? long.toNumber() : long; - } else object.begin = options.longs === String ? '0' : 0; + options.longs === String + ? long.toString() + : options.longs === Number + ? long.toNumber() + : long; + } else object.begin = options.longs === String ? "0" : 0; if ($util.Long) { var long = new $util.Long(0, 0, false); - object.end = options.longs === String ? long.toString() : options.longs === Number ? long.toNumber() : long; - } else object.end = options.longs === String ? '0' : 0; + object.end = + options.longs === String + ? long.toString() + : options.longs === Number + ? long.toNumber() + : long; + } else object.end = options.longs === String ? "0" : 0; } - if (message.begin != null && message.hasOwnProperty('begin')) - if (typeof message.begin === 'number') - object.begin = options.longs === String ? String(message.begin) : message.begin; + if (message.begin != null && message.hasOwnProperty("begin")) + if (typeof message.begin === "number") + object.begin = + options.longs === String ? String(message.begin) : message.begin; else object.begin = options.longs === String ? $util.Long.prototype.toString.call(message.begin) : options.longs === Number - ? new $util.LongBits(message.begin.low >>> 0, message.begin.high >>> 0).toNumber() + ? new $util.LongBits( + message.begin.low >>> 0, + message.begin.high >>> 0, + ).toNumber() : message.begin; - if (message.end != null && message.hasOwnProperty('end')) - if (typeof message.end === 'number') - object.end = options.longs === String ? String(message.end) : message.end; + if (message.end != null && message.hasOwnProperty("end")) + if (typeof message.end === "number") + object.end = + options.longs === String ? String(message.end) : message.end; else object.end = options.longs === String ? $util.Long.prototype.toString.call(message.end) : options.longs === Number - ? new $util.LongBits(message.end.low >>> 0, message.end.high >>> 0).toNumber() + ? new $util.LongBits( + message.end.low >>> 0, + message.end.high >>> 0, + ).toNumber() : message.end; return object; }; @@ -4447,9 +5230,9 @@ $root.onnx = (function () { */ Segment.getTypeUrl = function getTypeUrl(typeUrlPrefix) { if (typeUrlPrefix === undefined) { - typeUrlPrefix = 'type.googleapis.com'; + typeUrlPrefix = "type.googleapis.com"; } - return typeUrlPrefix + '/onnx.TensorProto.Segment'; + return typeUrlPrefix + "/onnx.TensorProto.Segment"; }; return Segment; @@ -4465,8 +5248,8 @@ $root.onnx = (function () { TensorProto.DataLocation = (function () { var valuesById = {}, values = Object.create(valuesById); - values[(valuesById[0] = 'DEFAULT')] = 0; - values[(valuesById[1] = 'EXTERNAL')] = 1; + values[(valuesById[0] = "DEFAULT")] = 0; + values[(valuesById[1] = "EXTERNAL")] = 1; return values; })(); @@ -4545,13 +5328,26 @@ $root.onnx = (function () { */ SparseTensorProto.encode = function encode(message, writer) { if (!writer) writer = $Writer.create(); - if (message.values != null && Object.hasOwnProperty.call(message, 'values')) - $root.onnx.TensorProto.encode(message.values, writer.uint32(/* id 1, wireType 2 =*/ 10).fork()).ldelim(); - if (message.indices != null && Object.hasOwnProperty.call(message, 'indices')) - $root.onnx.TensorProto.encode(message.indices, writer.uint32(/* id 2, wireType 2 =*/ 18).fork()).ldelim(); + if ( + message.values != null && + Object.hasOwnProperty.call(message, "values") + ) + $root.onnx.TensorProto.encode( + message.values, + writer.uint32(/* id 1, wireType 2 =*/ 10).fork(), + ).ldelim(); + if ( + message.indices != null && + Object.hasOwnProperty.call(message, "indices") + ) + $root.onnx.TensorProto.encode( + message.indices, + writer.uint32(/* id 2, wireType 2 =*/ 18).fork(), + ).ldelim(); if (message.dims != null && message.dims.length) { writer.uint32(/* id 3, wireType 2 =*/ 26).fork(); - for (var i = 0; i < message.dims.length; ++i) writer.int64(message.dims[i]); + for (var i = 0; i < message.dims.length; ++i) + writer.int64(message.dims[i]); writer.ldelim(); } return writer; @@ -4566,7 +5362,10 @@ $root.onnx = (function () { * @param {$protobuf.Writer} [writer] Writer to encode to * @returns {$protobuf.Writer} Writer */ - SparseTensorProto.encodeDelimited = function encodeDelimited(message, writer) { + SparseTensorProto.encodeDelimited = function encodeDelimited( + message, + writer, + ) { return this.encode(message, writer).ldelim(); }; @@ -4589,11 +5388,17 @@ $root.onnx = (function () { var tag = reader.uint32(); switch (tag >>> 3) { case 1: { - message.values = $root.onnx.TensorProto.decode(reader, reader.uint32()); + message.values = $root.onnx.TensorProto.decode( + reader, + reader.uint32(), + ); break; } case 2: { - message.indices = $root.onnx.TensorProto.decode(reader, reader.uint32()); + message.indices = $root.onnx.TensorProto.decode( + reader, + reader.uint32(), + ); break; } case 3: { @@ -4636,23 +5441,28 @@ $root.onnx = (function () { * @returns {string|null} `null` if valid, otherwise the reason why it is not */ SparseTensorProto.verify = function verify(message) { - if (typeof message !== 'object' || message === null) return 'object expected'; - if (message.values != null && message.hasOwnProperty('values')) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.values != null && message.hasOwnProperty("values")) { var error = $root.onnx.TensorProto.verify(message.values); - if (error) return 'values.' + error; + if (error) return "values." + error; } - if (message.indices != null && message.hasOwnProperty('indices')) { + if (message.indices != null && message.hasOwnProperty("indices")) { var error = $root.onnx.TensorProto.verify(message.indices); - if (error) return 'indices.' + error; + if (error) return "indices." + error; } - if (message.dims != null && message.hasOwnProperty('dims')) { - if (!Array.isArray(message.dims)) return 'dims: array expected'; + if (message.dims != null && message.hasOwnProperty("dims")) { + if (!Array.isArray(message.dims)) return "dims: array expected"; for (var i = 0; i < message.dims.length; ++i) if ( !$util.isInteger(message.dims[i]) && - !(message.dims[i] && $util.isInteger(message.dims[i].low) && $util.isInteger(message.dims[i].high)) + !( + message.dims[i] && + $util.isInteger(message.dims[i].low) && + $util.isInteger(message.dims[i].high) + ) ) - return 'dims: integer|Long[] expected'; + return "dims: integer|Long[] expected"; } return null; }; @@ -4669,22 +5479,32 @@ $root.onnx = (function () { if (object instanceof $root.onnx.SparseTensorProto) return object; var message = new $root.onnx.SparseTensorProto(); if (object.values != null) { - if (typeof object.values !== 'object') throw TypeError('.onnx.SparseTensorProto.values: object expected'); + if (typeof object.values !== "object") + throw TypeError(".onnx.SparseTensorProto.values: object expected"); message.values = $root.onnx.TensorProto.fromObject(object.values); } if (object.indices != null) { - if (typeof object.indices !== 'object') throw TypeError('.onnx.SparseTensorProto.indices: object expected'); + if (typeof object.indices !== "object") + throw TypeError(".onnx.SparseTensorProto.indices: object expected"); message.indices = $root.onnx.TensorProto.fromObject(object.indices); } if (object.dims) { - if (!Array.isArray(object.dims)) throw TypeError('.onnx.SparseTensorProto.dims: array expected'); + if (!Array.isArray(object.dims)) + throw TypeError(".onnx.SparseTensorProto.dims: array expected"); message.dims = []; for (var i = 0; i < object.dims.length; ++i) - if ($util.Long) (message.dims[i] = $util.Long.fromValue(object.dims[i])).unsigned = false; - else if (typeof object.dims[i] === 'string') message.dims[i] = parseInt(object.dims[i], 10); - else if (typeof object.dims[i] === 'number') message.dims[i] = object.dims[i]; - else if (typeof object.dims[i] === 'object') - message.dims[i] = new $util.LongBits(object.dims[i].low >>> 0, object.dims[i].high >>> 0).toNumber(); + if ($util.Long) + (message.dims[i] = $util.Long.fromValue(object.dims[i])).unsigned = + false; + else if (typeof object.dims[i] === "string") + message.dims[i] = parseInt(object.dims[i], 10); + else if (typeof object.dims[i] === "number") + message.dims[i] = object.dims[i]; + else if (typeof object.dims[i] === "object") + message.dims[i] = new $util.LongBits( + object.dims[i].low >>> 0, + object.dims[i].high >>> 0, + ).toNumber(); } return message; }; @@ -4706,21 +5526,33 @@ $root.onnx = (function () { object.values = null; object.indices = null; } - if (message.values != null && message.hasOwnProperty('values')) - object.values = $root.onnx.TensorProto.toObject(message.values, options); - if (message.indices != null && message.hasOwnProperty('indices')) - object.indices = $root.onnx.TensorProto.toObject(message.indices, options); + if (message.values != null && message.hasOwnProperty("values")) + object.values = $root.onnx.TensorProto.toObject( + message.values, + options, + ); + if (message.indices != null && message.hasOwnProperty("indices")) + object.indices = $root.onnx.TensorProto.toObject( + message.indices, + options, + ); if (message.dims && message.dims.length) { object.dims = []; for (var j = 0; j < message.dims.length; ++j) - if (typeof message.dims[j] === 'number') - object.dims[j] = options.longs === String ? String(message.dims[j]) : message.dims[j]; + if (typeof message.dims[j] === "number") + object.dims[j] = + options.longs === String + ? String(message.dims[j]) + : message.dims[j]; else object.dims[j] = options.longs === String ? $util.Long.prototype.toString.call(message.dims[j]) : options.longs === Number - ? new $util.LongBits(message.dims[j].low >>> 0, message.dims[j].high >>> 0).toNumber() + ? new $util.LongBits( + message.dims[j].low >>> 0, + message.dims[j].high >>> 0, + ).toNumber() : message.dims[j]; } return object; @@ -4747,9 +5579,9 @@ $root.onnx = (function () { */ SparseTensorProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { if (typeUrlPrefix === undefined) { - typeUrlPrefix = 'type.googleapis.com'; + typeUrlPrefix = "type.googleapis.com"; } - return typeUrlPrefix + '/onnx.SparseTensorProto'; + return typeUrlPrefix + "/onnx.SparseTensorProto"; }; return SparseTensorProto; @@ -4827,7 +5659,10 @@ $root.onnx = (function () { * @param {$protobuf.Writer} [writer] Writer to encode to * @returns {$protobuf.Writer} Writer */ - TensorShapeProto.encodeDelimited = function encodeDelimited(message, writer) { + TensorShapeProto.encodeDelimited = function encodeDelimited( + message, + writer, + ) { return this.encode(message, writer).ldelim(); }; @@ -4851,7 +5686,12 @@ $root.onnx = (function () { switch (tag >>> 3) { case 1: { if (!(message.dim && message.dim.length)) message.dim = []; - message.dim.push($root.onnx.TensorShapeProto.Dimension.decode(reader, reader.uint32())); + message.dim.push( + $root.onnx.TensorShapeProto.Dimension.decode( + reader, + reader.uint32(), + ), + ); break; } default: @@ -4886,12 +5726,15 @@ $root.onnx = (function () { * @returns {string|null} `null` if valid, otherwise the reason why it is not */ TensorShapeProto.verify = function verify(message) { - if (typeof message !== 'object' || message === null) return 'object expected'; - if (message.dim != null && message.hasOwnProperty('dim')) { - if (!Array.isArray(message.dim)) return 'dim: array expected'; + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.dim != null && message.hasOwnProperty("dim")) { + if (!Array.isArray(message.dim)) return "dim: array expected"; for (var i = 0; i < message.dim.length; ++i) { - var error = $root.onnx.TensorShapeProto.Dimension.verify(message.dim[i]); - if (error) return 'dim.' + error; + var error = $root.onnx.TensorShapeProto.Dimension.verify( + message.dim[i], + ); + if (error) return "dim." + error; } } return null; @@ -4909,11 +5752,15 @@ $root.onnx = (function () { if (object instanceof $root.onnx.TensorShapeProto) return object; var message = new $root.onnx.TensorShapeProto(); if (object.dim) { - if (!Array.isArray(object.dim)) throw TypeError('.onnx.TensorShapeProto.dim: array expected'); + if (!Array.isArray(object.dim)) + throw TypeError(".onnx.TensorShapeProto.dim: array expected"); message.dim = []; for (var i = 0; i < object.dim.length; ++i) { - if (typeof object.dim[i] !== 'object') throw TypeError('.onnx.TensorShapeProto.dim: object expected'); - message.dim[i] = $root.onnx.TensorShapeProto.Dimension.fromObject(object.dim[i]); + if (typeof object.dim[i] !== "object") + throw TypeError(".onnx.TensorShapeProto.dim: object expected"); + message.dim[i] = $root.onnx.TensorShapeProto.Dimension.fromObject( + object.dim[i], + ); } } return message; @@ -4935,7 +5782,10 @@ $root.onnx = (function () { if (message.dim && message.dim.length) { object.dim = []; for (var j = 0; j < message.dim.length; ++j) - object.dim[j] = $root.onnx.TensorShapeProto.Dimension.toObject(message.dim[j], options); + object.dim[j] = $root.onnx.TensorShapeProto.Dimension.toObject( + message.dim[j], + options, + ); } return object; }; @@ -4961,9 +5811,9 @@ $root.onnx = (function () { */ TensorShapeProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { if (typeUrlPrefix === undefined) { - typeUrlPrefix = 'type.googleapis.com'; + typeUrlPrefix = "type.googleapis.com"; } - return typeUrlPrefix + '/onnx.TensorShapeProto'; + return typeUrlPrefix + "/onnx.TensorShapeProto"; }; TensorShapeProto.Dimension = (function () { @@ -4987,7 +5837,8 @@ $root.onnx = (function () { function Dimension(properties) { if (properties) for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) - if (properties[keys[i]] != null) this[keys[i]] = properties[keys[i]]; + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; } /** @@ -5012,7 +5863,7 @@ $root.onnx = (function () { * @memberof onnx.TensorShapeProto.Dimension * @instance */ - Dimension.prototype.denotation = ''; + Dimension.prototype.denotation = ""; // OneOf field names bound to virtual getters and setters var $oneOfFields; @@ -5023,8 +5874,8 @@ $root.onnx = (function () { * @memberof onnx.TensorShapeProto.Dimension * @instance */ - Object.defineProperty(Dimension.prototype, 'value', { - get: $util.oneOfGetter(($oneOfFields = ['dimValue', 'dimParam'])), + Object.defineProperty(Dimension.prototype, "value", { + get: $util.oneOfGetter(($oneOfFields = ["dimValue", "dimParam"])), set: $util.oneOfSetter($oneOfFields), }); @@ -5051,11 +5902,20 @@ $root.onnx = (function () { */ Dimension.encode = function encode(message, writer) { if (!writer) writer = $Writer.create(); - if (message.dimValue != null && Object.hasOwnProperty.call(message, 'dimValue')) + if ( + message.dimValue != null && + Object.hasOwnProperty.call(message, "dimValue") + ) writer.uint32(/* id 1, wireType 0 =*/ 8).int64(message.dimValue); - if (message.dimParam != null && Object.hasOwnProperty.call(message, 'dimParam')) + if ( + message.dimParam != null && + Object.hasOwnProperty.call(message, "dimParam") + ) writer.uint32(/* id 2, wireType 2 =*/ 18).string(message.dimParam); - if (message.denotation != null && Object.hasOwnProperty.call(message, 'denotation')) + if ( + message.denotation != null && + Object.hasOwnProperty.call(message, "denotation") + ) writer.uint32(/* id 3, wireType 2 =*/ 26).string(message.denotation); return writer; }; @@ -5135,23 +5995,30 @@ $root.onnx = (function () { * @returns {string|null} `null` if valid, otherwise the reason why it is not */ Dimension.verify = function verify(message) { - if (typeof message !== 'object' || message === null) return 'object expected'; + if (typeof message !== "object" || message === null) + return "object expected"; var properties = {}; - if (message.dimValue != null && message.hasOwnProperty('dimValue')) { + if (message.dimValue != null && message.hasOwnProperty("dimValue")) { properties.value = 1; if ( !$util.isInteger(message.dimValue) && - !(message.dimValue && $util.isInteger(message.dimValue.low) && $util.isInteger(message.dimValue.high)) + !( + message.dimValue && + $util.isInteger(message.dimValue.low) && + $util.isInteger(message.dimValue.high) + ) ) - return 'dimValue: integer|Long expected'; + return "dimValue: integer|Long expected"; } - if (message.dimParam != null && message.hasOwnProperty('dimParam')) { - if (properties.value === 1) return 'value: multiple values'; + if (message.dimParam != null && message.hasOwnProperty("dimParam")) { + if (properties.value === 1) return "value: multiple values"; properties.value = 1; - if (!$util.isString(message.dimParam)) return 'dimParam: string expected'; + if (!$util.isString(message.dimParam)) + return "dimParam: string expected"; } - if (message.denotation != null && message.hasOwnProperty('denotation')) - if (!$util.isString(message.denotation)) return 'denotation: string expected'; + if (message.denotation != null && message.hasOwnProperty("denotation")) + if (!$util.isString(message.denotation)) + return "denotation: string expected"; return null; }; @@ -5164,16 +6031,26 @@ $root.onnx = (function () { * @returns {onnx.TensorShapeProto.Dimension} Dimension */ Dimension.fromObject = function fromObject(object) { - if (object instanceof $root.onnx.TensorShapeProto.Dimension) return object; + if (object instanceof $root.onnx.TensorShapeProto.Dimension) + return object; var message = new $root.onnx.TensorShapeProto.Dimension(); if (object.dimValue != null) - if ($util.Long) (message.dimValue = $util.Long.fromValue(object.dimValue)).unsigned = false; - else if (typeof object.dimValue === 'string') message.dimValue = parseInt(object.dimValue, 10); - else if (typeof object.dimValue === 'number') message.dimValue = object.dimValue; - else if (typeof object.dimValue === 'object') - message.dimValue = new $util.LongBits(object.dimValue.low >>> 0, object.dimValue.high >>> 0).toNumber(); + if ($util.Long) + (message.dimValue = $util.Long.fromValue( + object.dimValue, + )).unsigned = false; + else if (typeof object.dimValue === "string") + message.dimValue = parseInt(object.dimValue, 10); + else if (typeof object.dimValue === "number") + message.dimValue = object.dimValue; + else if (typeof object.dimValue === "object") + message.dimValue = new $util.LongBits( + object.dimValue.low >>> 0, + object.dimValue.high >>> 0, + ).toNumber(); if (object.dimParam != null) message.dimParam = String(object.dimParam); - if (object.denotation != null) message.denotation = String(object.denotation); + if (object.denotation != null) + message.denotation = String(object.denotation); return message; }; @@ -5189,24 +6066,31 @@ $root.onnx = (function () { Dimension.toObject = function toObject(message, options) { if (!options) options = {}; var object = {}; - if (options.defaults) object.denotation = ''; - if (message.dimValue != null && message.hasOwnProperty('dimValue')) { - if (typeof message.dimValue === 'number') - object.dimValue = options.longs === String ? String(message.dimValue) : message.dimValue; + if (options.defaults) object.denotation = ""; + if (message.dimValue != null && message.hasOwnProperty("dimValue")) { + if (typeof message.dimValue === "number") + object.dimValue = + options.longs === String + ? String(message.dimValue) + : message.dimValue; else object.dimValue = options.longs === String ? $util.Long.prototype.toString.call(message.dimValue) : options.longs === Number - ? new $util.LongBits(message.dimValue.low >>> 0, message.dimValue.high >>> 0).toNumber() + ? new $util.LongBits( + message.dimValue.low >>> 0, + message.dimValue.high >>> 0, + ).toNumber() : message.dimValue; - if (options.oneofs) object.value = 'dimValue'; + if (options.oneofs) object.value = "dimValue"; } - if (message.dimParam != null && message.hasOwnProperty('dimParam')) { + if (message.dimParam != null && message.hasOwnProperty("dimParam")) { object.dimParam = message.dimParam; - if (options.oneofs) object.value = 'dimParam'; + if (options.oneofs) object.value = "dimParam"; } - if (message.denotation != null && message.hasOwnProperty('denotation')) object.denotation = message.denotation; + if (message.denotation != null && message.hasOwnProperty("denotation")) + object.denotation = message.denotation; return object; }; @@ -5231,9 +6115,9 @@ $root.onnx = (function () { */ Dimension.getTypeUrl = function getTypeUrl(typeUrlPrefix) { if (typeUrlPrefix === undefined) { - typeUrlPrefix = 'type.googleapis.com'; + typeUrlPrefix = "type.googleapis.com"; } - return typeUrlPrefix + '/onnx.TensorShapeProto.Dimension'; + return typeUrlPrefix + "/onnx.TensorShapeProto.Dimension"; }; return Dimension; @@ -5315,7 +6199,7 @@ $root.onnx = (function () { * @memberof onnx.TypeProto * @instance */ - TypeProto.prototype.denotation = ''; + TypeProto.prototype.denotation = ""; // OneOf field names bound to virtual getters and setters var $oneOfFields; @@ -5326,9 +6210,15 @@ $root.onnx = (function () { * @memberof onnx.TypeProto * @instance */ - Object.defineProperty(TypeProto.prototype, 'value', { + Object.defineProperty(TypeProto.prototype, "value", { get: $util.oneOfGetter( - ($oneOfFields = ['tensorType', 'sequenceType', 'mapType', 'optionalType', 'sparseTensorType']), + ($oneOfFields = [ + "tensorType", + "sequenceType", + "mapType", + "optionalType", + "sparseTensorType", + ]), ), set: $util.oneOfSetter($oneOfFields), }); @@ -5356,26 +6246,47 @@ $root.onnx = (function () { */ TypeProto.encode = function encode(message, writer) { if (!writer) writer = $Writer.create(); - if (message.tensorType != null && Object.hasOwnProperty.call(message, 'tensorType')) + if ( + message.tensorType != null && + Object.hasOwnProperty.call(message, "tensorType") + ) $root.onnx.TypeProto.Tensor.encode( message.tensorType, writer.uint32(/* id 1, wireType 2 =*/ 10).fork(), ).ldelim(); - if (message.sequenceType != null && Object.hasOwnProperty.call(message, 'sequenceType')) + if ( + message.sequenceType != null && + Object.hasOwnProperty.call(message, "sequenceType") + ) $root.onnx.TypeProto.Sequence.encode( message.sequenceType, writer.uint32(/* id 4, wireType 2 =*/ 34).fork(), ).ldelim(); - if (message.mapType != null && Object.hasOwnProperty.call(message, 'mapType')) - $root.onnx.TypeProto.Map.encode(message.mapType, writer.uint32(/* id 5, wireType 2 =*/ 42).fork()).ldelim(); - if (message.denotation != null && Object.hasOwnProperty.call(message, 'denotation')) + if ( + message.mapType != null && + Object.hasOwnProperty.call(message, "mapType") + ) + $root.onnx.TypeProto.Map.encode( + message.mapType, + writer.uint32(/* id 5, wireType 2 =*/ 42).fork(), + ).ldelim(); + if ( + message.denotation != null && + Object.hasOwnProperty.call(message, "denotation") + ) writer.uint32(/* id 6, wireType 2 =*/ 50).string(message.denotation); - if (message.sparseTensorType != null && Object.hasOwnProperty.call(message, 'sparseTensorType')) + if ( + message.sparseTensorType != null && + Object.hasOwnProperty.call(message, "sparseTensorType") + ) $root.onnx.TypeProto.SparseTensor.encode( message.sparseTensorType, writer.uint32(/* id 8, wireType 2 =*/ 66).fork(), ).ldelim(); - if (message.optionalType != null && Object.hasOwnProperty.call(message, 'optionalType')) + if ( + message.optionalType != null && + Object.hasOwnProperty.call(message, "optionalType") + ) $root.onnx.TypeProto.Optional.encode( message.optionalType, writer.uint32(/* id 9, wireType 2 =*/ 74).fork(), @@ -5415,23 +6326,38 @@ $root.onnx = (function () { var tag = reader.uint32(); switch (tag >>> 3) { case 1: { - message.tensorType = $root.onnx.TypeProto.Tensor.decode(reader, reader.uint32()); + message.tensorType = $root.onnx.TypeProto.Tensor.decode( + reader, + reader.uint32(), + ); break; } case 4: { - message.sequenceType = $root.onnx.TypeProto.Sequence.decode(reader, reader.uint32()); + message.sequenceType = $root.onnx.TypeProto.Sequence.decode( + reader, + reader.uint32(), + ); break; } case 5: { - message.mapType = $root.onnx.TypeProto.Map.decode(reader, reader.uint32()); + message.mapType = $root.onnx.TypeProto.Map.decode( + reader, + reader.uint32(), + ); break; } case 9: { - message.optionalType = $root.onnx.TypeProto.Optional.decode(reader, reader.uint32()); + message.optionalType = $root.onnx.TypeProto.Optional.decode( + reader, + reader.uint32(), + ); break; } case 8: { - message.sparseTensorType = $root.onnx.TypeProto.SparseTensor.decode(reader, reader.uint32()); + message.sparseTensorType = $root.onnx.TypeProto.SparseTensor.decode( + reader, + reader.uint32(), + ); break; } case 6: { @@ -5470,49 +6396,66 @@ $root.onnx = (function () { * @returns {string|null} `null` if valid, otherwise the reason why it is not */ TypeProto.verify = function verify(message) { - if (typeof message !== 'object' || message === null) return 'object expected'; + if (typeof message !== "object" || message === null) + return "object expected"; var properties = {}; - if (message.tensorType != null && message.hasOwnProperty('tensorType')) { + if (message.tensorType != null && message.hasOwnProperty("tensorType")) { properties.value = 1; { var error = $root.onnx.TypeProto.Tensor.verify(message.tensorType); - if (error) return 'tensorType.' + error; + if (error) return "tensorType." + error; } } - if (message.sequenceType != null && message.hasOwnProperty('sequenceType')) { - if (properties.value === 1) return 'value: multiple values'; + if ( + message.sequenceType != null && + message.hasOwnProperty("sequenceType") + ) { + if (properties.value === 1) return "value: multiple values"; properties.value = 1; { - var error = $root.onnx.TypeProto.Sequence.verify(message.sequenceType); - if (error) return 'sequenceType.' + error; + var error = $root.onnx.TypeProto.Sequence.verify( + message.sequenceType, + ); + if (error) return "sequenceType." + error; } } - if (message.mapType != null && message.hasOwnProperty('mapType')) { - if (properties.value === 1) return 'value: multiple values'; + if (message.mapType != null && message.hasOwnProperty("mapType")) { + if (properties.value === 1) return "value: multiple values"; properties.value = 1; { var error = $root.onnx.TypeProto.Map.verify(message.mapType); - if (error) return 'mapType.' + error; + if (error) return "mapType." + error; } } - if (message.optionalType != null && message.hasOwnProperty('optionalType')) { - if (properties.value === 1) return 'value: multiple values'; + if ( + message.optionalType != null && + message.hasOwnProperty("optionalType") + ) { + if (properties.value === 1) return "value: multiple values"; properties.value = 1; { - var error = $root.onnx.TypeProto.Optional.verify(message.optionalType); - if (error) return 'optionalType.' + error; + var error = $root.onnx.TypeProto.Optional.verify( + message.optionalType, + ); + if (error) return "optionalType." + error; } } - if (message.sparseTensorType != null && message.hasOwnProperty('sparseTensorType')) { - if (properties.value === 1) return 'value: multiple values'; + if ( + message.sparseTensorType != null && + message.hasOwnProperty("sparseTensorType") + ) { + if (properties.value === 1) return "value: multiple values"; properties.value = 1; { - var error = $root.onnx.TypeProto.SparseTensor.verify(message.sparseTensorType); - if (error) return 'sparseTensorType.' + error; + var error = $root.onnx.TypeProto.SparseTensor.verify( + message.sparseTensorType, + ); + if (error) return "sparseTensorType." + error; } } - if (message.denotation != null && message.hasOwnProperty('denotation')) - if (!$util.isString(message.denotation)) return 'denotation: string expected'; + if (message.denotation != null && message.hasOwnProperty("denotation")) + if (!$util.isString(message.denotation)) + return "denotation: string expected"; return null; }; @@ -5528,27 +6471,40 @@ $root.onnx = (function () { if (object instanceof $root.onnx.TypeProto) return object; var message = new $root.onnx.TypeProto(); if (object.tensorType != null) { - if (typeof object.tensorType !== 'object') throw TypeError('.onnx.TypeProto.tensorType: object expected'); - message.tensorType = $root.onnx.TypeProto.Tensor.fromObject(object.tensorType); + if (typeof object.tensorType !== "object") + throw TypeError(".onnx.TypeProto.tensorType: object expected"); + message.tensorType = $root.onnx.TypeProto.Tensor.fromObject( + object.tensorType, + ); } if (object.sequenceType != null) { - if (typeof object.sequenceType !== 'object') throw TypeError('.onnx.TypeProto.sequenceType: object expected'); - message.sequenceType = $root.onnx.TypeProto.Sequence.fromObject(object.sequenceType); + if (typeof object.sequenceType !== "object") + throw TypeError(".onnx.TypeProto.sequenceType: object expected"); + message.sequenceType = $root.onnx.TypeProto.Sequence.fromObject( + object.sequenceType, + ); } if (object.mapType != null) { - if (typeof object.mapType !== 'object') throw TypeError('.onnx.TypeProto.mapType: object expected'); + if (typeof object.mapType !== "object") + throw TypeError(".onnx.TypeProto.mapType: object expected"); message.mapType = $root.onnx.TypeProto.Map.fromObject(object.mapType); } if (object.optionalType != null) { - if (typeof object.optionalType !== 'object') throw TypeError('.onnx.TypeProto.optionalType: object expected'); - message.optionalType = $root.onnx.TypeProto.Optional.fromObject(object.optionalType); + if (typeof object.optionalType !== "object") + throw TypeError(".onnx.TypeProto.optionalType: object expected"); + message.optionalType = $root.onnx.TypeProto.Optional.fromObject( + object.optionalType, + ); } if (object.sparseTensorType != null) { - if (typeof object.sparseTensorType !== 'object') - throw TypeError('.onnx.TypeProto.sparseTensorType: object expected'); - message.sparseTensorType = $root.onnx.TypeProto.SparseTensor.fromObject(object.sparseTensorType); - } - if (object.denotation != null) message.denotation = String(object.denotation); + if (typeof object.sparseTensorType !== "object") + throw TypeError(".onnx.TypeProto.sparseTensorType: object expected"); + message.sparseTensorType = $root.onnx.TypeProto.SparseTensor.fromObject( + object.sparseTensorType, + ); + } + if (object.denotation != null) + message.denotation = String(object.denotation); return message; }; @@ -5564,27 +6520,52 @@ $root.onnx = (function () { TypeProto.toObject = function toObject(message, options) { if (!options) options = {}; var object = {}; - if (options.defaults) object.denotation = ''; - if (message.tensorType != null && message.hasOwnProperty('tensorType')) { - object.tensorType = $root.onnx.TypeProto.Tensor.toObject(message.tensorType, options); - if (options.oneofs) object.value = 'tensorType'; - } - if (message.sequenceType != null && message.hasOwnProperty('sequenceType')) { - object.sequenceType = $root.onnx.TypeProto.Sequence.toObject(message.sequenceType, options); - if (options.oneofs) object.value = 'sequenceType'; - } - if (message.mapType != null && message.hasOwnProperty('mapType')) { - object.mapType = $root.onnx.TypeProto.Map.toObject(message.mapType, options); - if (options.oneofs) object.value = 'mapType'; - } - if (message.denotation != null && message.hasOwnProperty('denotation')) object.denotation = message.denotation; - if (message.sparseTensorType != null && message.hasOwnProperty('sparseTensorType')) { - object.sparseTensorType = $root.onnx.TypeProto.SparseTensor.toObject(message.sparseTensorType, options); - if (options.oneofs) object.value = 'sparseTensorType'; - } - if (message.optionalType != null && message.hasOwnProperty('optionalType')) { - object.optionalType = $root.onnx.TypeProto.Optional.toObject(message.optionalType, options); - if (options.oneofs) object.value = 'optionalType'; + if (options.defaults) object.denotation = ""; + if (message.tensorType != null && message.hasOwnProperty("tensorType")) { + object.tensorType = $root.onnx.TypeProto.Tensor.toObject( + message.tensorType, + options, + ); + if (options.oneofs) object.value = "tensorType"; + } + if ( + message.sequenceType != null && + message.hasOwnProperty("sequenceType") + ) { + object.sequenceType = $root.onnx.TypeProto.Sequence.toObject( + message.sequenceType, + options, + ); + if (options.oneofs) object.value = "sequenceType"; + } + if (message.mapType != null && message.hasOwnProperty("mapType")) { + object.mapType = $root.onnx.TypeProto.Map.toObject( + message.mapType, + options, + ); + if (options.oneofs) object.value = "mapType"; + } + if (message.denotation != null && message.hasOwnProperty("denotation")) + object.denotation = message.denotation; + if ( + message.sparseTensorType != null && + message.hasOwnProperty("sparseTensorType") + ) { + object.sparseTensorType = $root.onnx.TypeProto.SparseTensor.toObject( + message.sparseTensorType, + options, + ); + if (options.oneofs) object.value = "sparseTensorType"; + } + if ( + message.optionalType != null && + message.hasOwnProperty("optionalType") + ) { + object.optionalType = $root.onnx.TypeProto.Optional.toObject( + message.optionalType, + options, + ); + if (options.oneofs) object.value = "optionalType"; } return object; }; @@ -5610,9 +6591,9 @@ $root.onnx = (function () { */ TypeProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { if (typeUrlPrefix === undefined) { - typeUrlPrefix = 'type.googleapis.com'; + typeUrlPrefix = "type.googleapis.com"; } - return typeUrlPrefix + '/onnx.TypeProto'; + return typeUrlPrefix + "/onnx.TypeProto"; }; TypeProto.Tensor = (function () { @@ -5635,7 +6616,8 @@ $root.onnx = (function () { function Tensor(properties) { if (properties) for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) - if (properties[keys[i]] != null) this[keys[i]] = properties[keys[i]]; + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; } /** @@ -5677,10 +6659,19 @@ $root.onnx = (function () { */ Tensor.encode = function encode(message, writer) { if (!writer) writer = $Writer.create(); - if (message.elemType != null && Object.hasOwnProperty.call(message, 'elemType')) + if ( + message.elemType != null && + Object.hasOwnProperty.call(message, "elemType") + ) writer.uint32(/* id 1, wireType 0 =*/ 8).int32(message.elemType); - if (message.shape != null && Object.hasOwnProperty.call(message, 'shape')) - $root.onnx.TensorShapeProto.encode(message.shape, writer.uint32(/* id 2, wireType 2 =*/ 18).fork()).ldelim(); + if ( + message.shape != null && + Object.hasOwnProperty.call(message, "shape") + ) + $root.onnx.TensorShapeProto.encode( + message.shape, + writer.uint32(/* id 2, wireType 2 =*/ 18).fork(), + ).ldelim(); return writer; }; @@ -5720,7 +6711,10 @@ $root.onnx = (function () { break; } case 2: { - message.shape = $root.onnx.TensorShapeProto.decode(reader, reader.uint32()); + message.shape = $root.onnx.TensorShapeProto.decode( + reader, + reader.uint32(), + ); break; } default: @@ -5755,12 +6749,14 @@ $root.onnx = (function () { * @returns {string|null} `null` if valid, otherwise the reason why it is not */ Tensor.verify = function verify(message) { - if (typeof message !== 'object' || message === null) return 'object expected'; - if (message.elemType != null && message.hasOwnProperty('elemType')) - if (!$util.isInteger(message.elemType)) return 'elemType: integer expected'; - if (message.shape != null && message.hasOwnProperty('shape')) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.elemType != null && message.hasOwnProperty("elemType")) + if (!$util.isInteger(message.elemType)) + return "elemType: integer expected"; + if (message.shape != null && message.hasOwnProperty("shape")) { var error = $root.onnx.TensorShapeProto.verify(message.shape); - if (error) return 'shape.' + error; + if (error) return "shape." + error; } return null; }; @@ -5778,7 +6774,8 @@ $root.onnx = (function () { var message = new $root.onnx.TypeProto.Tensor(); if (object.elemType != null) message.elemType = object.elemType | 0; if (object.shape != null) { - if (typeof object.shape !== 'object') throw TypeError('.onnx.TypeProto.Tensor.shape: object expected'); + if (typeof object.shape !== "object") + throw TypeError(".onnx.TypeProto.Tensor.shape: object expected"); message.shape = $root.onnx.TensorShapeProto.fromObject(object.shape); } return message; @@ -5800,9 +6797,13 @@ $root.onnx = (function () { object.elemType = 0; object.shape = null; } - if (message.elemType != null && message.hasOwnProperty('elemType')) object.elemType = message.elemType; - if (message.shape != null && message.hasOwnProperty('shape')) - object.shape = $root.onnx.TensorShapeProto.toObject(message.shape, options); + if (message.elemType != null && message.hasOwnProperty("elemType")) + object.elemType = message.elemType; + if (message.shape != null && message.hasOwnProperty("shape")) + object.shape = $root.onnx.TensorShapeProto.toObject( + message.shape, + options, + ); return object; }; @@ -5827,9 +6828,9 @@ $root.onnx = (function () { */ Tensor.getTypeUrl = function getTypeUrl(typeUrlPrefix) { if (typeUrlPrefix === undefined) { - typeUrlPrefix = 'type.googleapis.com'; + typeUrlPrefix = "type.googleapis.com"; } - return typeUrlPrefix + '/onnx.TypeProto.Tensor'; + return typeUrlPrefix + "/onnx.TypeProto.Tensor"; }; return Tensor; @@ -5854,7 +6855,8 @@ $root.onnx = (function () { function Sequence(properties) { if (properties) for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) - if (properties[keys[i]] != null) this[keys[i]] = properties[keys[i]]; + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; } /** @@ -5888,8 +6890,14 @@ $root.onnx = (function () { */ Sequence.encode = function encode(message, writer) { if (!writer) writer = $Writer.create(); - if (message.elemType != null && Object.hasOwnProperty.call(message, 'elemType')) - $root.onnx.TypeProto.encode(message.elemType, writer.uint32(/* id 1, wireType 2 =*/ 10).fork()).ldelim(); + if ( + message.elemType != null && + Object.hasOwnProperty.call(message, "elemType") + ) + $root.onnx.TypeProto.encode( + message.elemType, + writer.uint32(/* id 1, wireType 2 =*/ 10).fork(), + ).ldelim(); return writer; }; @@ -5925,7 +6933,10 @@ $root.onnx = (function () { var tag = reader.uint32(); switch (tag >>> 3) { case 1: { - message.elemType = $root.onnx.TypeProto.decode(reader, reader.uint32()); + message.elemType = $root.onnx.TypeProto.decode( + reader, + reader.uint32(), + ); break; } default: @@ -5960,10 +6971,11 @@ $root.onnx = (function () { * @returns {string|null} `null` if valid, otherwise the reason why it is not */ Sequence.verify = function verify(message) { - if (typeof message !== 'object' || message === null) return 'object expected'; - if (message.elemType != null && message.hasOwnProperty('elemType')) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.elemType != null && message.hasOwnProperty("elemType")) { var error = $root.onnx.TypeProto.verify(message.elemType); - if (error) return 'elemType.' + error; + if (error) return "elemType." + error; } return null; }; @@ -5980,8 +6992,10 @@ $root.onnx = (function () { if (object instanceof $root.onnx.TypeProto.Sequence) return object; var message = new $root.onnx.TypeProto.Sequence(); if (object.elemType != null) { - if (typeof object.elemType !== 'object') - throw TypeError('.onnx.TypeProto.Sequence.elemType: object expected'); + if (typeof object.elemType !== "object") + throw TypeError( + ".onnx.TypeProto.Sequence.elemType: object expected", + ); message.elemType = $root.onnx.TypeProto.fromObject(object.elemType); } return message; @@ -6000,8 +7014,11 @@ $root.onnx = (function () { if (!options) options = {}; var object = {}; if (options.defaults) object.elemType = null; - if (message.elemType != null && message.hasOwnProperty('elemType')) - object.elemType = $root.onnx.TypeProto.toObject(message.elemType, options); + if (message.elemType != null && message.hasOwnProperty("elemType")) + object.elemType = $root.onnx.TypeProto.toObject( + message.elemType, + options, + ); return object; }; @@ -6026,9 +7043,9 @@ $root.onnx = (function () { */ Sequence.getTypeUrl = function getTypeUrl(typeUrlPrefix) { if (typeUrlPrefix === undefined) { - typeUrlPrefix = 'type.googleapis.com'; + typeUrlPrefix = "type.googleapis.com"; } - return typeUrlPrefix + '/onnx.TypeProto.Sequence'; + return typeUrlPrefix + "/onnx.TypeProto.Sequence"; }; return Sequence; @@ -6054,7 +7071,8 @@ $root.onnx = (function () { function Map(properties) { if (properties) for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) - if (properties[keys[i]] != null) this[keys[i]] = properties[keys[i]]; + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; } /** @@ -6096,10 +7114,19 @@ $root.onnx = (function () { */ Map.encode = function encode(message, writer) { if (!writer) writer = $Writer.create(); - if (message.keyType != null && Object.hasOwnProperty.call(message, 'keyType')) + if ( + message.keyType != null && + Object.hasOwnProperty.call(message, "keyType") + ) writer.uint32(/* id 1, wireType 0 =*/ 8).int32(message.keyType); - if (message.valueType != null && Object.hasOwnProperty.call(message, 'valueType')) - $root.onnx.TypeProto.encode(message.valueType, writer.uint32(/* id 2, wireType 2 =*/ 18).fork()).ldelim(); + if ( + message.valueType != null && + Object.hasOwnProperty.call(message, "valueType") + ) + $root.onnx.TypeProto.encode( + message.valueType, + writer.uint32(/* id 2, wireType 2 =*/ 18).fork(), + ).ldelim(); return writer; }; @@ -6139,7 +7166,10 @@ $root.onnx = (function () { break; } case 2: { - message.valueType = $root.onnx.TypeProto.decode(reader, reader.uint32()); + message.valueType = $root.onnx.TypeProto.decode( + reader, + reader.uint32(), + ); break; } default: @@ -6174,12 +7204,14 @@ $root.onnx = (function () { * @returns {string|null} `null` if valid, otherwise the reason why it is not */ Map.verify = function verify(message) { - if (typeof message !== 'object' || message === null) return 'object expected'; - if (message.keyType != null && message.hasOwnProperty('keyType')) - if (!$util.isInteger(message.keyType)) return 'keyType: integer expected'; - if (message.valueType != null && message.hasOwnProperty('valueType')) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.keyType != null && message.hasOwnProperty("keyType")) + if (!$util.isInteger(message.keyType)) + return "keyType: integer expected"; + if (message.valueType != null && message.hasOwnProperty("valueType")) { var error = $root.onnx.TypeProto.verify(message.valueType); - if (error) return 'valueType.' + error; + if (error) return "valueType." + error; } return null; }; @@ -6197,7 +7229,8 @@ $root.onnx = (function () { var message = new $root.onnx.TypeProto.Map(); if (object.keyType != null) message.keyType = object.keyType | 0; if (object.valueType != null) { - if (typeof object.valueType !== 'object') throw TypeError('.onnx.TypeProto.Map.valueType: object expected'); + if (typeof object.valueType !== "object") + throw TypeError(".onnx.TypeProto.Map.valueType: object expected"); message.valueType = $root.onnx.TypeProto.fromObject(object.valueType); } return message; @@ -6219,9 +7252,13 @@ $root.onnx = (function () { object.keyType = 0; object.valueType = null; } - if (message.keyType != null && message.hasOwnProperty('keyType')) object.keyType = message.keyType; - if (message.valueType != null && message.hasOwnProperty('valueType')) - object.valueType = $root.onnx.TypeProto.toObject(message.valueType, options); + if (message.keyType != null && message.hasOwnProperty("keyType")) + object.keyType = message.keyType; + if (message.valueType != null && message.hasOwnProperty("valueType")) + object.valueType = $root.onnx.TypeProto.toObject( + message.valueType, + options, + ); return object; }; @@ -6246,9 +7283,9 @@ $root.onnx = (function () { */ Map.getTypeUrl = function getTypeUrl(typeUrlPrefix) { if (typeUrlPrefix === undefined) { - typeUrlPrefix = 'type.googleapis.com'; + typeUrlPrefix = "type.googleapis.com"; } - return typeUrlPrefix + '/onnx.TypeProto.Map'; + return typeUrlPrefix + "/onnx.TypeProto.Map"; }; return Map; @@ -6273,7 +7310,8 @@ $root.onnx = (function () { function Optional(properties) { if (properties) for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) - if (properties[keys[i]] != null) this[keys[i]] = properties[keys[i]]; + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; } /** @@ -6307,8 +7345,14 @@ $root.onnx = (function () { */ Optional.encode = function encode(message, writer) { if (!writer) writer = $Writer.create(); - if (message.elemType != null && Object.hasOwnProperty.call(message, 'elemType')) - $root.onnx.TypeProto.encode(message.elemType, writer.uint32(/* id 1, wireType 2 =*/ 10).fork()).ldelim(); + if ( + message.elemType != null && + Object.hasOwnProperty.call(message, "elemType") + ) + $root.onnx.TypeProto.encode( + message.elemType, + writer.uint32(/* id 1, wireType 2 =*/ 10).fork(), + ).ldelim(); return writer; }; @@ -6344,7 +7388,10 @@ $root.onnx = (function () { var tag = reader.uint32(); switch (tag >>> 3) { case 1: { - message.elemType = $root.onnx.TypeProto.decode(reader, reader.uint32()); + message.elemType = $root.onnx.TypeProto.decode( + reader, + reader.uint32(), + ); break; } default: @@ -6379,10 +7426,11 @@ $root.onnx = (function () { * @returns {string|null} `null` if valid, otherwise the reason why it is not */ Optional.verify = function verify(message) { - if (typeof message !== 'object' || message === null) return 'object expected'; - if (message.elemType != null && message.hasOwnProperty('elemType')) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.elemType != null && message.hasOwnProperty("elemType")) { var error = $root.onnx.TypeProto.verify(message.elemType); - if (error) return 'elemType.' + error; + if (error) return "elemType." + error; } return null; }; @@ -6399,8 +7447,10 @@ $root.onnx = (function () { if (object instanceof $root.onnx.TypeProto.Optional) return object; var message = new $root.onnx.TypeProto.Optional(); if (object.elemType != null) { - if (typeof object.elemType !== 'object') - throw TypeError('.onnx.TypeProto.Optional.elemType: object expected'); + if (typeof object.elemType !== "object") + throw TypeError( + ".onnx.TypeProto.Optional.elemType: object expected", + ); message.elemType = $root.onnx.TypeProto.fromObject(object.elemType); } return message; @@ -6419,8 +7469,11 @@ $root.onnx = (function () { if (!options) options = {}; var object = {}; if (options.defaults) object.elemType = null; - if (message.elemType != null && message.hasOwnProperty('elemType')) - object.elemType = $root.onnx.TypeProto.toObject(message.elemType, options); + if (message.elemType != null && message.hasOwnProperty("elemType")) + object.elemType = $root.onnx.TypeProto.toObject( + message.elemType, + options, + ); return object; }; @@ -6445,9 +7498,9 @@ $root.onnx = (function () { */ Optional.getTypeUrl = function getTypeUrl(typeUrlPrefix) { if (typeUrlPrefix === undefined) { - typeUrlPrefix = 'type.googleapis.com'; + typeUrlPrefix = "type.googleapis.com"; } - return typeUrlPrefix + '/onnx.TypeProto.Optional'; + return typeUrlPrefix + "/onnx.TypeProto.Optional"; }; return Optional; @@ -6473,7 +7526,8 @@ $root.onnx = (function () { function SparseTensor(properties) { if (properties) for (var keys = Object.keys(properties), i = 0; i < keys.length; ++i) - if (properties[keys[i]] != null) this[keys[i]] = properties[keys[i]]; + if (properties[keys[i]] != null) + this[keys[i]] = properties[keys[i]]; } /** @@ -6515,10 +7569,19 @@ $root.onnx = (function () { */ SparseTensor.encode = function encode(message, writer) { if (!writer) writer = $Writer.create(); - if (message.elemType != null && Object.hasOwnProperty.call(message, 'elemType')) + if ( + message.elemType != null && + Object.hasOwnProperty.call(message, "elemType") + ) writer.uint32(/* id 1, wireType 0 =*/ 8).int32(message.elemType); - if (message.shape != null && Object.hasOwnProperty.call(message, 'shape')) - $root.onnx.TensorShapeProto.encode(message.shape, writer.uint32(/* id 2, wireType 2 =*/ 18).fork()).ldelim(); + if ( + message.shape != null && + Object.hasOwnProperty.call(message, "shape") + ) + $root.onnx.TensorShapeProto.encode( + message.shape, + writer.uint32(/* id 2, wireType 2 =*/ 18).fork(), + ).ldelim(); return writer; }; @@ -6558,7 +7621,10 @@ $root.onnx = (function () { break; } case 2: { - message.shape = $root.onnx.TensorShapeProto.decode(reader, reader.uint32()); + message.shape = $root.onnx.TensorShapeProto.decode( + reader, + reader.uint32(), + ); break; } default: @@ -6593,12 +7659,14 @@ $root.onnx = (function () { * @returns {string|null} `null` if valid, otherwise the reason why it is not */ SparseTensor.verify = function verify(message) { - if (typeof message !== 'object' || message === null) return 'object expected'; - if (message.elemType != null && message.hasOwnProperty('elemType')) - if (!$util.isInteger(message.elemType)) return 'elemType: integer expected'; - if (message.shape != null && message.hasOwnProperty('shape')) { + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.elemType != null && message.hasOwnProperty("elemType")) + if (!$util.isInteger(message.elemType)) + return "elemType: integer expected"; + if (message.shape != null && message.hasOwnProperty("shape")) { var error = $root.onnx.TensorShapeProto.verify(message.shape); - if (error) return 'shape.' + error; + if (error) return "shape." + error; } return null; }; @@ -6616,7 +7684,10 @@ $root.onnx = (function () { var message = new $root.onnx.TypeProto.SparseTensor(); if (object.elemType != null) message.elemType = object.elemType | 0; if (object.shape != null) { - if (typeof object.shape !== 'object') throw TypeError('.onnx.TypeProto.SparseTensor.shape: object expected'); + if (typeof object.shape !== "object") + throw TypeError( + ".onnx.TypeProto.SparseTensor.shape: object expected", + ); message.shape = $root.onnx.TensorShapeProto.fromObject(object.shape); } return message; @@ -6638,9 +7709,13 @@ $root.onnx = (function () { object.elemType = 0; object.shape = null; } - if (message.elemType != null && message.hasOwnProperty('elemType')) object.elemType = message.elemType; - if (message.shape != null && message.hasOwnProperty('shape')) - object.shape = $root.onnx.TensorShapeProto.toObject(message.shape, options); + if (message.elemType != null && message.hasOwnProperty("elemType")) + object.elemType = message.elemType; + if (message.shape != null && message.hasOwnProperty("shape")) + object.shape = $root.onnx.TensorShapeProto.toObject( + message.shape, + options, + ); return object; }; @@ -6665,9 +7740,9 @@ $root.onnx = (function () { */ SparseTensor.getTypeUrl = function getTypeUrl(typeUrlPrefix) { if (typeUrlPrefix === undefined) { - typeUrlPrefix = 'type.googleapis.com'; + typeUrlPrefix = "type.googleapis.com"; } - return typeUrlPrefix + '/onnx.TypeProto.SparseTensor'; + return typeUrlPrefix + "/onnx.TypeProto.SparseTensor"; }; return SparseTensor; @@ -6705,7 +7780,7 @@ $root.onnx = (function () { * @memberof onnx.OperatorSetIdProto * @instance */ - OperatorSetIdProto.prototype.domain = ''; + OperatorSetIdProto.prototype.domain = ""; /** * OperatorSetIdProto version. @@ -6713,7 +7788,9 @@ $root.onnx = (function () { * @memberof onnx.OperatorSetIdProto * @instance */ - OperatorSetIdProto.prototype.version = $util.Long ? $util.Long.fromBits(0, 0, false) : 0; + OperatorSetIdProto.prototype.version = $util.Long + ? $util.Long.fromBits(0, 0, false) + : 0; /** * Creates a new OperatorSetIdProto instance using the specified properties. @@ -6738,9 +7815,15 @@ $root.onnx = (function () { */ OperatorSetIdProto.encode = function encode(message, writer) { if (!writer) writer = $Writer.create(); - if (message.domain != null && Object.hasOwnProperty.call(message, 'domain')) + if ( + message.domain != null && + Object.hasOwnProperty.call(message, "domain") + ) writer.uint32(/* id 1, wireType 2 =*/ 10).string(message.domain); - if (message.version != null && Object.hasOwnProperty.call(message, 'version')) + if ( + message.version != null && + Object.hasOwnProperty.call(message, "version") + ) writer.uint32(/* id 2, wireType 0 =*/ 16).int64(message.version); return writer; }; @@ -6754,7 +7837,10 @@ $root.onnx = (function () { * @param {$protobuf.Writer} [writer] Writer to encode to * @returns {$protobuf.Writer} Writer */ - OperatorSetIdProto.encodeDelimited = function encodeDelimited(message, writer) { + OperatorSetIdProto.encodeDelimited = function encodeDelimited( + message, + writer, + ) { return this.encode(message, writer).ldelim(); }; @@ -6816,15 +7902,20 @@ $root.onnx = (function () { * @returns {string|null} `null` if valid, otherwise the reason why it is not */ OperatorSetIdProto.verify = function verify(message) { - if (typeof message !== 'object' || message === null) return 'object expected'; - if (message.domain != null && message.hasOwnProperty('domain')) - if (!$util.isString(message.domain)) return 'domain: string expected'; - if (message.version != null && message.hasOwnProperty('version')) + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.domain != null && message.hasOwnProperty("domain")) + if (!$util.isString(message.domain)) return "domain: string expected"; + if (message.version != null && message.hasOwnProperty("version")) if ( !$util.isInteger(message.version) && - !(message.version && $util.isInteger(message.version.low) && $util.isInteger(message.version.high)) + !( + message.version && + $util.isInteger(message.version.low) && + $util.isInteger(message.version.high) + ) ) - return 'version: integer|Long expected'; + return "version: integer|Long expected"; return null; }; @@ -6841,11 +7932,18 @@ $root.onnx = (function () { var message = new $root.onnx.OperatorSetIdProto(); if (object.domain != null) message.domain = String(object.domain); if (object.version != null) - if ($util.Long) (message.version = $util.Long.fromValue(object.version)).unsigned = false; - else if (typeof object.version === 'string') message.version = parseInt(object.version, 10); - else if (typeof object.version === 'number') message.version = object.version; - else if (typeof object.version === 'object') - message.version = new $util.LongBits(object.version.low >>> 0, object.version.high >>> 0).toNumber(); + if ($util.Long) + (message.version = $util.Long.fromValue(object.version)).unsigned = + false; + else if (typeof object.version === "string") + message.version = parseInt(object.version, 10); + else if (typeof object.version === "number") + message.version = object.version; + else if (typeof object.version === "object") + message.version = new $util.LongBits( + object.version.low >>> 0, + object.version.high >>> 0, + ).toNumber(); return message; }; @@ -6862,23 +7960,34 @@ $root.onnx = (function () { if (!options) options = {}; var object = {}; if (options.defaults) { - object.domain = ''; + object.domain = ""; if ($util.Long) { var long = new $util.Long(0, 0, false); object.version = - options.longs === String ? long.toString() : options.longs === Number ? long.toNumber() : long; - } else object.version = options.longs === String ? '0' : 0; - } - if (message.domain != null && message.hasOwnProperty('domain')) object.domain = message.domain; - if (message.version != null && message.hasOwnProperty('version')) - if (typeof message.version === 'number') - object.version = options.longs === String ? String(message.version) : message.version; + options.longs === String + ? long.toString() + : options.longs === Number + ? long.toNumber() + : long; + } else object.version = options.longs === String ? "0" : 0; + } + if (message.domain != null && message.hasOwnProperty("domain")) + object.domain = message.domain; + if (message.version != null && message.hasOwnProperty("version")) + if (typeof message.version === "number") + object.version = + options.longs === String + ? String(message.version) + : message.version; else object.version = options.longs === String ? $util.Long.prototype.toString.call(message.version) : options.longs === Number - ? new $util.LongBits(message.version.low >>> 0, message.version.high >>> 0).toNumber() + ? new $util.LongBits( + message.version.low >>> 0, + message.version.high >>> 0, + ).toNumber() : message.version; return object; }; @@ -6904,9 +8013,9 @@ $root.onnx = (function () { */ OperatorSetIdProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { if (typeUrlPrefix === undefined) { - typeUrlPrefix = 'type.googleapis.com'; + typeUrlPrefix = "type.googleapis.com"; } - return typeUrlPrefix + '/onnx.OperatorSetIdProto'; + return typeUrlPrefix + "/onnx.OperatorSetIdProto"; }; return OperatorSetIdProto; @@ -6922,8 +8031,8 @@ $root.onnx = (function () { onnx.OperatorStatus = (function () { var valuesById = {}, values = Object.create(valuesById); - values[(valuesById[0] = 'EXPERIMENTAL')] = 0; - values[(valuesById[1] = 'STABLE')] = 1; + values[(valuesById[0] = "EXPERIMENTAL")] = 0; + values[(valuesById[1] = "STABLE")] = 1; return values; })(); @@ -6969,7 +8078,7 @@ $root.onnx = (function () { * @memberof onnx.FunctionProto * @instance */ - FunctionProto.prototype.name = ''; + FunctionProto.prototype.name = ""; /** * FunctionProto input. @@ -7017,7 +8126,7 @@ $root.onnx = (function () { * @memberof onnx.FunctionProto * @instance */ - FunctionProto.prototype.docString = ''; + FunctionProto.prototype.docString = ""; /** * FunctionProto opsetImport. @@ -7033,7 +8142,7 @@ $root.onnx = (function () { * @memberof onnx.FunctionProto * @instance */ - FunctionProto.prototype.domain = ''; + FunctionProto.prototype.domain = ""; /** * Creates a new FunctionProto instance using the specified properties. @@ -7058,7 +8167,7 @@ $root.onnx = (function () { */ FunctionProto.encode = function encode(message, writer) { if (!writer) writer = $Writer.create(); - if (message.name != null && Object.hasOwnProperty.call(message, 'name')) + if (message.name != null && Object.hasOwnProperty.call(message, "name")) writer.uint32(/* id 1, wireType 2 =*/ 10).string(message.name); if (message.input != null && message.input.length) for (var i = 0; i < message.input.length; ++i) @@ -7068,11 +8177,19 @@ $root.onnx = (function () { writer.uint32(/* id 5, wireType 2 =*/ 42).string(message.output[i]); if (message.attribute != null && message.attribute.length) for (var i = 0; i < message.attribute.length; ++i) - writer.uint32(/* id 6, wireType 2 =*/ 50).string(message.attribute[i]); + writer + .uint32(/* id 6, wireType 2 =*/ 50) + .string(message.attribute[i]); if (message.node != null && message.node.length) for (var i = 0; i < message.node.length; ++i) - $root.onnx.NodeProto.encode(message.node[i], writer.uint32(/* id 7, wireType 2 =*/ 58).fork()).ldelim(); - if (message.docString != null && Object.hasOwnProperty.call(message, 'docString')) + $root.onnx.NodeProto.encode( + message.node[i], + writer.uint32(/* id 7, wireType 2 =*/ 58).fork(), + ).ldelim(); + if ( + message.docString != null && + Object.hasOwnProperty.call(message, "docString") + ) writer.uint32(/* id 8, wireType 2 =*/ 66).string(message.docString); if (message.opsetImport != null && message.opsetImport.length) for (var i = 0; i < message.opsetImport.length; ++i) @@ -7080,7 +8197,10 @@ $root.onnx = (function () { message.opsetImport[i], writer.uint32(/* id 9, wireType 2 =*/ 74).fork(), ).ldelim(); - if (message.domain != null && Object.hasOwnProperty.call(message, 'domain')) + if ( + message.domain != null && + Object.hasOwnProperty.call(message, "domain") + ) writer.uint32(/* id 10, wireType 2 =*/ 82).string(message.domain); if (message.attributeProto != null && message.attributeProto.length) for (var i = 0; i < message.attributeProto.length; ++i) @@ -7137,18 +8257,24 @@ $root.onnx = (function () { break; } case 6: { - if (!(message.attribute && message.attribute.length)) message.attribute = []; + if (!(message.attribute && message.attribute.length)) + message.attribute = []; message.attribute.push(reader.string()); break; } case 11: { - if (!(message.attributeProto && message.attributeProto.length)) message.attributeProto = []; - message.attributeProto.push($root.onnx.AttributeProto.decode(reader, reader.uint32())); + if (!(message.attributeProto && message.attributeProto.length)) + message.attributeProto = []; + message.attributeProto.push( + $root.onnx.AttributeProto.decode(reader, reader.uint32()), + ); break; } case 7: { if (!(message.node && message.node.length)) message.node = []; - message.node.push($root.onnx.NodeProto.decode(reader, reader.uint32())); + message.node.push( + $root.onnx.NodeProto.decode(reader, reader.uint32()), + ); break; } case 8: { @@ -7156,8 +8282,11 @@ $root.onnx = (function () { break; } case 9: { - if (!(message.opsetImport && message.opsetImport.length)) message.opsetImport = []; - message.opsetImport.push($root.onnx.OperatorSetIdProto.decode(reader, reader.uint32())); + if (!(message.opsetImport && message.opsetImport.length)) + message.opsetImport = []; + message.opsetImport.push( + $root.onnx.OperatorSetIdProto.decode(reader, reader.uint32()), + ); break; } case 10: { @@ -7196,49 +8325,67 @@ $root.onnx = (function () { * @returns {string|null} `null` if valid, otherwise the reason why it is not */ FunctionProto.verify = function verify(message) { - if (typeof message !== 'object' || message === null) return 'object expected'; - if (message.name != null && message.hasOwnProperty('name')) - if (!$util.isString(message.name)) return 'name: string expected'; - if (message.input != null && message.hasOwnProperty('input')) { - if (!Array.isArray(message.input)) return 'input: array expected'; + if (typeof message !== "object" || message === null) + return "object expected"; + if (message.name != null && message.hasOwnProperty("name")) + if (!$util.isString(message.name)) return "name: string expected"; + if (message.input != null && message.hasOwnProperty("input")) { + if (!Array.isArray(message.input)) return "input: array expected"; for (var i = 0; i < message.input.length; ++i) - if (!$util.isString(message.input[i])) return 'input: string[] expected'; + if (!$util.isString(message.input[i])) + return "input: string[] expected"; } - if (message.output != null && message.hasOwnProperty('output')) { - if (!Array.isArray(message.output)) return 'output: array expected'; + if (message.output != null && message.hasOwnProperty("output")) { + if (!Array.isArray(message.output)) return "output: array expected"; for (var i = 0; i < message.output.length; ++i) - if (!$util.isString(message.output[i])) return 'output: string[] expected'; + if (!$util.isString(message.output[i])) + return "output: string[] expected"; } - if (message.attribute != null && message.hasOwnProperty('attribute')) { - if (!Array.isArray(message.attribute)) return 'attribute: array expected'; + if (message.attribute != null && message.hasOwnProperty("attribute")) { + if (!Array.isArray(message.attribute)) + return "attribute: array expected"; for (var i = 0; i < message.attribute.length; ++i) - if (!$util.isString(message.attribute[i])) return 'attribute: string[] expected'; - } - if (message.attributeProto != null && message.hasOwnProperty('attributeProto')) { - if (!Array.isArray(message.attributeProto)) return 'attributeProto: array expected'; + if (!$util.isString(message.attribute[i])) + return "attribute: string[] expected"; + } + if ( + message.attributeProto != null && + message.hasOwnProperty("attributeProto") + ) { + if (!Array.isArray(message.attributeProto)) + return "attributeProto: array expected"; for (var i = 0; i < message.attributeProto.length; ++i) { - var error = $root.onnx.AttributeProto.verify(message.attributeProto[i]); - if (error) return 'attributeProto.' + error; + var error = $root.onnx.AttributeProto.verify( + message.attributeProto[i], + ); + if (error) return "attributeProto." + error; } } - if (message.node != null && message.hasOwnProperty('node')) { - if (!Array.isArray(message.node)) return 'node: array expected'; + if (message.node != null && message.hasOwnProperty("node")) { + if (!Array.isArray(message.node)) return "node: array expected"; for (var i = 0; i < message.node.length; ++i) { var error = $root.onnx.NodeProto.verify(message.node[i]); - if (error) return 'node.' + error; + if (error) return "node." + error; } } - if (message.docString != null && message.hasOwnProperty('docString')) - if (!$util.isString(message.docString)) return 'docString: string expected'; - if (message.opsetImport != null && message.hasOwnProperty('opsetImport')) { - if (!Array.isArray(message.opsetImport)) return 'opsetImport: array expected'; + if (message.docString != null && message.hasOwnProperty("docString")) + if (!$util.isString(message.docString)) + return "docString: string expected"; + if ( + message.opsetImport != null && + message.hasOwnProperty("opsetImport") + ) { + if (!Array.isArray(message.opsetImport)) + return "opsetImport: array expected"; for (var i = 0; i < message.opsetImport.length; ++i) { - var error = $root.onnx.OperatorSetIdProto.verify(message.opsetImport[i]); - if (error) return 'opsetImport.' + error; + var error = $root.onnx.OperatorSetIdProto.verify( + message.opsetImport[i], + ); + if (error) return "opsetImport." + error; } } - if (message.domain != null && message.hasOwnProperty('domain')) - if (!$util.isString(message.domain)) return 'domain: string expected'; + if (message.domain != null && message.hasOwnProperty("domain")) + if (!$util.isString(message.domain)) return "domain: string expected"; return null; }; @@ -7255,46 +8402,62 @@ $root.onnx = (function () { var message = new $root.onnx.FunctionProto(); if (object.name != null) message.name = String(object.name); if (object.input) { - if (!Array.isArray(object.input)) throw TypeError('.onnx.FunctionProto.input: array expected'); + if (!Array.isArray(object.input)) + throw TypeError(".onnx.FunctionProto.input: array expected"); message.input = []; - for (var i = 0; i < object.input.length; ++i) message.input[i] = String(object.input[i]); + for (var i = 0; i < object.input.length; ++i) + message.input[i] = String(object.input[i]); } if (object.output) { - if (!Array.isArray(object.output)) throw TypeError('.onnx.FunctionProto.output: array expected'); + if (!Array.isArray(object.output)) + throw TypeError(".onnx.FunctionProto.output: array expected"); message.output = []; - for (var i = 0; i < object.output.length; ++i) message.output[i] = String(object.output[i]); + for (var i = 0; i < object.output.length; ++i) + message.output[i] = String(object.output[i]); } if (object.attribute) { - if (!Array.isArray(object.attribute)) throw TypeError('.onnx.FunctionProto.attribute: array expected'); + if (!Array.isArray(object.attribute)) + throw TypeError(".onnx.FunctionProto.attribute: array expected"); message.attribute = []; - for (var i = 0; i < object.attribute.length; ++i) message.attribute[i] = String(object.attribute[i]); + for (var i = 0; i < object.attribute.length; ++i) + message.attribute[i] = String(object.attribute[i]); } if (object.attributeProto) { if (!Array.isArray(object.attributeProto)) - throw TypeError('.onnx.FunctionProto.attributeProto: array expected'); + throw TypeError(".onnx.FunctionProto.attributeProto: array expected"); message.attributeProto = []; for (var i = 0; i < object.attributeProto.length; ++i) { - if (typeof object.attributeProto[i] !== 'object') - throw TypeError('.onnx.FunctionProto.attributeProto: object expected'); - message.attributeProto[i] = $root.onnx.AttributeProto.fromObject(object.attributeProto[i]); + if (typeof object.attributeProto[i] !== "object") + throw TypeError( + ".onnx.FunctionProto.attributeProto: object expected", + ); + message.attributeProto[i] = $root.onnx.AttributeProto.fromObject( + object.attributeProto[i], + ); } } if (object.node) { - if (!Array.isArray(object.node)) throw TypeError('.onnx.FunctionProto.node: array expected'); + if (!Array.isArray(object.node)) + throw TypeError(".onnx.FunctionProto.node: array expected"); message.node = []; for (var i = 0; i < object.node.length; ++i) { - if (typeof object.node[i] !== 'object') throw TypeError('.onnx.FunctionProto.node: object expected'); + if (typeof object.node[i] !== "object") + throw TypeError(".onnx.FunctionProto.node: object expected"); message.node[i] = $root.onnx.NodeProto.fromObject(object.node[i]); } } - if (object.docString != null) message.docString = String(object.docString); + if (object.docString != null) + message.docString = String(object.docString); if (object.opsetImport) { - if (!Array.isArray(object.opsetImport)) throw TypeError('.onnx.FunctionProto.opsetImport: array expected'); + if (!Array.isArray(object.opsetImport)) + throw TypeError(".onnx.FunctionProto.opsetImport: array expected"); message.opsetImport = []; for (var i = 0; i < object.opsetImport.length; ++i) { - if (typeof object.opsetImport[i] !== 'object') - throw TypeError('.onnx.FunctionProto.opsetImport: object expected'); - message.opsetImport[i] = $root.onnx.OperatorSetIdProto.fromObject(object.opsetImport[i]); + if (typeof object.opsetImport[i] !== "object") + throw TypeError(".onnx.FunctionProto.opsetImport: object expected"); + message.opsetImport[i] = $root.onnx.OperatorSetIdProto.fromObject( + object.opsetImport[i], + ); } } if (object.domain != null) message.domain = String(object.domain); @@ -7322,39 +8485,54 @@ $root.onnx = (function () { object.attributeProto = []; } if (options.defaults) { - object.name = ''; - object.docString = ''; - object.domain = ''; + object.name = ""; + object.docString = ""; + object.domain = ""; } - if (message.name != null && message.hasOwnProperty('name')) object.name = message.name; + if (message.name != null && message.hasOwnProperty("name")) + object.name = message.name; if (message.input && message.input.length) { object.input = []; - for (var j = 0; j < message.input.length; ++j) object.input[j] = message.input[j]; + for (var j = 0; j < message.input.length; ++j) + object.input[j] = message.input[j]; } if (message.output && message.output.length) { object.output = []; - for (var j = 0; j < message.output.length; ++j) object.output[j] = message.output[j]; + for (var j = 0; j < message.output.length; ++j) + object.output[j] = message.output[j]; } if (message.attribute && message.attribute.length) { object.attribute = []; - for (var j = 0; j < message.attribute.length; ++j) object.attribute[j] = message.attribute[j]; + for (var j = 0; j < message.attribute.length; ++j) + object.attribute[j] = message.attribute[j]; } if (message.node && message.node.length) { object.node = []; for (var j = 0; j < message.node.length; ++j) - object.node[j] = $root.onnx.NodeProto.toObject(message.node[j], options); + object.node[j] = $root.onnx.NodeProto.toObject( + message.node[j], + options, + ); } - if (message.docString != null && message.hasOwnProperty('docString')) object.docString = message.docString; + if (message.docString != null && message.hasOwnProperty("docString")) + object.docString = message.docString; if (message.opsetImport && message.opsetImport.length) { object.opsetImport = []; for (var j = 0; j < message.opsetImport.length; ++j) - object.opsetImport[j] = $root.onnx.OperatorSetIdProto.toObject(message.opsetImport[j], options); + object.opsetImport[j] = $root.onnx.OperatorSetIdProto.toObject( + message.opsetImport[j], + options, + ); } - if (message.domain != null && message.hasOwnProperty('domain')) object.domain = message.domain; + if (message.domain != null && message.hasOwnProperty("domain")) + object.domain = message.domain; if (message.attributeProto && message.attributeProto.length) { object.attributeProto = []; for (var j = 0; j < message.attributeProto.length; ++j) - object.attributeProto[j] = $root.onnx.AttributeProto.toObject(message.attributeProto[j], options); + object.attributeProto[j] = $root.onnx.AttributeProto.toObject( + message.attributeProto[j], + options, + ); } return object; }; @@ -7380,9 +8558,9 @@ $root.onnx = (function () { */ FunctionProto.getTypeUrl = function getTypeUrl(typeUrlPrefix) { if (typeUrlPrefix === undefined) { - typeUrlPrefix = 'type.googleapis.com'; + typeUrlPrefix = "type.googleapis.com"; } - return typeUrlPrefix + '/onnx.FunctionProto'; + return typeUrlPrefix + "/onnx.FunctionProto"; }; return FunctionProto; diff --git a/onnx-converter/src/protobuf/onnx.d.ts b/onnx-converter/src/protobuf/onnx.d.ts index cee6010d3..94aa80b35 100644 --- a/onnx-converter/src/protobuf/onnx.d.ts +++ b/onnx-converter/src/protobuf/onnx.d.ts @@ -1,12 +1,11 @@ // SOURCE: https://github.com/microsoft/onnxruntime/tree/main/js/web/lib/onnxjs/ort-schema/protobuf // LICENSE: MIT -import Long from 'long'; -import * as $protobuf from 'protobufjs'; +import Long from "long"; +import * as $protobuf from "protobufjs"; /** Namespace onnx. */ export namespace onnx { - /** Version enum. */ enum Version { _START_VERSION = 0, @@ -18,64 +17,64 @@ export namespace onnx { IR_VERSION_2019_9_19 = 6, IR_VERSION_2020_5_8 = 7, IR_VERSION_2021_7_30 = 8, - IR_VERSION = 9 + IR_VERSION = 9, } /** Properties of an AttributeProto. */ interface IAttributeProto { /** AttributeProto name */ - name?: (string|null); + name?: string | null; /** AttributeProto refAttrName */ - refAttrName?: (string|null); + refAttrName?: string | null; /** AttributeProto docString */ - docString?: (string|null); + docString?: string | null; /** AttributeProto type */ - type?: (onnx.AttributeProto.AttributeType|null); + type?: onnx.AttributeProto.AttributeType | null; /** AttributeProto f */ - f?: (number|null); + f?: number | null; /** AttributeProto i */ - i?: (number|Long|null); + i?: number | Long | null; /** AttributeProto s */ - s?: (Uint8Array|null); + s?: Uint8Array | null; /** AttributeProto t */ - t?: (onnx.ITensorProto|null); + t?: onnx.ITensorProto | null; /** AttributeProto g */ - g?: (onnx.IGraphProto|null); + g?: onnx.IGraphProto | null; /** AttributeProto sparseTensor */ - sparseTensor?: (onnx.ISparseTensorProto|null); + sparseTensor?: onnx.ISparseTensorProto | null; /** AttributeProto tp */ - tp?: (onnx.ITypeProto|null); + tp?: onnx.ITypeProto | null; /** AttributeProto floats */ - floats?: (number[]|null); + floats?: number[] | null; /** AttributeProto ints */ - ints?: ((number | Long)[]|null); + ints?: (number | Long)[] | null; /** AttributeProto strings */ - strings?: (Uint8Array[]|null); + strings?: Uint8Array[] | null; /** AttributeProto tensors */ - tensors?: (onnx.ITensorProto[]|null); + tensors?: onnx.ITensorProto[] | null; /** AttributeProto graphs */ - graphs?: (onnx.IGraphProto[]|null); + graphs?: onnx.IGraphProto[] | null; /** AttributeProto sparseTensors */ - sparseTensors?: (onnx.ISparseTensorProto[]|null); + sparseTensors?: onnx.ISparseTensorProto[] | null; /** AttributeProto typeProtos */ - typeProtos?: (onnx.ITypeProto[]|null); + typeProtos?: onnx.ITypeProto[] | null; } /** Represents an AttributeProto. */ @@ -102,28 +101,28 @@ export namespace onnx { public f: number; /** AttributeProto i. */ - public i: (number|Long); + public i: number | Long; /** AttributeProto s. */ public s: Uint8Array; /** AttributeProto t. */ - public t?: (onnx.ITensorProto|null); + public t?: onnx.ITensorProto | null; /** AttributeProto g. */ - public g?: (onnx.IGraphProto|null); + public g?: onnx.IGraphProto | null; /** AttributeProto sparseTensor. */ - public sparseTensor?: (onnx.ISparseTensorProto|null); + public sparseTensor?: onnx.ISparseTensorProto | null; /** AttributeProto tp. */ - public tp?: (onnx.ITypeProto|null); + public tp?: onnx.ITypeProto | null; /** AttributeProto floats. */ public floats: number[]; /** AttributeProto ints. */ - public ints: (number|Long)[]; + public ints: (number | Long)[]; /** AttributeProto strings. */ public strings: Uint8Array[]; @@ -145,7 +144,9 @@ export namespace onnx { * @param [properties] Properties to set * @returns AttributeProto instance */ - public static create(properties?: onnx.IAttributeProto): onnx.AttributeProto; + public static create( + properties?: onnx.IAttributeProto, + ): onnx.AttributeProto; /** * Encodes the specified AttributeProto message. Does not implicitly {@link onnx.AttributeProto.verify|verify} @@ -154,7 +155,10 @@ export namespace onnx { * @param [writer] Writer to encode to * @returns Writer */ - public static encode(message: onnx.IAttributeProto, writer?: $protobuf.Writer): $protobuf.Writer; + public static encode( + message: onnx.IAttributeProto, + writer?: $protobuf.Writer, + ): $protobuf.Writer; /** * Encodes the specified AttributeProto message, length delimited. Does not implicitly {@link @@ -163,7 +167,10 @@ export namespace onnx { * @param [writer] Writer to encode to * @returns Writer */ - public static encodeDelimited(message: onnx.IAttributeProto, writer?: $protobuf.Writer): $protobuf.Writer; + public static encodeDelimited( + message: onnx.IAttributeProto, + writer?: $protobuf.Writer, + ): $protobuf.Writer; /** * Decodes an AttributeProto message from the specified reader or buffer. @@ -173,7 +180,10 @@ export namespace onnx { * @throws {Error} If the payload is not a reader or valid buffer * @throws {$protobuf.util.ProtocolError} If required fields are missing */ - public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.AttributeProto; + public static decode( + reader: $protobuf.Reader | Uint8Array, + length?: number, + ): onnx.AttributeProto; /** * Decodes an AttributeProto message from the specified reader or buffer, length delimited. @@ -182,21 +192,23 @@ export namespace onnx { * @throws {Error} If the payload is not a reader or valid buffer * @throws {$protobuf.util.ProtocolError} If required fields are missing */ - public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.AttributeProto; + public static decodeDelimited( + reader: $protobuf.Reader | Uint8Array, + ): onnx.AttributeProto; /** * Verifies an AttributeProto message. * @param message Plain object to verify * @returns `null` if valid, otherwise the reason why it is not */ - public static verify(message: {[k: string]: any}): (string|null); + public static verify(message: { [k: string]: any }): string | null; /** * Creates an AttributeProto message from a plain object. Also converts values to their respective internal types. * @param object Plain object * @returns AttributeProto */ - public static fromObject(object: {[k: string]: any}): onnx.AttributeProto; + public static fromObject(object: { [k: string]: any }): onnx.AttributeProto; /** * Creates a plain object from an AttributeProto message. Also converts values to other types if specified. @@ -204,13 +216,16 @@ export namespace onnx { * @param [options] Conversion options * @returns Plain object */ - public static toObject(message: onnx.AttributeProto, options?: $protobuf.IConversionOptions): {[k: string]: any}; + public static toObject( + message: onnx.AttributeProto, + options?: $protobuf.IConversionOptions, + ): { [k: string]: any }; /** * Converts this AttributeProto to JSON. * @returns JSON object */ - public toJSON(): {[k: string]: any}; + public toJSON(): { [k: string]: any }; /** * Gets the default type url for AttributeProto @@ -221,7 +236,6 @@ export namespace onnx { } namespace AttributeProto { - /** AttributeType enum. */ enum AttributeType { UNDEFINED = 0, @@ -238,20 +252,20 @@ export namespace onnx { TENSORS = 9, GRAPHS = 10, SPARSE_TENSORS = 12, - TYPE_PROTOS = 14 + TYPE_PROTOS = 14, } } /** Properties of a ValueInfoProto. */ interface IValueInfoProto { /** ValueInfoProto name */ - name?: (string|null); + name?: string | null; /** ValueInfoProto type */ - type?: (onnx.ITypeProto|null); + type?: onnx.ITypeProto | null; /** ValueInfoProto docString */ - docString?: (string|null); + docString?: string | null; } /** Represents a ValueInfoProto. */ @@ -266,7 +280,7 @@ export namespace onnx { public name: string; /** ValueInfoProto type. */ - public type?: (onnx.ITypeProto|null); + public type?: onnx.ITypeProto | null; /** ValueInfoProto docString. */ public docString: string; @@ -276,7 +290,9 @@ export namespace onnx { * @param [properties] Properties to set * @returns ValueInfoProto instance */ - public static create(properties?: onnx.IValueInfoProto): onnx.ValueInfoProto; + public static create( + properties?: onnx.IValueInfoProto, + ): onnx.ValueInfoProto; /** * Encodes the specified ValueInfoProto message. Does not implicitly {@link onnx.ValueInfoProto.verify|verify} @@ -285,7 +301,10 @@ export namespace onnx { * @param [writer] Writer to encode to * @returns Writer */ - public static encode(message: onnx.IValueInfoProto, writer?: $protobuf.Writer): $protobuf.Writer; + public static encode( + message: onnx.IValueInfoProto, + writer?: $protobuf.Writer, + ): $protobuf.Writer; /** * Encodes the specified ValueInfoProto message, length delimited. Does not implicitly {@link @@ -294,7 +313,10 @@ export namespace onnx { * @param [writer] Writer to encode to * @returns Writer */ - public static encodeDelimited(message: onnx.IValueInfoProto, writer?: $protobuf.Writer): $protobuf.Writer; + public static encodeDelimited( + message: onnx.IValueInfoProto, + writer?: $protobuf.Writer, + ): $protobuf.Writer; /** * Decodes a ValueInfoProto message from the specified reader or buffer. @@ -304,7 +326,10 @@ export namespace onnx { * @throws {Error} If the payload is not a reader or valid buffer * @throws {$protobuf.util.ProtocolError} If required fields are missing */ - public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.ValueInfoProto; + public static decode( + reader: $protobuf.Reader | Uint8Array, + length?: number, + ): onnx.ValueInfoProto; /** * Decodes a ValueInfoProto message from the specified reader or buffer, length delimited. @@ -313,21 +338,23 @@ export namespace onnx { * @throws {Error} If the payload is not a reader or valid buffer * @throws {$protobuf.util.ProtocolError} If required fields are missing */ - public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.ValueInfoProto; + public static decodeDelimited( + reader: $protobuf.Reader | Uint8Array, + ): onnx.ValueInfoProto; /** * Verifies a ValueInfoProto message. * @param message Plain object to verify * @returns `null` if valid, otherwise the reason why it is not */ - public static verify(message: {[k: string]: any}): (string|null); + public static verify(message: { [k: string]: any }): string | null; /** * Creates a ValueInfoProto message from a plain object. Also converts values to their respective internal types. * @param object Plain object * @returns ValueInfoProto */ - public static fromObject(object: {[k: string]: any}): onnx.ValueInfoProto; + public static fromObject(object: { [k: string]: any }): onnx.ValueInfoProto; /** * Creates a plain object from a ValueInfoProto message. Also converts values to other types if specified. @@ -335,13 +362,16 @@ export namespace onnx { * @param [options] Conversion options * @returns Plain object */ - public static toObject(message: onnx.ValueInfoProto, options?: $protobuf.IConversionOptions): {[k: string]: any}; + public static toObject( + message: onnx.ValueInfoProto, + options?: $protobuf.IConversionOptions, + ): { [k: string]: any }; /** * Converts this ValueInfoProto to JSON. * @returns JSON object */ - public toJSON(): {[k: string]: any}; + public toJSON(): { [k: string]: any }; /** * Gets the default type url for ValueInfoProto @@ -354,25 +384,25 @@ export namespace onnx { /** Properties of a NodeProto. */ interface INodeProto { /** NodeProto input */ - input?: (string[]|null); + input?: string[] | null; /** NodeProto output */ - output?: (string[]|null); + output?: string[] | null; /** NodeProto name */ - name?: (string|null); + name?: string | null; /** NodeProto opType */ - opType?: (string|null); + opType?: string | null; /** NodeProto domain */ - domain?: (string|null); + domain?: string | null; /** NodeProto attribute */ - attribute?: (onnx.IAttributeProto[]|null); + attribute?: onnx.IAttributeProto[] | null; /** NodeProto docString */ - docString?: (string|null); + docString?: string | null; } /** Represents a NodeProto. */ @@ -417,7 +447,10 @@ export namespace onnx { * @param [writer] Writer to encode to * @returns Writer */ - public static encode(message: onnx.INodeProto, writer?: $protobuf.Writer): $protobuf.Writer; + public static encode( + message: onnx.INodeProto, + writer?: $protobuf.Writer, + ): $protobuf.Writer; /** * Encodes the specified NodeProto message, length delimited. Does not implicitly {@link @@ -426,7 +459,10 @@ export namespace onnx { * @param [writer] Writer to encode to * @returns Writer */ - public static encodeDelimited(message: onnx.INodeProto, writer?: $protobuf.Writer): $protobuf.Writer; + public static encodeDelimited( + message: onnx.INodeProto, + writer?: $protobuf.Writer, + ): $protobuf.Writer; /** * Decodes a NodeProto message from the specified reader or buffer. @@ -436,7 +472,10 @@ export namespace onnx { * @throws {Error} If the payload is not a reader or valid buffer * @throws {$protobuf.util.ProtocolError} If required fields are missing */ - public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.NodeProto; + public static decode( + reader: $protobuf.Reader | Uint8Array, + length?: number, + ): onnx.NodeProto; /** * Decodes a NodeProto message from the specified reader or buffer, length delimited. @@ -445,21 +484,23 @@ export namespace onnx { * @throws {Error} If the payload is not a reader or valid buffer * @throws {$protobuf.util.ProtocolError} If required fields are missing */ - public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.NodeProto; + public static decodeDelimited( + reader: $protobuf.Reader | Uint8Array, + ): onnx.NodeProto; /** * Verifies a NodeProto message. * @param message Plain object to verify * @returns `null` if valid, otherwise the reason why it is not */ - public static verify(message: {[k: string]: any}): (string|null); + public static verify(message: { [k: string]: any }): string | null; /** * Creates a NodeProto message from a plain object. Also converts values to their respective internal types. * @param object Plain object * @returns NodeProto */ - public static fromObject(object: {[k: string]: any}): onnx.NodeProto; + public static fromObject(object: { [k: string]: any }): onnx.NodeProto; /** * Creates a plain object from a NodeProto message. Also converts values to other types if specified. @@ -467,13 +508,16 @@ export namespace onnx { * @param [options] Conversion options * @returns Plain object */ - public static toObject(message: onnx.NodeProto, options?: $protobuf.IConversionOptions): {[k: string]: any}; + public static toObject( + message: onnx.NodeProto, + options?: $protobuf.IConversionOptions, + ): { [k: string]: any }; /** * Converts this NodeProto to JSON. * @returns JSON object */ - public toJSON(): {[k: string]: any}; + public toJSON(): { [k: string]: any }; /** * Gets the default type url for NodeProto @@ -486,16 +530,16 @@ export namespace onnx { /** Properties of a TrainingInfoProto. */ interface ITrainingInfoProto { /** TrainingInfoProto initialization */ - initialization?: (onnx.IGraphProto|null); + initialization?: onnx.IGraphProto | null; /** TrainingInfoProto algorithm */ - algorithm?: (onnx.IGraphProto|null); + algorithm?: onnx.IGraphProto | null; /** TrainingInfoProto initializationBinding */ - initializationBinding?: (onnx.IStringStringEntryProto[]|null); + initializationBinding?: onnx.IStringStringEntryProto[] | null; /** TrainingInfoProto updateBinding */ - updateBinding?: (onnx.IStringStringEntryProto[]|null); + updateBinding?: onnx.IStringStringEntryProto[] | null; } /** Represents a TrainingInfoProto. */ @@ -507,10 +551,10 @@ export namespace onnx { constructor(properties?: onnx.ITrainingInfoProto); /** TrainingInfoProto initialization. */ - public initialization?: (onnx.IGraphProto|null); + public initialization?: onnx.IGraphProto | null; /** TrainingInfoProto algorithm. */ - public algorithm?: (onnx.IGraphProto|null); + public algorithm?: onnx.IGraphProto | null; /** TrainingInfoProto initializationBinding. */ public initializationBinding: onnx.IStringStringEntryProto[]; @@ -523,7 +567,9 @@ export namespace onnx { * @param [properties] Properties to set * @returns TrainingInfoProto instance */ - public static create(properties?: onnx.ITrainingInfoProto): onnx.TrainingInfoProto; + public static create( + properties?: onnx.ITrainingInfoProto, + ): onnx.TrainingInfoProto; /** * Encodes the specified TrainingInfoProto message. Does not implicitly {@link onnx.TrainingInfoProto.verify|verify} @@ -532,7 +578,10 @@ export namespace onnx { * @param [writer] Writer to encode to * @returns Writer */ - public static encode(message: onnx.ITrainingInfoProto, writer?: $protobuf.Writer): $protobuf.Writer; + public static encode( + message: onnx.ITrainingInfoProto, + writer?: $protobuf.Writer, + ): $protobuf.Writer; /** * Encodes the specified TrainingInfoProto message, length delimited. Does not implicitly {@link @@ -541,7 +590,10 @@ export namespace onnx { * @param [writer] Writer to encode to * @returns Writer */ - public static encodeDelimited(message: onnx.ITrainingInfoProto, writer?: $protobuf.Writer): $protobuf.Writer; + public static encodeDelimited( + message: onnx.ITrainingInfoProto, + writer?: $protobuf.Writer, + ): $protobuf.Writer; /** * Decodes a TrainingInfoProto message from the specified reader or buffer. @@ -551,7 +603,10 @@ export namespace onnx { * @throws {Error} If the payload is not a reader or valid buffer * @throws {$protobuf.util.ProtocolError} If required fields are missing */ - public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.TrainingInfoProto; + public static decode( + reader: $protobuf.Reader | Uint8Array, + length?: number, + ): onnx.TrainingInfoProto; /** * Decodes a TrainingInfoProto message from the specified reader or buffer, length delimited. @@ -560,21 +615,25 @@ export namespace onnx { * @throws {Error} If the payload is not a reader or valid buffer * @throws {$protobuf.util.ProtocolError} If required fields are missing */ - public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.TrainingInfoProto; + public static decodeDelimited( + reader: $protobuf.Reader | Uint8Array, + ): onnx.TrainingInfoProto; /** * Verifies a TrainingInfoProto message. * @param message Plain object to verify * @returns `null` if valid, otherwise the reason why it is not */ - public static verify(message: {[k: string]: any}): (string|null); + public static verify(message: { [k: string]: any }): string | null; /** * Creates a TrainingInfoProto message from a plain object. Also converts values to their respective internal types. * @param object Plain object * @returns TrainingInfoProto */ - public static fromObject(object: {[k: string]: any}): onnx.TrainingInfoProto; + public static fromObject(object: { + [k: string]: any; + }): onnx.TrainingInfoProto; /** * Creates a plain object from a TrainingInfoProto message. Also converts values to other types if specified. @@ -582,13 +641,16 @@ export namespace onnx { * @param [options] Conversion options * @returns Plain object */ - public static toObject(message: onnx.TrainingInfoProto, options?: $protobuf.IConversionOptions): {[k: string]: any}; + public static toObject( + message: onnx.TrainingInfoProto, + options?: $protobuf.IConversionOptions, + ): { [k: string]: any }; /** * Converts this TrainingInfoProto to JSON. * @returns JSON object */ - public toJSON(): {[k: string]: any}; + public toJSON(): { [k: string]: any }; /** * Gets the default type url for TrainingInfoProto @@ -601,37 +663,37 @@ export namespace onnx { /** Properties of a ModelProto. */ interface IModelProto { /** ModelProto irVersion */ - irVersion?: (number|Long|null); + irVersion?: number | Long | null; /** ModelProto opsetImport */ - opsetImport?: (onnx.IOperatorSetIdProto[]|null); + opsetImport?: onnx.IOperatorSetIdProto[] | null; /** ModelProto producerName */ - producerName?: (string|null); + producerName?: string | null; /** ModelProto producerVersion */ - producerVersion?: (string|null); + producerVersion?: string | null; /** ModelProto domain */ - domain?: (string|null); + domain?: string | null; /** ModelProto modelVersion */ - modelVersion?: (number|Long|null); + modelVersion?: number | Long | null; /** ModelProto docString */ - docString?: (string|null); + docString?: string | null; /** ModelProto graph */ - graph?: (onnx.IGraphProto|null); + graph?: onnx.IGraphProto | null; /** ModelProto metadataProps */ - metadataProps?: (onnx.IStringStringEntryProto[]|null); + metadataProps?: onnx.IStringStringEntryProto[] | null; /** ModelProto trainingInfo */ - trainingInfo?: (onnx.ITrainingInfoProto[]|null); + trainingInfo?: onnx.ITrainingInfoProto[] | null; /** ModelProto functions */ - functions?: (onnx.IFunctionProto[]|null); + functions?: onnx.IFunctionProto[] | null; } /** Represents a ModelProto. */ @@ -643,7 +705,7 @@ export namespace onnx { constructor(properties?: onnx.IModelProto); /** ModelProto irVersion. */ - public irVersion: (number|Long); + public irVersion: number | Long; /** ModelProto opsetImport. */ public opsetImport: onnx.IOperatorSetIdProto[]; @@ -658,13 +720,13 @@ export namespace onnx { public domain: string; /** ModelProto modelVersion. */ - public modelVersion: (number|Long); + public modelVersion: number | Long; /** ModelProto docString. */ public docString: string; /** ModelProto graph. */ - public graph?: (onnx.IGraphProto|null); + public graph?: onnx.IGraphProto | null; /** ModelProto metadataProps. */ public metadataProps: onnx.IStringStringEntryProto[]; @@ -688,7 +750,10 @@ export namespace onnx { * @param [writer] Writer to encode to * @returns Writer */ - public static encode(message: onnx.IModelProto, writer?: $protobuf.Writer): $protobuf.Writer; + public static encode( + message: onnx.IModelProto, + writer?: $protobuf.Writer, + ): $protobuf.Writer; /** * Encodes the specified ModelProto message, length delimited. Does not implicitly {@link @@ -697,7 +762,10 @@ export namespace onnx { * @param [writer] Writer to encode to * @returns Writer */ - public static encodeDelimited(message: onnx.IModelProto, writer?: $protobuf.Writer): $protobuf.Writer; + public static encodeDelimited( + message: onnx.IModelProto, + writer?: $protobuf.Writer, + ): $protobuf.Writer; /** * Decodes a ModelProto message from the specified reader or buffer. @@ -707,7 +775,10 @@ export namespace onnx { * @throws {Error} If the payload is not a reader or valid buffer * @throws {$protobuf.util.ProtocolError} If required fields are missing */ - public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.ModelProto; + public static decode( + reader: $protobuf.Reader | Uint8Array, + length?: number, + ): onnx.ModelProto; /** * Decodes a ModelProto message from the specified reader or buffer, length delimited. @@ -716,21 +787,23 @@ export namespace onnx { * @throws {Error} If the payload is not a reader or valid buffer * @throws {$protobuf.util.ProtocolError} If required fields are missing */ - public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.ModelProto; + public static decodeDelimited( + reader: $protobuf.Reader | Uint8Array, + ): onnx.ModelProto; /** * Verifies a ModelProto message. * @param message Plain object to verify * @returns `null` if valid, otherwise the reason why it is not */ - public static verify(message: {[k: string]: any}): (string|null); + public static verify(message: { [k: string]: any }): string | null; /** * Creates a ModelProto message from a plain object. Also converts values to their respective internal types. * @param object Plain object * @returns ModelProto */ - public static fromObject(object: {[k: string]: any}): onnx.ModelProto; + public static fromObject(object: { [k: string]: any }): onnx.ModelProto; /** * Creates a plain object from a ModelProto message. Also converts values to other types if specified. @@ -738,13 +811,16 @@ export namespace onnx { * @param [options] Conversion options * @returns Plain object */ - public static toObject(message: onnx.ModelProto, options?: $protobuf.IConversionOptions): {[k: string]: any}; + public static toObject( + message: onnx.ModelProto, + options?: $protobuf.IConversionOptions, + ): { [k: string]: any }; /** * Converts this ModelProto to JSON. * @returns JSON object */ - public toJSON(): {[k: string]: any}; + public toJSON(): { [k: string]: any }; /** * Gets the default type url for ModelProto @@ -757,10 +833,10 @@ export namespace onnx { /** Properties of a StringStringEntryProto. */ interface IStringStringEntryProto { /** StringStringEntryProto key */ - key?: (string|null); + key?: string | null; /** StringStringEntryProto value */ - value?: (string|null); + value?: string | null; } /** Represents a StringStringEntryProto. */ @@ -782,7 +858,9 @@ export namespace onnx { * @param [properties] Properties to set * @returns StringStringEntryProto instance */ - public static create(properties?: onnx.IStringStringEntryProto): onnx.StringStringEntryProto; + public static create( + properties?: onnx.IStringStringEntryProto, + ): onnx.StringStringEntryProto; /** * Encodes the specified StringStringEntryProto message. Does not implicitly {@link @@ -791,7 +869,10 @@ export namespace onnx { * @param [writer] Writer to encode to * @returns Writer */ - public static encode(message: onnx.IStringStringEntryProto, writer?: $protobuf.Writer): $protobuf.Writer; + public static encode( + message: onnx.IStringStringEntryProto, + writer?: $protobuf.Writer, + ): $protobuf.Writer; /** * Encodes the specified StringStringEntryProto message, length delimited. Does not implicitly {@link @@ -800,7 +881,10 @@ export namespace onnx { * @param [writer] Writer to encode to * @returns Writer */ - public static encodeDelimited(message: onnx.IStringStringEntryProto, writer?: $protobuf.Writer): $protobuf.Writer; + public static encodeDelimited( + message: onnx.IStringStringEntryProto, + writer?: $protobuf.Writer, + ): $protobuf.Writer; /** * Decodes a StringStringEntryProto message from the specified reader or buffer. @@ -810,7 +894,10 @@ export namespace onnx { * @throws {Error} If the payload is not a reader or valid buffer * @throws {$protobuf.util.ProtocolError} If required fields are missing */ - public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.StringStringEntryProto; + public static decode( + reader: $protobuf.Reader | Uint8Array, + length?: number, + ): onnx.StringStringEntryProto; /** * Decodes a StringStringEntryProto message from the specified reader or buffer, length delimited. @@ -819,14 +906,16 @@ export namespace onnx { * @throws {Error} If the payload is not a reader or valid buffer * @throws {$protobuf.util.ProtocolError} If required fields are missing */ - public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.StringStringEntryProto; + public static decodeDelimited( + reader: $protobuf.Reader | Uint8Array, + ): onnx.StringStringEntryProto; /** * Verifies a StringStringEntryProto message. * @param message Plain object to verify * @returns `null` if valid, otherwise the reason why it is not */ - public static verify(message: {[k: string]: any}): (string|null); + public static verify(message: { [k: string]: any }): string | null; /** * Creates a StringStringEntryProto message from a plain object. Also converts values to their respective internal @@ -834,7 +923,9 @@ export namespace onnx { * @param object Plain object * @returns StringStringEntryProto */ - public static fromObject(object: {[k: string]: any}): onnx.StringStringEntryProto; + public static fromObject(object: { + [k: string]: any; + }): onnx.StringStringEntryProto; /** * Creates a plain object from a StringStringEntryProto message. Also converts values to other types if specified. @@ -842,14 +933,16 @@ export namespace onnx { * @param [options] Conversion options * @returns Plain object */ - public static toObject(message: onnx.StringStringEntryProto, options?: $protobuf.IConversionOptions): - {[k: string]: any}; + public static toObject( + message: onnx.StringStringEntryProto, + options?: $protobuf.IConversionOptions, + ): { [k: string]: any }; /** * Converts this StringStringEntryProto to JSON. * @returns JSON object */ - public toJSON(): {[k: string]: any}; + public toJSON(): { [k: string]: any }; /** * Gets the default type url for StringStringEntryProto @@ -862,10 +955,10 @@ export namespace onnx { /** Properties of a TensorAnnotation. */ interface ITensorAnnotation { /** TensorAnnotation tensorName */ - tensorName?: (string|null); + tensorName?: string | null; /** TensorAnnotation quantParameterTensorNames */ - quantParameterTensorNames?: (onnx.IStringStringEntryProto[]|null); + quantParameterTensorNames?: onnx.IStringStringEntryProto[] | null; } /** Represents a TensorAnnotation. */ @@ -887,7 +980,9 @@ export namespace onnx { * @param [properties] Properties to set * @returns TensorAnnotation instance */ - public static create(properties?: onnx.ITensorAnnotation): onnx.TensorAnnotation; + public static create( + properties?: onnx.ITensorAnnotation, + ): onnx.TensorAnnotation; /** * Encodes the specified TensorAnnotation message. Does not implicitly {@link onnx.TensorAnnotation.verify|verify} @@ -896,7 +991,10 @@ export namespace onnx { * @param [writer] Writer to encode to * @returns Writer */ - public static encode(message: onnx.ITensorAnnotation, writer?: $protobuf.Writer): $protobuf.Writer; + public static encode( + message: onnx.ITensorAnnotation, + writer?: $protobuf.Writer, + ): $protobuf.Writer; /** * Encodes the specified TensorAnnotation message, length delimited. Does not implicitly {@link @@ -905,7 +1003,10 @@ export namespace onnx { * @param [writer] Writer to encode to * @returns Writer */ - public static encodeDelimited(message: onnx.ITensorAnnotation, writer?: $protobuf.Writer): $protobuf.Writer; + public static encodeDelimited( + message: onnx.ITensorAnnotation, + writer?: $protobuf.Writer, + ): $protobuf.Writer; /** * Decodes a TensorAnnotation message from the specified reader or buffer. @@ -915,7 +1016,10 @@ export namespace onnx { * @throws {Error} If the payload is not a reader or valid buffer * @throws {$protobuf.util.ProtocolError} If required fields are missing */ - public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.TensorAnnotation; + public static decode( + reader: $protobuf.Reader | Uint8Array, + length?: number, + ): onnx.TensorAnnotation; /** * Decodes a TensorAnnotation message from the specified reader or buffer, length delimited. @@ -924,21 +1028,25 @@ export namespace onnx { * @throws {Error} If the payload is not a reader or valid buffer * @throws {$protobuf.util.ProtocolError} If required fields are missing */ - public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.TensorAnnotation; + public static decodeDelimited( + reader: $protobuf.Reader | Uint8Array, + ): onnx.TensorAnnotation; /** * Verifies a TensorAnnotation message. * @param message Plain object to verify * @returns `null` if valid, otherwise the reason why it is not */ - public static verify(message: {[k: string]: any}): (string|null); + public static verify(message: { [k: string]: any }): string | null; /** * Creates a TensorAnnotation message from a plain object. Also converts values to their respective internal types. * @param object Plain object * @returns TensorAnnotation */ - public static fromObject(object: {[k: string]: any}): onnx.TensorAnnotation; + public static fromObject(object: { + [k: string]: any; + }): onnx.TensorAnnotation; /** * Creates a plain object from a TensorAnnotation message. Also converts values to other types if specified. @@ -946,13 +1054,16 @@ export namespace onnx { * @param [options] Conversion options * @returns Plain object */ - public static toObject(message: onnx.TensorAnnotation, options?: $protobuf.IConversionOptions): {[k: string]: any}; + public static toObject( + message: onnx.TensorAnnotation, + options?: $protobuf.IConversionOptions, + ): { [k: string]: any }; /** * Converts this TensorAnnotation to JSON. * @returns JSON object */ - public toJSON(): {[k: string]: any}; + public toJSON(): { [k: string]: any }; /** * Gets the default type url for TensorAnnotation @@ -965,31 +1076,31 @@ export namespace onnx { /** Properties of a GraphProto. */ interface IGraphProto { /** GraphProto node */ - node?: (onnx.INodeProto[]|null); + node?: onnx.INodeProto[] | null; /** GraphProto name */ - name?: (string|null); + name?: string | null; /** GraphProto initializer */ - initializer?: (onnx.ITensorProto[]|null); + initializer?: onnx.ITensorProto[] | null; /** GraphProto sparseInitializer */ - sparseInitializer?: (onnx.ISparseTensorProto[]|null); + sparseInitializer?: onnx.ISparseTensorProto[] | null; /** GraphProto docString */ - docString?: (string|null); + docString?: string | null; /** GraphProto input */ - input?: (onnx.IValueInfoProto[]|null); + input?: onnx.IValueInfoProto[] | null; /** GraphProto output */ - output?: (onnx.IValueInfoProto[]|null); + output?: onnx.IValueInfoProto[] | null; /** GraphProto valueInfo */ - valueInfo?: (onnx.IValueInfoProto[]|null); + valueInfo?: onnx.IValueInfoProto[] | null; /** GraphProto quantizationAnnotation */ - quantizationAnnotation?: (onnx.ITensorAnnotation[]|null); + quantizationAnnotation?: onnx.ITensorAnnotation[] | null; } /** Represents a GraphProto. */ @@ -1040,7 +1151,10 @@ export namespace onnx { * @param [writer] Writer to encode to * @returns Writer */ - public static encode(message: onnx.IGraphProto, writer?: $protobuf.Writer): $protobuf.Writer; + public static encode( + message: onnx.IGraphProto, + writer?: $protobuf.Writer, + ): $protobuf.Writer; /** * Encodes the specified GraphProto message, length delimited. Does not implicitly {@link @@ -1049,7 +1163,10 @@ export namespace onnx { * @param [writer] Writer to encode to * @returns Writer */ - public static encodeDelimited(message: onnx.IGraphProto, writer?: $protobuf.Writer): $protobuf.Writer; + public static encodeDelimited( + message: onnx.IGraphProto, + writer?: $protobuf.Writer, + ): $protobuf.Writer; /** * Decodes a GraphProto message from the specified reader or buffer. @@ -1059,7 +1176,10 @@ export namespace onnx { * @throws {Error} If the payload is not a reader or valid buffer * @throws {$protobuf.util.ProtocolError} If required fields are missing */ - public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.GraphProto; + public static decode( + reader: $protobuf.Reader | Uint8Array, + length?: number, + ): onnx.GraphProto; /** * Decodes a GraphProto message from the specified reader or buffer, length delimited. @@ -1068,21 +1188,23 @@ export namespace onnx { * @throws {Error} If the payload is not a reader or valid buffer * @throws {$protobuf.util.ProtocolError} If required fields are missing */ - public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.GraphProto; + public static decodeDelimited( + reader: $protobuf.Reader | Uint8Array, + ): onnx.GraphProto; /** * Verifies a GraphProto message. * @param message Plain object to verify * @returns `null` if valid, otherwise the reason why it is not */ - public static verify(message: {[k: string]: any}): (string|null); + public static verify(message: { [k: string]: any }): string | null; /** * Creates a GraphProto message from a plain object. Also converts values to their respective internal types. * @param object Plain object * @returns GraphProto */ - public static fromObject(object: {[k: string]: any}): onnx.GraphProto; + public static fromObject(object: { [k: string]: any }): onnx.GraphProto; /** * Creates a plain object from a GraphProto message. Also converts values to other types if specified. @@ -1090,13 +1212,16 @@ export namespace onnx { * @param [options] Conversion options * @returns Plain object */ - public static toObject(message: onnx.GraphProto, options?: $protobuf.IConversionOptions): {[k: string]: any}; + public static toObject( + message: onnx.GraphProto, + options?: $protobuf.IConversionOptions, + ): { [k: string]: any }; /** * Converts this GraphProto to JSON. * @returns JSON object */ - public toJSON(): {[k: string]: any}; + public toJSON(): { [k: string]: any }; /** * Gets the default type url for GraphProto @@ -1109,46 +1234,46 @@ export namespace onnx { /** Properties of a TensorProto. */ interface ITensorProto { /** TensorProto dims */ - dims?: ((number | Long)[]|null); + dims?: (number | Long)[] | null; /** TensorProto dataType */ - dataType?: (number|null); + dataType?: number | null; /** TensorProto segment */ - segment?: (onnx.TensorProto.ISegment|null); + segment?: onnx.TensorProto.ISegment | null; /** TensorProto floatData */ - floatData?: (number[]|null); + floatData?: number[] | null; /** TensorProto int32Data */ - int32Data?: (number[]|null); + int32Data?: number[] | null; /** TensorProto stringData */ - stringData?: (Uint8Array[]|null); + stringData?: Uint8Array[] | null; /** TensorProto int64Data */ - int64Data?: ((number | Long)[]|null); + int64Data?: (number | Long)[] | null; /** TensorProto name */ - name?: (string|null); + name?: string | null; /** TensorProto docString */ - docString?: (string|null); + docString?: string | null; /** TensorProto rawData */ - rawData?: (Uint8Array|null); + rawData?: Uint8Array | null; /** TensorProto externalData */ - externalData?: (onnx.IStringStringEntryProto[]|null); + externalData?: onnx.IStringStringEntryProto[] | null; /** TensorProto dataLocation */ - dataLocation?: (onnx.TensorProto.DataLocation|null); + dataLocation?: onnx.TensorProto.DataLocation | null; /** TensorProto doubleData */ - doubleData?: (number[]|null); + doubleData?: number[] | null; /** TensorProto uint64Data */ - uint64Data?: ((number | Long)[]|null); + uint64Data?: (number | Long)[] | null; } /** Represents a TensorProto. */ @@ -1160,13 +1285,13 @@ export namespace onnx { constructor(properties?: onnx.ITensorProto); /** TensorProto dims. */ - public dims: (number|Long)[]; + public dims: (number | Long)[]; /** TensorProto dataType. */ public dataType: number; /** TensorProto segment. */ - public segment?: (onnx.TensorProto.ISegment|null); + public segment?: onnx.TensorProto.ISegment | null; /** TensorProto floatData. */ public floatData: number[]; @@ -1178,7 +1303,7 @@ export namespace onnx { public stringData: Uint8Array[]; /** TensorProto int64Data. */ - public int64Data: (number|Long)[]; + public int64Data: (number | Long)[]; /** TensorProto name. */ public name: string; @@ -1199,7 +1324,7 @@ export namespace onnx { public doubleData: number[]; /** TensorProto uint64Data. */ - public uint64Data: (number|Long)[]; + public uint64Data: (number | Long)[]; /** * Creates a new TensorProto instance using the specified properties. @@ -1214,7 +1339,10 @@ export namespace onnx { * @param [writer] Writer to encode to * @returns Writer */ - public static encode(message: onnx.ITensorProto, writer?: $protobuf.Writer): $protobuf.Writer; + public static encode( + message: onnx.ITensorProto, + writer?: $protobuf.Writer, + ): $protobuf.Writer; /** * Encodes the specified TensorProto message, length delimited. Does not implicitly {@link @@ -1223,7 +1351,10 @@ export namespace onnx { * @param [writer] Writer to encode to * @returns Writer */ - public static encodeDelimited(message: onnx.ITensorProto, writer?: $protobuf.Writer): $protobuf.Writer; + public static encodeDelimited( + message: onnx.ITensorProto, + writer?: $protobuf.Writer, + ): $protobuf.Writer; /** * Decodes a TensorProto message from the specified reader or buffer. @@ -1233,7 +1364,10 @@ export namespace onnx { * @throws {Error} If the payload is not a reader or valid buffer * @throws {$protobuf.util.ProtocolError} If required fields are missing */ - public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.TensorProto; + public static decode( + reader: $protobuf.Reader | Uint8Array, + length?: number, + ): onnx.TensorProto; /** * Decodes a TensorProto message from the specified reader or buffer, length delimited. @@ -1242,21 +1376,23 @@ export namespace onnx { * @throws {Error} If the payload is not a reader or valid buffer * @throws {$protobuf.util.ProtocolError} If required fields are missing */ - public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.TensorProto; + public static decodeDelimited( + reader: $protobuf.Reader | Uint8Array, + ): onnx.TensorProto; /** * Verifies a TensorProto message. * @param message Plain object to verify * @returns `null` if valid, otherwise the reason why it is not */ - public static verify(message: {[k: string]: any}): (string|null); + public static verify(message: { [k: string]: any }): string | null; /** * Creates a TensorProto message from a plain object. Also converts values to their respective internal types. * @param object Plain object * @returns TensorProto */ - public static fromObject(object: {[k: string]: any}): onnx.TensorProto; + public static fromObject(object: { [k: string]: any }): onnx.TensorProto; /** * Creates a plain object from a TensorProto message. Also converts values to other types if specified. @@ -1264,13 +1400,16 @@ export namespace onnx { * @param [options] Conversion options * @returns Plain object */ - public static toObject(message: onnx.TensorProto, options?: $protobuf.IConversionOptions): {[k: string]: any}; + public static toObject( + message: onnx.TensorProto, + options?: $protobuf.IConversionOptions, + ): { [k: string]: any }; /** * Converts this TensorProto to JSON. * @returns JSON object */ - public toJSON(): {[k: string]: any}; + public toJSON(): { [k: string]: any }; /** * Gets the default type url for TensorProto @@ -1281,7 +1420,6 @@ export namespace onnx { } namespace TensorProto { - /** DataType enum. */ enum DataType { UNDEFINED = 0, @@ -1304,16 +1442,16 @@ export namespace onnx { FLOAT8E4M3FN = 17, FLOAT8E4M3FNUZ = 18, FLOAT8E5M2 = 19, - FLOAT8E5M2FNUZ = 20 + FLOAT8E5M2FNUZ = 20, } /** Properties of a Segment. */ interface ISegment { /** Segment begin */ - begin?: (number|Long|null); + begin?: number | Long | null; /** Segment end */ - end?: (number|Long|null); + end?: number | Long | null; } /** Represents a Segment. */ @@ -1325,17 +1463,19 @@ export namespace onnx { constructor(properties?: onnx.TensorProto.ISegment); /** Segment begin. */ - public begin: (number|Long); + public begin: number | Long; /** Segment end. */ - public end: (number|Long); + public end: number | Long; /** * Creates a new Segment instance using the specified properties. * @param [properties] Properties to set * @returns Segment instance */ - public static create(properties?: onnx.TensorProto.ISegment): onnx.TensorProto.Segment; + public static create( + properties?: onnx.TensorProto.ISegment, + ): onnx.TensorProto.Segment; /** * Encodes the specified Segment message. Does not implicitly {@link onnx.TensorProto.Segment.verify|verify} @@ -1344,7 +1484,10 @@ export namespace onnx { * @param [writer] Writer to encode to * @returns Writer */ - public static encode(message: onnx.TensorProto.ISegment, writer?: $protobuf.Writer): $protobuf.Writer; + public static encode( + message: onnx.TensorProto.ISegment, + writer?: $protobuf.Writer, + ): $protobuf.Writer; /** * Encodes the specified Segment message, length delimited. Does not implicitly {@link @@ -1353,7 +1496,10 @@ export namespace onnx { * @param [writer] Writer to encode to * @returns Writer */ - public static encodeDelimited(message: onnx.TensorProto.ISegment, writer?: $protobuf.Writer): $protobuf.Writer; + public static encodeDelimited( + message: onnx.TensorProto.ISegment, + writer?: $protobuf.Writer, + ): $protobuf.Writer; /** * Decodes a Segment message from the specified reader or buffer. @@ -1363,7 +1509,10 @@ export namespace onnx { * @throws {Error} If the payload is not a reader or valid buffer * @throws {$protobuf.util.ProtocolError} If required fields are missing */ - public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.TensorProto.Segment; + public static decode( + reader: $protobuf.Reader | Uint8Array, + length?: number, + ): onnx.TensorProto.Segment; /** * Decodes a Segment message from the specified reader or buffer, length delimited. @@ -1372,21 +1521,25 @@ export namespace onnx { * @throws {Error} If the payload is not a reader or valid buffer * @throws {$protobuf.util.ProtocolError} If required fields are missing */ - public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.TensorProto.Segment; + public static decodeDelimited( + reader: $protobuf.Reader | Uint8Array, + ): onnx.TensorProto.Segment; /** * Verifies a Segment message. * @param message Plain object to verify * @returns `null` if valid, otherwise the reason why it is not */ - public static verify(message: {[k: string]: any}): (string|null); + public static verify(message: { [k: string]: any }): string | null; /** * Creates a Segment message from a plain object. Also converts values to their respective internal types. * @param object Plain object * @returns Segment */ - public static fromObject(object: {[k: string]: any}): onnx.TensorProto.Segment; + public static fromObject(object: { + [k: string]: any; + }): onnx.TensorProto.Segment; /** * Creates a plain object from a Segment message. Also converts values to other types if specified. @@ -1394,14 +1547,16 @@ export namespace onnx { * @param [options] Conversion options * @returns Plain object */ - public static toObject(message: onnx.TensorProto.Segment, options?: $protobuf.IConversionOptions): - {[k: string]: any}; + public static toObject( + message: onnx.TensorProto.Segment, + options?: $protobuf.IConversionOptions, + ): { [k: string]: any }; /** * Converts this Segment to JSON. * @returns JSON object */ - public toJSON(): {[k: string]: any}; + public toJSON(): { [k: string]: any }; /** * Gets the default type url for Segment @@ -1412,19 +1567,22 @@ export namespace onnx { } /** DataLocation enum. */ - enum DataLocation { DEFAULT = 0, EXTERNAL = 1 } + enum DataLocation { + DEFAULT = 0, + EXTERNAL = 1, + } } /** Properties of a SparseTensorProto. */ interface ISparseTensorProto { /** SparseTensorProto values */ - values?: (onnx.ITensorProto|null); + values?: onnx.ITensorProto | null; /** SparseTensorProto indices */ - indices?: (onnx.ITensorProto|null); + indices?: onnx.ITensorProto | null; /** SparseTensorProto dims */ - dims?: ((number | Long)[]|null); + dims?: (number | Long)[] | null; } /** Represents a SparseTensorProto. */ @@ -1436,20 +1594,22 @@ export namespace onnx { constructor(properties?: onnx.ISparseTensorProto); /** SparseTensorProto values. */ - public values?: (onnx.ITensorProto|null); + public values?: onnx.ITensorProto | null; /** SparseTensorProto indices. */ - public indices?: (onnx.ITensorProto|null); + public indices?: onnx.ITensorProto | null; /** SparseTensorProto dims. */ - public dims: (number|Long)[]; + public dims: (number | Long)[]; /** * Creates a new SparseTensorProto instance using the specified properties. * @param [properties] Properties to set * @returns SparseTensorProto instance */ - public static create(properties?: onnx.ISparseTensorProto): onnx.SparseTensorProto; + public static create( + properties?: onnx.ISparseTensorProto, + ): onnx.SparseTensorProto; /** * Encodes the specified SparseTensorProto message. Does not implicitly {@link onnx.SparseTensorProto.verify|verify} @@ -1458,7 +1618,10 @@ export namespace onnx { * @param [writer] Writer to encode to * @returns Writer */ - public static encode(message: onnx.ISparseTensorProto, writer?: $protobuf.Writer): $protobuf.Writer; + public static encode( + message: onnx.ISparseTensorProto, + writer?: $protobuf.Writer, + ): $protobuf.Writer; /** * Encodes the specified SparseTensorProto message, length delimited. Does not implicitly {@link @@ -1467,7 +1630,10 @@ export namespace onnx { * @param [writer] Writer to encode to * @returns Writer */ - public static encodeDelimited(message: onnx.ISparseTensorProto, writer?: $protobuf.Writer): $protobuf.Writer; + public static encodeDelimited( + message: onnx.ISparseTensorProto, + writer?: $protobuf.Writer, + ): $protobuf.Writer; /** * Decodes a SparseTensorProto message from the specified reader or buffer. @@ -1477,7 +1643,10 @@ export namespace onnx { * @throws {Error} If the payload is not a reader or valid buffer * @throws {$protobuf.util.ProtocolError} If required fields are missing */ - public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.SparseTensorProto; + public static decode( + reader: $protobuf.Reader | Uint8Array, + length?: number, + ): onnx.SparseTensorProto; /** * Decodes a SparseTensorProto message from the specified reader or buffer, length delimited. @@ -1486,21 +1655,25 @@ export namespace onnx { * @throws {Error} If the payload is not a reader or valid buffer * @throws {$protobuf.util.ProtocolError} If required fields are missing */ - public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.SparseTensorProto; + public static decodeDelimited( + reader: $protobuf.Reader | Uint8Array, + ): onnx.SparseTensorProto; /** * Verifies a SparseTensorProto message. * @param message Plain object to verify * @returns `null` if valid, otherwise the reason why it is not */ - public static verify(message: {[k: string]: any}): (string|null); + public static verify(message: { [k: string]: any }): string | null; /** * Creates a SparseTensorProto message from a plain object. Also converts values to their respective internal types. * @param object Plain object * @returns SparseTensorProto */ - public static fromObject(object: {[k: string]: any}): onnx.SparseTensorProto; + public static fromObject(object: { + [k: string]: any; + }): onnx.SparseTensorProto; /** * Creates a plain object from a SparseTensorProto message. Also converts values to other types if specified. @@ -1508,13 +1681,16 @@ export namespace onnx { * @param [options] Conversion options * @returns Plain object */ - public static toObject(message: onnx.SparseTensorProto, options?: $protobuf.IConversionOptions): {[k: string]: any}; + public static toObject( + message: onnx.SparseTensorProto, + options?: $protobuf.IConversionOptions, + ): { [k: string]: any }; /** * Converts this SparseTensorProto to JSON. * @returns JSON object */ - public toJSON(): {[k: string]: any}; + public toJSON(): { [k: string]: any }; /** * Gets the default type url for SparseTensorProto @@ -1527,7 +1703,7 @@ export namespace onnx { /** Properties of a TensorShapeProto. */ interface ITensorShapeProto { /** TensorShapeProto dim */ - dim?: (onnx.TensorShapeProto.IDimension[]|null); + dim?: onnx.TensorShapeProto.IDimension[] | null; } /** Represents a TensorShapeProto. */ @@ -1546,7 +1722,9 @@ export namespace onnx { * @param [properties] Properties to set * @returns TensorShapeProto instance */ - public static create(properties?: onnx.ITensorShapeProto): onnx.TensorShapeProto; + public static create( + properties?: onnx.ITensorShapeProto, + ): onnx.TensorShapeProto; /** * Encodes the specified TensorShapeProto message. Does not implicitly {@link onnx.TensorShapeProto.verify|verify} @@ -1555,7 +1733,10 @@ export namespace onnx { * @param [writer] Writer to encode to * @returns Writer */ - public static encode(message: onnx.ITensorShapeProto, writer?: $protobuf.Writer): $protobuf.Writer; + public static encode( + message: onnx.ITensorShapeProto, + writer?: $protobuf.Writer, + ): $protobuf.Writer; /** * Encodes the specified TensorShapeProto message, length delimited. Does not implicitly {@link @@ -1564,7 +1745,10 @@ export namespace onnx { * @param [writer] Writer to encode to * @returns Writer */ - public static encodeDelimited(message: onnx.ITensorShapeProto, writer?: $protobuf.Writer): $protobuf.Writer; + public static encodeDelimited( + message: onnx.ITensorShapeProto, + writer?: $protobuf.Writer, + ): $protobuf.Writer; /** * Decodes a TensorShapeProto message from the specified reader or buffer. @@ -1574,7 +1758,10 @@ export namespace onnx { * @throws {Error} If the payload is not a reader or valid buffer * @throws {$protobuf.util.ProtocolError} If required fields are missing */ - public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.TensorShapeProto; + public static decode( + reader: $protobuf.Reader | Uint8Array, + length?: number, + ): onnx.TensorShapeProto; /** * Decodes a TensorShapeProto message from the specified reader or buffer, length delimited. @@ -1583,21 +1770,25 @@ export namespace onnx { * @throws {Error} If the payload is not a reader or valid buffer * @throws {$protobuf.util.ProtocolError} If required fields are missing */ - public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.TensorShapeProto; + public static decodeDelimited( + reader: $protobuf.Reader | Uint8Array, + ): onnx.TensorShapeProto; /** * Verifies a TensorShapeProto message. * @param message Plain object to verify * @returns `null` if valid, otherwise the reason why it is not */ - public static verify(message: {[k: string]: any}): (string|null); + public static verify(message: { [k: string]: any }): string | null; /** * Creates a TensorShapeProto message from a plain object. Also converts values to their respective internal types. * @param object Plain object * @returns TensorShapeProto */ - public static fromObject(object: {[k: string]: any}): onnx.TensorShapeProto; + public static fromObject(object: { + [k: string]: any; + }): onnx.TensorShapeProto; /** * Creates a plain object from a TensorShapeProto message. Also converts values to other types if specified. @@ -1605,13 +1796,16 @@ export namespace onnx { * @param [options] Conversion options * @returns Plain object */ - public static toObject(message: onnx.TensorShapeProto, options?: $protobuf.IConversionOptions): {[k: string]: any}; + public static toObject( + message: onnx.TensorShapeProto, + options?: $protobuf.IConversionOptions, + ): { [k: string]: any }; /** * Converts this TensorShapeProto to JSON. * @returns JSON object */ - public toJSON(): {[k: string]: any}; + public toJSON(): { [k: string]: any }; /** * Gets the default type url for TensorShapeProto @@ -1622,17 +1816,16 @@ export namespace onnx { } namespace TensorShapeProto { - /** Properties of a Dimension. */ interface IDimension { /** Dimension dimValue */ - dimValue?: (number|Long|null); + dimValue?: number | Long | null; /** Dimension dimParam */ - dimParam?: (string|null); + dimParam?: string | null; /** Dimension denotation */ - denotation?: (string|null); + denotation?: string | null; } /** Represents a Dimension. */ @@ -1644,23 +1837,25 @@ export namespace onnx { constructor(properties?: onnx.TensorShapeProto.IDimension); /** Dimension dimValue. */ - public dimValue?: (number|Long|null); + public dimValue?: number | Long | null; /** Dimension dimParam. */ - public dimParam?: (string|null); + public dimParam?: string | null; /** Dimension denotation. */ public denotation: string; /** Dimension value. */ - public value?: ('dimValue'|'dimParam'); + public value?: "dimValue" | "dimParam"; /** * Creates a new Dimension instance using the specified properties. * @param [properties] Properties to set * @returns Dimension instance */ - public static create(properties?: onnx.TensorShapeProto.IDimension): onnx.TensorShapeProto.Dimension; + public static create( + properties?: onnx.TensorShapeProto.IDimension, + ): onnx.TensorShapeProto.Dimension; /** * Encodes the specified Dimension message. Does not implicitly {@link @@ -1669,7 +1864,10 @@ export namespace onnx { * @param [writer] Writer to encode to * @returns Writer */ - public static encode(message: onnx.TensorShapeProto.IDimension, writer?: $protobuf.Writer): $protobuf.Writer; + public static encode( + message: onnx.TensorShapeProto.IDimension, + writer?: $protobuf.Writer, + ): $protobuf.Writer; /** * Encodes the specified Dimension message, length delimited. Does not implicitly {@link @@ -1678,8 +1876,10 @@ export namespace onnx { * @param [writer] Writer to encode to * @returns Writer */ - public static encodeDelimited(message: onnx.TensorShapeProto.IDimension, writer?: $protobuf.Writer): - $protobuf.Writer; + public static encodeDelimited( + message: onnx.TensorShapeProto.IDimension, + writer?: $protobuf.Writer, + ): $protobuf.Writer; /** * Decodes a Dimension message from the specified reader or buffer. @@ -1689,7 +1889,10 @@ export namespace onnx { * @throws {Error} If the payload is not a reader or valid buffer * @throws {$protobuf.util.ProtocolError} If required fields are missing */ - public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.TensorShapeProto.Dimension; + public static decode( + reader: $protobuf.Reader | Uint8Array, + length?: number, + ): onnx.TensorShapeProto.Dimension; /** * Decodes a Dimension message from the specified reader or buffer, length delimited. @@ -1698,21 +1901,25 @@ export namespace onnx { * @throws {Error} If the payload is not a reader or valid buffer * @throws {$protobuf.util.ProtocolError} If required fields are missing */ - public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.TensorShapeProto.Dimension; + public static decodeDelimited( + reader: $protobuf.Reader | Uint8Array, + ): onnx.TensorShapeProto.Dimension; /** * Verifies a Dimension message. * @param message Plain object to verify * @returns `null` if valid, otherwise the reason why it is not */ - public static verify(message: {[k: string]: any}): (string|null); + public static verify(message: { [k: string]: any }): string | null; /** * Creates a Dimension message from a plain object. Also converts values to their respective internal types. * @param object Plain object * @returns Dimension */ - public static fromObject(object: {[k: string]: any}): onnx.TensorShapeProto.Dimension; + public static fromObject(object: { + [k: string]: any; + }): onnx.TensorShapeProto.Dimension; /** * Creates a plain object from a Dimension message. Also converts values to other types if specified. @@ -1720,14 +1927,16 @@ export namespace onnx { * @param [options] Conversion options * @returns Plain object */ - public static toObject(message: onnx.TensorShapeProto.Dimension, options?: $protobuf.IConversionOptions): - {[k: string]: any}; + public static toObject( + message: onnx.TensorShapeProto.Dimension, + options?: $protobuf.IConversionOptions, + ): { [k: string]: any }; /** * Converts this Dimension to JSON. * @returns JSON object */ - public toJSON(): {[k: string]: any}; + public toJSON(): { [k: string]: any }; /** * Gets the default type url for Dimension @@ -1741,22 +1950,22 @@ export namespace onnx { /** Properties of a TypeProto. */ interface ITypeProto { /** TypeProto tensorType */ - tensorType?: (onnx.TypeProto.ITensor|null); + tensorType?: onnx.TypeProto.ITensor | null; /** TypeProto sequenceType */ - sequenceType?: (onnx.TypeProto.ISequence|null); + sequenceType?: onnx.TypeProto.ISequence | null; /** TypeProto mapType */ - mapType?: (onnx.TypeProto.IMap|null); + mapType?: onnx.TypeProto.IMap | null; /** TypeProto optionalType */ - optionalType?: (onnx.TypeProto.IOptional|null); + optionalType?: onnx.TypeProto.IOptional | null; /** TypeProto sparseTensorType */ - sparseTensorType?: (onnx.TypeProto.ISparseTensor|null); + sparseTensorType?: onnx.TypeProto.ISparseTensor | null; /** TypeProto denotation */ - denotation?: (string|null); + denotation?: string | null; } /** Represents a TypeProto. */ @@ -1768,25 +1977,30 @@ export namespace onnx { constructor(properties?: onnx.ITypeProto); /** TypeProto tensorType. */ - public tensorType?: (onnx.TypeProto.ITensor|null); + public tensorType?: onnx.TypeProto.ITensor | null; /** TypeProto sequenceType. */ - public sequenceType?: (onnx.TypeProto.ISequence|null); + public sequenceType?: onnx.TypeProto.ISequence | null; /** TypeProto mapType. */ - public mapType?: (onnx.TypeProto.IMap|null); + public mapType?: onnx.TypeProto.IMap | null; /** TypeProto optionalType. */ - public optionalType?: (onnx.TypeProto.IOptional|null); + public optionalType?: onnx.TypeProto.IOptional | null; /** TypeProto sparseTensorType. */ - public sparseTensorType?: (onnx.TypeProto.ISparseTensor|null); + public sparseTensorType?: onnx.TypeProto.ISparseTensor | null; /** TypeProto denotation. */ public denotation: string; /** TypeProto value. */ - public value?: ('tensorType'|'sequenceType'|'mapType'|'optionalType'|'sparseTensorType'); + public value?: + | "tensorType" + | "sequenceType" + | "mapType" + | "optionalType" + | "sparseTensorType"; /** * Creates a new TypeProto instance using the specified properties. @@ -1801,7 +2015,10 @@ export namespace onnx { * @param [writer] Writer to encode to * @returns Writer */ - public static encode(message: onnx.ITypeProto, writer?: $protobuf.Writer): $protobuf.Writer; + public static encode( + message: onnx.ITypeProto, + writer?: $protobuf.Writer, + ): $protobuf.Writer; /** * Encodes the specified TypeProto message, length delimited. Does not implicitly {@link @@ -1810,7 +2027,10 @@ export namespace onnx { * @param [writer] Writer to encode to * @returns Writer */ - public static encodeDelimited(message: onnx.ITypeProto, writer?: $protobuf.Writer): $protobuf.Writer; + public static encodeDelimited( + message: onnx.ITypeProto, + writer?: $protobuf.Writer, + ): $protobuf.Writer; /** * Decodes a TypeProto message from the specified reader or buffer. @@ -1820,7 +2040,10 @@ export namespace onnx { * @throws {Error} If the payload is not a reader or valid buffer * @throws {$protobuf.util.ProtocolError} If required fields are missing */ - public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.TypeProto; + public static decode( + reader: $protobuf.Reader | Uint8Array, + length?: number, + ): onnx.TypeProto; /** * Decodes a TypeProto message from the specified reader or buffer, length delimited. @@ -1829,21 +2052,23 @@ export namespace onnx { * @throws {Error} If the payload is not a reader or valid buffer * @throws {$protobuf.util.ProtocolError} If required fields are missing */ - public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.TypeProto; + public static decodeDelimited( + reader: $protobuf.Reader | Uint8Array, + ): onnx.TypeProto; /** * Verifies a TypeProto message. * @param message Plain object to verify * @returns `null` if valid, otherwise the reason why it is not */ - public static verify(message: {[k: string]: any}): (string|null); + public static verify(message: { [k: string]: any }): string | null; /** * Creates a TypeProto message from a plain object. Also converts values to their respective internal types. * @param object Plain object * @returns TypeProto */ - public static fromObject(object: {[k: string]: any}): onnx.TypeProto; + public static fromObject(object: { [k: string]: any }): onnx.TypeProto; /** * Creates a plain object from a TypeProto message. Also converts values to other types if specified. @@ -1851,13 +2076,16 @@ export namespace onnx { * @param [options] Conversion options * @returns Plain object */ - public static toObject(message: onnx.TypeProto, options?: $protobuf.IConversionOptions): {[k: string]: any}; + public static toObject( + message: onnx.TypeProto, + options?: $protobuf.IConversionOptions, + ): { [k: string]: any }; /** * Converts this TypeProto to JSON. * @returns JSON object */ - public toJSON(): {[k: string]: any}; + public toJSON(): { [k: string]: any }; /** * Gets the default type url for TypeProto @@ -1868,14 +2096,13 @@ export namespace onnx { } namespace TypeProto { - /** Properties of a Tensor. */ interface ITensor { /** Tensor elemType */ - elemType?: (number|null); + elemType?: number | null; /** Tensor shape */ - shape?: (onnx.ITensorShapeProto|null); + shape?: onnx.ITensorShapeProto | null; } /** Represents a Tensor. */ @@ -1890,14 +2117,16 @@ export namespace onnx { public elemType: number; /** Tensor shape. */ - public shape?: (onnx.ITensorShapeProto|null); + public shape?: onnx.ITensorShapeProto | null; /** * Creates a new Tensor instance using the specified properties. * @param [properties] Properties to set * @returns Tensor instance */ - public static create(properties?: onnx.TypeProto.ITensor): onnx.TypeProto.Tensor; + public static create( + properties?: onnx.TypeProto.ITensor, + ): onnx.TypeProto.Tensor; /** * Encodes the specified Tensor message. Does not implicitly {@link onnx.TypeProto.Tensor.verify|verify} messages. @@ -1905,7 +2134,10 @@ export namespace onnx { * @param [writer] Writer to encode to * @returns Writer */ - public static encode(message: onnx.TypeProto.ITensor, writer?: $protobuf.Writer): $protobuf.Writer; + public static encode( + message: onnx.TypeProto.ITensor, + writer?: $protobuf.Writer, + ): $protobuf.Writer; /** * Encodes the specified Tensor message, length delimited. Does not implicitly {@link @@ -1914,7 +2146,10 @@ export namespace onnx { * @param [writer] Writer to encode to * @returns Writer */ - public static encodeDelimited(message: onnx.TypeProto.ITensor, writer?: $protobuf.Writer): $protobuf.Writer; + public static encodeDelimited( + message: onnx.TypeProto.ITensor, + writer?: $protobuf.Writer, + ): $protobuf.Writer; /** * Decodes a Tensor message from the specified reader or buffer. @@ -1924,7 +2159,10 @@ export namespace onnx { * @throws {Error} If the payload is not a reader or valid buffer * @throws {$protobuf.util.ProtocolError} If required fields are missing */ - public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.TypeProto.Tensor; + public static decode( + reader: $protobuf.Reader | Uint8Array, + length?: number, + ): onnx.TypeProto.Tensor; /** * Decodes a Tensor message from the specified reader or buffer, length delimited. @@ -1933,21 +2171,25 @@ export namespace onnx { * @throws {Error} If the payload is not a reader or valid buffer * @throws {$protobuf.util.ProtocolError} If required fields are missing */ - public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.TypeProto.Tensor; + public static decodeDelimited( + reader: $protobuf.Reader | Uint8Array, + ): onnx.TypeProto.Tensor; /** * Verifies a Tensor message. * @param message Plain object to verify * @returns `null` if valid, otherwise the reason why it is not */ - public static verify(message: {[k: string]: any}): (string|null); + public static verify(message: { [k: string]: any }): string | null; /** * Creates a Tensor message from a plain object. Also converts values to their respective internal types. * @param object Plain object * @returns Tensor */ - public static fromObject(object: {[k: string]: any}): onnx.TypeProto.Tensor; + public static fromObject(object: { + [k: string]: any; + }): onnx.TypeProto.Tensor; /** * Creates a plain object from a Tensor message. Also converts values to other types if specified. @@ -1955,14 +2197,16 @@ export namespace onnx { * @param [options] Conversion options * @returns Plain object */ - public static toObject(message: onnx.TypeProto.Tensor, options?: $protobuf.IConversionOptions): - {[k: string]: any}; + public static toObject( + message: onnx.TypeProto.Tensor, + options?: $protobuf.IConversionOptions, + ): { [k: string]: any }; /** * Converts this Tensor to JSON. * @returns JSON object */ - public toJSON(): {[k: string]: any}; + public toJSON(): { [k: string]: any }; /** * Gets the default type url for Tensor @@ -1975,7 +2219,7 @@ export namespace onnx { /** Properties of a Sequence. */ interface ISequence { /** Sequence elemType */ - elemType?: (onnx.ITypeProto|null); + elemType?: onnx.ITypeProto | null; } /** Represents a Sequence. */ @@ -1987,14 +2231,16 @@ export namespace onnx { constructor(properties?: onnx.TypeProto.ISequence); /** Sequence elemType. */ - public elemType?: (onnx.ITypeProto|null); + public elemType?: onnx.ITypeProto | null; /** * Creates a new Sequence instance using the specified properties. * @param [properties] Properties to set * @returns Sequence instance */ - public static create(properties?: onnx.TypeProto.ISequence): onnx.TypeProto.Sequence; + public static create( + properties?: onnx.TypeProto.ISequence, + ): onnx.TypeProto.Sequence; /** * Encodes the specified Sequence message. Does not implicitly {@link onnx.TypeProto.Sequence.verify|verify} @@ -2003,7 +2249,10 @@ export namespace onnx { * @param [writer] Writer to encode to * @returns Writer */ - public static encode(message: onnx.TypeProto.ISequence, writer?: $protobuf.Writer): $protobuf.Writer; + public static encode( + message: onnx.TypeProto.ISequence, + writer?: $protobuf.Writer, + ): $protobuf.Writer; /** * Encodes the specified Sequence message, length delimited. Does not implicitly {@link @@ -2012,7 +2261,10 @@ export namespace onnx { * @param [writer] Writer to encode to * @returns Writer */ - public static encodeDelimited(message: onnx.TypeProto.ISequence, writer?: $protobuf.Writer): $protobuf.Writer; + public static encodeDelimited( + message: onnx.TypeProto.ISequence, + writer?: $protobuf.Writer, + ): $protobuf.Writer; /** * Decodes a Sequence message from the specified reader or buffer. @@ -2022,7 +2274,10 @@ export namespace onnx { * @throws {Error} If the payload is not a reader or valid buffer * @throws {$protobuf.util.ProtocolError} If required fields are missing */ - public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.TypeProto.Sequence; + public static decode( + reader: $protobuf.Reader | Uint8Array, + length?: number, + ): onnx.TypeProto.Sequence; /** * Decodes a Sequence message from the specified reader or buffer, length delimited. @@ -2031,21 +2286,25 @@ export namespace onnx { * @throws {Error} If the payload is not a reader or valid buffer * @throws {$protobuf.util.ProtocolError} If required fields are missing */ - public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.TypeProto.Sequence; + public static decodeDelimited( + reader: $protobuf.Reader | Uint8Array, + ): onnx.TypeProto.Sequence; /** * Verifies a Sequence message. * @param message Plain object to verify * @returns `null` if valid, otherwise the reason why it is not */ - public static verify(message: {[k: string]: any}): (string|null); + public static verify(message: { [k: string]: any }): string | null; /** * Creates a Sequence message from a plain object. Also converts values to their respective internal types. * @param object Plain object * @returns Sequence */ - public static fromObject(object: {[k: string]: any}): onnx.TypeProto.Sequence; + public static fromObject(object: { + [k: string]: any; + }): onnx.TypeProto.Sequence; /** * Creates a plain object from a Sequence message. Also converts values to other types if specified. @@ -2053,14 +2312,16 @@ export namespace onnx { * @param [options] Conversion options * @returns Plain object */ - public static toObject(message: onnx.TypeProto.Sequence, options?: $protobuf.IConversionOptions): - {[k: string]: any}; + public static toObject( + message: onnx.TypeProto.Sequence, + options?: $protobuf.IConversionOptions, + ): { [k: string]: any }; /** * Converts this Sequence to JSON. * @returns JSON object */ - public toJSON(): {[k: string]: any}; + public toJSON(): { [k: string]: any }; /** * Gets the default type url for Sequence @@ -2073,10 +2334,10 @@ export namespace onnx { /** Properties of a Map. */ interface IMap { /** Map keyType */ - keyType?: (number|null); + keyType?: number | null; /** Map valueType */ - valueType?: (onnx.ITypeProto|null); + valueType?: onnx.ITypeProto | null; } /** Represents a Map. */ @@ -2091,14 +2352,16 @@ export namespace onnx { public keyType: number; /** Map valueType. */ - public valueType?: (onnx.ITypeProto|null); + public valueType?: onnx.ITypeProto | null; /** * Creates a new Map instance using the specified properties. * @param [properties] Properties to set * @returns Map instance */ - public static create(properties?: onnx.TypeProto.IMap): onnx.TypeProto.Map; + public static create( + properties?: onnx.TypeProto.IMap, + ): onnx.TypeProto.Map; /** * Encodes the specified Map message. Does not implicitly {@link onnx.TypeProto.Map.verify|verify} messages. @@ -2106,7 +2369,10 @@ export namespace onnx { * @param [writer] Writer to encode to * @returns Writer */ - public static encode(message: onnx.TypeProto.IMap, writer?: $protobuf.Writer): $protobuf.Writer; + public static encode( + message: onnx.TypeProto.IMap, + writer?: $protobuf.Writer, + ): $protobuf.Writer; /** * Encodes the specified Map message, length delimited. Does not implicitly {@link @@ -2115,7 +2381,10 @@ export namespace onnx { * @param [writer] Writer to encode to * @returns Writer */ - public static encodeDelimited(message: onnx.TypeProto.IMap, writer?: $protobuf.Writer): $protobuf.Writer; + public static encodeDelimited( + message: onnx.TypeProto.IMap, + writer?: $protobuf.Writer, + ): $protobuf.Writer; /** * Decodes a Map message from the specified reader or buffer. @@ -2125,7 +2394,10 @@ export namespace onnx { * @throws {Error} If the payload is not a reader or valid buffer * @throws {$protobuf.util.ProtocolError} If required fields are missing */ - public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.TypeProto.Map; + public static decode( + reader: $protobuf.Reader | Uint8Array, + length?: number, + ): onnx.TypeProto.Map; /** * Decodes a Map message from the specified reader or buffer, length delimited. @@ -2134,21 +2406,25 @@ export namespace onnx { * @throws {Error} If the payload is not a reader or valid buffer * @throws {$protobuf.util.ProtocolError} If required fields are missing */ - public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.TypeProto.Map; + public static decodeDelimited( + reader: $protobuf.Reader | Uint8Array, + ): onnx.TypeProto.Map; /** * Verifies a Map message. * @param message Plain object to verify * @returns `null` if valid, otherwise the reason why it is not */ - public static verify(message: {[k: string]: any}): (string|null); + public static verify(message: { [k: string]: any }): string | null; /** * Creates a Map message from a plain object. Also converts values to their respective internal types. * @param object Plain object * @returns Map */ - public static fromObject(object: {[k: string]: any}): onnx.TypeProto.Map; + public static fromObject(object: { + [k: string]: any; + }): onnx.TypeProto.Map; /** * Creates a plain object from a Map message. Also converts values to other types if specified. @@ -2156,13 +2432,16 @@ export namespace onnx { * @param [options] Conversion options * @returns Plain object */ - public static toObject(message: onnx.TypeProto.Map, options?: $protobuf.IConversionOptions): {[k: string]: any}; + public static toObject( + message: onnx.TypeProto.Map, + options?: $protobuf.IConversionOptions, + ): { [k: string]: any }; /** * Converts this Map to JSON. * @returns JSON object */ - public toJSON(): {[k: string]: any}; + public toJSON(): { [k: string]: any }; /** * Gets the default type url for Map @@ -2175,7 +2454,7 @@ export namespace onnx { /** Properties of an Optional. */ interface IOptional { /** Optional elemType */ - elemType?: (onnx.ITypeProto|null); + elemType?: onnx.ITypeProto | null; } /** Represents an Optional. */ @@ -2187,14 +2466,16 @@ export namespace onnx { constructor(properties?: onnx.TypeProto.IOptional); /** Optional elemType. */ - public elemType?: (onnx.ITypeProto|null); + public elemType?: onnx.ITypeProto | null; /** * Creates a new Optional instance using the specified properties. * @param [properties] Properties to set * @returns Optional instance */ - public static create(properties?: onnx.TypeProto.IOptional): onnx.TypeProto.Optional; + public static create( + properties?: onnx.TypeProto.IOptional, + ): onnx.TypeProto.Optional; /** * Encodes the specified Optional message. Does not implicitly {@link onnx.TypeProto.Optional.verify|verify} @@ -2203,7 +2484,10 @@ export namespace onnx { * @param [writer] Writer to encode to * @returns Writer */ - public static encode(message: onnx.TypeProto.IOptional, writer?: $protobuf.Writer): $protobuf.Writer; + public static encode( + message: onnx.TypeProto.IOptional, + writer?: $protobuf.Writer, + ): $protobuf.Writer; /** * Encodes the specified Optional message, length delimited. Does not implicitly {@link @@ -2212,7 +2496,10 @@ export namespace onnx { * @param [writer] Writer to encode to * @returns Writer */ - public static encodeDelimited(message: onnx.TypeProto.IOptional, writer?: $protobuf.Writer): $protobuf.Writer; + public static encodeDelimited( + message: onnx.TypeProto.IOptional, + writer?: $protobuf.Writer, + ): $protobuf.Writer; /** * Decodes an Optional message from the specified reader or buffer. @@ -2222,7 +2509,10 @@ export namespace onnx { * @throws {Error} If the payload is not a reader or valid buffer * @throws {$protobuf.util.ProtocolError} If required fields are missing */ - public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.TypeProto.Optional; + public static decode( + reader: $protobuf.Reader | Uint8Array, + length?: number, + ): onnx.TypeProto.Optional; /** * Decodes an Optional message from the specified reader or buffer, length delimited. @@ -2231,21 +2521,25 @@ export namespace onnx { * @throws {Error} If the payload is not a reader or valid buffer * @throws {$protobuf.util.ProtocolError} If required fields are missing */ - public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.TypeProto.Optional; + public static decodeDelimited( + reader: $protobuf.Reader | Uint8Array, + ): onnx.TypeProto.Optional; /** * Verifies an Optional message. * @param message Plain object to verify * @returns `null` if valid, otherwise the reason why it is not */ - public static verify(message: {[k: string]: any}): (string|null); + public static verify(message: { [k: string]: any }): string | null; /** * Creates an Optional message from a plain object. Also converts values to their respective internal types. * @param object Plain object * @returns Optional */ - public static fromObject(object: {[k: string]: any}): onnx.TypeProto.Optional; + public static fromObject(object: { + [k: string]: any; + }): onnx.TypeProto.Optional; /** * Creates a plain object from an Optional message. Also converts values to other types if specified. @@ -2253,14 +2547,16 @@ export namespace onnx { * @param [options] Conversion options * @returns Plain object */ - public static toObject(message: onnx.TypeProto.Optional, options?: $protobuf.IConversionOptions): - {[k: string]: any}; + public static toObject( + message: onnx.TypeProto.Optional, + options?: $protobuf.IConversionOptions, + ): { [k: string]: any }; /** * Converts this Optional to JSON. * @returns JSON object */ - public toJSON(): {[k: string]: any}; + public toJSON(): { [k: string]: any }; /** * Gets the default type url for Optional @@ -2273,10 +2569,10 @@ export namespace onnx { /** Properties of a SparseTensor. */ interface ISparseTensor { /** SparseTensor elemType */ - elemType?: (number|null); + elemType?: number | null; /** SparseTensor shape */ - shape?: (onnx.ITensorShapeProto|null); + shape?: onnx.ITensorShapeProto | null; } /** Represents a SparseTensor. */ @@ -2291,14 +2587,16 @@ export namespace onnx { public elemType: number; /** SparseTensor shape. */ - public shape?: (onnx.ITensorShapeProto|null); + public shape?: onnx.ITensorShapeProto | null; /** * Creates a new SparseTensor instance using the specified properties. * @param [properties] Properties to set * @returns SparseTensor instance */ - public static create(properties?: onnx.TypeProto.ISparseTensor): onnx.TypeProto.SparseTensor; + public static create( + properties?: onnx.TypeProto.ISparseTensor, + ): onnx.TypeProto.SparseTensor; /** * Encodes the specified SparseTensor message. Does not implicitly {@link @@ -2307,7 +2605,10 @@ export namespace onnx { * @param [writer] Writer to encode to * @returns Writer */ - public static encode(message: onnx.TypeProto.ISparseTensor, writer?: $protobuf.Writer): $protobuf.Writer; + public static encode( + message: onnx.TypeProto.ISparseTensor, + writer?: $protobuf.Writer, + ): $protobuf.Writer; /** * Encodes the specified SparseTensor message, length delimited. Does not implicitly {@link @@ -2316,7 +2617,10 @@ export namespace onnx { * @param [writer] Writer to encode to * @returns Writer */ - public static encodeDelimited(message: onnx.TypeProto.ISparseTensor, writer?: $protobuf.Writer): $protobuf.Writer; + public static encodeDelimited( + message: onnx.TypeProto.ISparseTensor, + writer?: $protobuf.Writer, + ): $protobuf.Writer; /** * Decodes a SparseTensor message from the specified reader or buffer. @@ -2326,7 +2630,10 @@ export namespace onnx { * @throws {Error} If the payload is not a reader or valid buffer * @throws {$protobuf.util.ProtocolError} If required fields are missing */ - public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.TypeProto.SparseTensor; + public static decode( + reader: $protobuf.Reader | Uint8Array, + length?: number, + ): onnx.TypeProto.SparseTensor; /** * Decodes a SparseTensor message from the specified reader or buffer, length delimited. @@ -2335,21 +2642,25 @@ export namespace onnx { * @throws {Error} If the payload is not a reader or valid buffer * @throws {$protobuf.util.ProtocolError} If required fields are missing */ - public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.TypeProto.SparseTensor; + public static decodeDelimited( + reader: $protobuf.Reader | Uint8Array, + ): onnx.TypeProto.SparseTensor; /** * Verifies a SparseTensor message. * @param message Plain object to verify * @returns `null` if valid, otherwise the reason why it is not */ - public static verify(message: {[k: string]: any}): (string|null); + public static verify(message: { [k: string]: any }): string | null; /** * Creates a SparseTensor message from a plain object. Also converts values to their respective internal types. * @param object Plain object * @returns SparseTensor */ - public static fromObject(object: {[k: string]: any}): onnx.TypeProto.SparseTensor; + public static fromObject(object: { + [k: string]: any; + }): onnx.TypeProto.SparseTensor; /** * Creates a plain object from a SparseTensor message. Also converts values to other types if specified. @@ -2357,14 +2668,16 @@ export namespace onnx { * @param [options] Conversion options * @returns Plain object */ - public static toObject(message: onnx.TypeProto.SparseTensor, options?: $protobuf.IConversionOptions): - {[k: string]: any}; + public static toObject( + message: onnx.TypeProto.SparseTensor, + options?: $protobuf.IConversionOptions, + ): { [k: string]: any }; /** * Converts this SparseTensor to JSON. * @returns JSON object */ - public toJSON(): {[k: string]: any}; + public toJSON(): { [k: string]: any }; /** * Gets the default type url for SparseTensor @@ -2378,10 +2691,10 @@ export namespace onnx { /** Properties of an OperatorSetIdProto. */ interface IOperatorSetIdProto { /** OperatorSetIdProto domain */ - domain?: (string|null); + domain?: string | null; /** OperatorSetIdProto version */ - version?: (number|Long|null); + version?: number | Long | null; } /** Represents an OperatorSetIdProto. */ @@ -2396,14 +2709,16 @@ export namespace onnx { public domain: string; /** OperatorSetIdProto version. */ - public version: (number|Long); + public version: number | Long; /** * Creates a new OperatorSetIdProto instance using the specified properties. * @param [properties] Properties to set * @returns OperatorSetIdProto instance */ - public static create(properties?: onnx.IOperatorSetIdProto): onnx.OperatorSetIdProto; + public static create( + properties?: onnx.IOperatorSetIdProto, + ): onnx.OperatorSetIdProto; /** * Encodes the specified OperatorSetIdProto message. Does not implicitly {@link @@ -2412,7 +2727,10 @@ export namespace onnx { * @param [writer] Writer to encode to * @returns Writer */ - public static encode(message: onnx.IOperatorSetIdProto, writer?: $protobuf.Writer): $protobuf.Writer; + public static encode( + message: onnx.IOperatorSetIdProto, + writer?: $protobuf.Writer, + ): $protobuf.Writer; /** * Encodes the specified OperatorSetIdProto message, length delimited. Does not implicitly {@link @@ -2421,7 +2739,10 @@ export namespace onnx { * @param [writer] Writer to encode to * @returns Writer */ - public static encodeDelimited(message: onnx.IOperatorSetIdProto, writer?: $protobuf.Writer): $protobuf.Writer; + public static encodeDelimited( + message: onnx.IOperatorSetIdProto, + writer?: $protobuf.Writer, + ): $protobuf.Writer; /** * Decodes an OperatorSetIdProto message from the specified reader or buffer. @@ -2431,7 +2752,10 @@ export namespace onnx { * @throws {Error} If the payload is not a reader or valid buffer * @throws {$protobuf.util.ProtocolError} If required fields are missing */ - public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.OperatorSetIdProto; + public static decode( + reader: $protobuf.Reader | Uint8Array, + length?: number, + ): onnx.OperatorSetIdProto; /** * Decodes an OperatorSetIdProto message from the specified reader or buffer, length delimited. @@ -2440,14 +2764,16 @@ export namespace onnx { * @throws {Error} If the payload is not a reader or valid buffer * @throws {$protobuf.util.ProtocolError} If required fields are missing */ - public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.OperatorSetIdProto; + public static decodeDelimited( + reader: $protobuf.Reader | Uint8Array, + ): onnx.OperatorSetIdProto; /** * Verifies an OperatorSetIdProto message. * @param message Plain object to verify * @returns `null` if valid, otherwise the reason why it is not */ - public static verify(message: {[k: string]: any}): (string|null); + public static verify(message: { [k: string]: any }): string | null; /** * Creates an OperatorSetIdProto message from a plain object. Also converts values to their respective internal @@ -2455,7 +2781,9 @@ export namespace onnx { * @param object Plain object * @returns OperatorSetIdProto */ - public static fromObject(object: {[k: string]: any}): onnx.OperatorSetIdProto; + public static fromObject(object: { + [k: string]: any; + }): onnx.OperatorSetIdProto; /** * Creates a plain object from an OperatorSetIdProto message. Also converts values to other types if specified. @@ -2463,14 +2791,16 @@ export namespace onnx { * @param [options] Conversion options * @returns Plain object */ - public static toObject(message: onnx.OperatorSetIdProto, options?: $protobuf.IConversionOptions): - {[k: string]: any}; + public static toObject( + message: onnx.OperatorSetIdProto, + options?: $protobuf.IConversionOptions, + ): { [k: string]: any }; /** * Converts this OperatorSetIdProto to JSON. * @returns JSON object */ - public toJSON(): {[k: string]: any}; + public toJSON(): { [k: string]: any }; /** * Gets the default type url for OperatorSetIdProto @@ -2481,36 +2811,39 @@ export namespace onnx { } /** OperatorStatus enum. */ - enum OperatorStatus { EXPERIMENTAL = 0, STABLE = 1 } + enum OperatorStatus { + EXPERIMENTAL = 0, + STABLE = 1, + } /** Properties of a FunctionProto. */ interface IFunctionProto { /** FunctionProto name */ - name?: (string|null); + name?: string | null; /** FunctionProto input */ - input?: (string[]|null); + input?: string[] | null; /** FunctionProto output */ - output?: (string[]|null); + output?: string[] | null; /** FunctionProto attribute */ - attribute?: (string[]|null); + attribute?: string[] | null; /** FunctionProto attributeProto */ - attributeProto?: (onnx.IAttributeProto[]|null); + attributeProto?: onnx.IAttributeProto[] | null; /** FunctionProto node */ - node?: (onnx.INodeProto[]|null); + node?: onnx.INodeProto[] | null; /** FunctionProto docString */ - docString?: (string|null); + docString?: string | null; /** FunctionProto opsetImport */ - opsetImport?: (onnx.IOperatorSetIdProto[]|null); + opsetImport?: onnx.IOperatorSetIdProto[] | null; /** FunctionProto domain */ - domain?: (string|null); + domain?: string | null; } /** Represents a FunctionProto. */ @@ -2562,7 +2895,10 @@ export namespace onnx { * @param [writer] Writer to encode to * @returns Writer */ - public static encode(message: onnx.IFunctionProto, writer?: $protobuf.Writer): $protobuf.Writer; + public static encode( + message: onnx.IFunctionProto, + writer?: $protobuf.Writer, + ): $protobuf.Writer; /** * Encodes the specified FunctionProto message, length delimited. Does not implicitly {@link @@ -2571,7 +2907,10 @@ export namespace onnx { * @param [writer] Writer to encode to * @returns Writer */ - public static encodeDelimited(message: onnx.IFunctionProto, writer?: $protobuf.Writer): $protobuf.Writer; + public static encodeDelimited( + message: onnx.IFunctionProto, + writer?: $protobuf.Writer, + ): $protobuf.Writer; /** * Decodes a FunctionProto message from the specified reader or buffer. @@ -2581,7 +2920,10 @@ export namespace onnx { * @throws {Error} If the payload is not a reader or valid buffer * @throws {$protobuf.util.ProtocolError} If required fields are missing */ - public static decode(reader: ($protobuf.Reader|Uint8Array), length?: number): onnx.FunctionProto; + public static decode( + reader: $protobuf.Reader | Uint8Array, + length?: number, + ): onnx.FunctionProto; /** * Decodes a FunctionProto message from the specified reader or buffer, length delimited. @@ -2590,21 +2932,23 @@ export namespace onnx { * @throws {Error} If the payload is not a reader or valid buffer * @throws {$protobuf.util.ProtocolError} If required fields are missing */ - public static decodeDelimited(reader: ($protobuf.Reader|Uint8Array)): onnx.FunctionProto; + public static decodeDelimited( + reader: $protobuf.Reader | Uint8Array, + ): onnx.FunctionProto; /** * Verifies a FunctionProto message. * @param message Plain object to verify * @returns `null` if valid, otherwise the reason why it is not */ - public static verify(message: {[k: string]: any}): (string|null); + public static verify(message: { [k: string]: any }): string | null; /** * Creates a FunctionProto message from a plain object. Also converts values to their respective internal types. * @param object Plain object * @returns FunctionProto */ - public static fromObject(object: {[k: string]: any}): onnx.FunctionProto; + public static fromObject(object: { [k: string]: any }): onnx.FunctionProto; /** * Creates a plain object from a FunctionProto message. Also converts values to other types if specified. @@ -2612,13 +2956,16 @@ export namespace onnx { * @param [options] Conversion options * @returns Plain object */ - public static toObject(message: onnx.FunctionProto, options?: $protobuf.IConversionOptions): {[k: string]: any}; + public static toObject( + message: onnx.FunctionProto, + options?: $protobuf.IConversionOptions, + ): { [k: string]: any }; /** * Converts this FunctionProto to JSON. * @returns JSON object */ - public toJSON(): {[k: string]: any}; + public toJSON(): { [k: string]: any }; /** * Gets the default type url for FunctionProto diff --git a/onnx-converter/tsconfig.json b/onnx-converter/tsconfig.json index 13c5e4595..abb5cd6c2 100644 --- a/onnx-converter/tsconfig.json +++ b/onnx-converter/tsconfig.json @@ -8,4 +8,4 @@ "path": "tsconfig.lib.json" } ] -} \ No newline at end of file +} diff --git a/onnx-converter/tsconfig.lib.json b/onnx-converter/tsconfig.lib.json index 806611d85..9bea26418 100644 --- a/onnx-converter/tsconfig.lib.json +++ b/onnx-converter/tsconfig.lib.json @@ -1,5 +1,5 @@ { - "extends": "../tsconfig.base.json", - "compilerOptions": { "outDir": "dist" }, - "include": ["src"] + "extends": "../tsconfig.base.json", + "compilerOptions": { "outDir": "dist" }, + "include": ["src"] } diff --git a/server/package.json b/server/package.json index c3b63ec1b..c0b37471c 100644 --- a/server/package.json +++ b/server/package.json @@ -1,40 +1,40 @@ { - "name": "server", - "private": true, - "type": "module", - "main": "dist/index.js", - "exports": "./dist/index.js", - "scripts": { - "watch": "nodemon --ext ts --ignore dist --watch ../discojs-node/dist --watch . --exec npm run", - "start": "npm run build && node dist/main.js", - "build": "tsc --build", - "test": "cd .. && vitest --run --project=server" - }, - "author": "", - "license": "ISC", - "dependencies": { - "@epfml/discojs-node": "*", - "@msgpack/msgpack": "3", - "@roamhq/wrtc": "0.10", - "@tensorflow/tfjs": "4", - "cors": "2", - "express": "5", - "express-ws": "5", - "uuid": "14" - }, - "devDependencies": { - "@types/cors": "2", - "@types/express-ws": "3", - "@types/node": "22", - "nodemon": "3", - "ts-node": "10" - }, - "repository": { - "type": "git", - "url": "git+https://github.com/epfml/disco.git" - }, - "bugs": { - "url": "https://github.com/epfml/disco/issues" - }, - "homepage": "https://github.com/epfml/disco#readme" + "name": "server", + "private": true, + "type": "module", + "main": "dist/index.js", + "exports": "./dist/index.js", + "scripts": { + "watch": "nodemon --ext ts --ignore dist --watch ../discojs-node/dist --watch . --exec npm run", + "start": "npm run build && node dist/main.js", + "build": "tsc --build", + "test": "cd .. && vitest --run --project=server" + }, + "author": "", + "license": "ISC", + "dependencies": { + "@epfml/discojs-node": "*", + "@msgpack/msgpack": "3", + "@roamhq/wrtc": "0.10", + "@tensorflow/tfjs": "4", + "cors": "2", + "express": "5", + "express-ws": "5", + "uuid": "14" + }, + "devDependencies": { + "@types/cors": "2", + "@types/express-ws": "3", + "@types/node": "22", + "nodemon": "3", + "ts-node": "10" + }, + "repository": { + "type": "git", + "url": "git+https://github.com/epfml/disco.git" + }, + "bugs": { + "url": "https://github.com/epfml/disco/issues" + }, + "homepage": "https://github.com/epfml/disco#readme" } diff --git a/server/src/controllers/decentralized_controller.ts b/server/src/controllers/decentralized_controller.ts index 37ea428f4..33bdc403a 100644 --- a/server/src/controllers/decentralized_controller.ts +++ b/server/src/controllers/decentralized_controller.ts @@ -1,76 +1,78 @@ import createDebug from "debug"; -import { v4 as randomUUID } from 'uuid' +import { v4 as randomUUID } from "uuid"; import * as msgpack from "@msgpack/msgpack"; -import type WebSocket from 'ws' -import { Map } from 'immutable' +import type WebSocket from "ws"; +import { Map } from "immutable"; import { client, DataType } from "@epfml/discojs"; -import { TrainingController } from './training_controller.js' +import { TrainingController } from "./training_controller.js"; -import messages = client.decentralized.messages -import MessageTypes = client.messages.type +import messages = client.decentralized.messages; +import MessageTypes = client.messages.type; -const debug = createDebug("server:controllers:decentralized") +const debug = createDebug("server:controllers:decentralized"); export class DecentralizedController< - D extends DataType, + D extends DataType, > extends TrainingController { // Map of nodes who want to join the round. // The boolean value indicates if the node is ready to exchange weight updates (i.e. // the node has already sent a PeerIsReady message) // We wait for all peers to be ready to exchange weight updates - #roundPeers = Map() - #aggregationRound = 0 + #roundPeers = Map(); + #aggregationRound = 0; - handle (ws: WebSocket): void { - const minNbOfParticipants = this.task.trainingInformation.minNbOfParticipants + handle(ws: WebSocket): void { + const minNbOfParticipants = + this.task.trainingInformation.minNbOfParticipants; // Peer id of the message sender - let peerId = randomUUID() + let peerId = randomUUID(); while (this.connections.has(peerId)) { - peerId = randomUUID() + peerId = randomUUID(); } - const shortId = peerId.slice(0, 4) + const shortId = peerId.slice(0, 4); // How the server responds to messages - ws.on('message', (data: Buffer) => { + ws.on("message", (data: Buffer) => { try { - const msg: unknown = msgpack.decode(data) + const msg: unknown = msgpack.decode(data); if (!messages.isMessageToServer(msg)) return debug("invalid message received: %o", msg); switch (msg.type) { // A new peer joins the network for a task case MessageTypes.ClientConnected: { - debug(`peer [%s] joined ${this.task.id}`, shortId) - this.connections = this.connections.set(peerId, ws) + debug(`peer [%s] joined ${this.task.id}`, shortId); + this.connections = this.connections.set(peerId, ws); // Answer with client id in an NewNodeInfo message const msg: messages.NewDecentralizedNodeInfo = { type: MessageTypes.NewDecentralizedNodeInfo, id: peerId, nbOfParticipants: this.connections.size, - waitForMoreParticipants: this.connections.size < minNbOfParticipants - } - ws.send(msgpack.encode(msg), { binary: true }) + waitForMoreParticipants: + this.connections.size < minNbOfParticipants, + }; + ws.send(msgpack.encode(msg), { binary: true }); // Send an update to participants if we can start/resume training - this.sendEnoughParticipantsMsgIfNeeded(peerId) - break + this.sendEnoughParticipantsMsgIfNeeded(peerId); + break; } - // Send by peers at the beginning of each training round to notify + // Send by peers at the beginning of each training round to notify // the server that they want to join the round case MessageTypes.JoinRound: { - this.#roundPeers = this.#roundPeers.set(peerId, false) - break + this.#roundPeers = this.#roundPeers.set(peerId, false); + break; } // Send by peers when they are ready to exchange weight updates to get the list // of active peers for this round. case MessageTypes.PeerIsReady: { - this.#roundPeers = this.#roundPeers.set(peerId, true) - debug("Received peer ready from: %o", shortId) - this.sendPeersForRoundIfNeeded() - break + this.#roundPeers = this.#roundPeers.set(peerId, true); + debug("Received peer ready from: %o", shortId); + this.sendPeersForRoundIfNeeded(); + break; } // Forwards a peer's message to another destination peer // Used to exchange peer's information and establish a direct @@ -79,39 +81,39 @@ export class DecentralizedController< const forward: messages.SignalForPeer = { type: MessageTypes.SignalForPeer, peer: peerId, - signal: msg.signal - } - this.connections.get(msg.peer)?.send(msgpack.encode(forward)) - break + signal: msg.signal, + }; + this.connections.get(msg.peer)?.send(msgpack.encode(forward)); + break; } default: { - const _: never = msg - throw new Error('should never happen') + const _: never = msg; + throw new Error("should never happen"); } } } catch (e) { debug("when processing WebSocket message: %o", e); } - }) + }); // Setup callback for client leaving the session - ws.on('close', () => { + ws.on("close", () => { // Remove the participant when the websocket is closed - this.connections = this.connections.delete(peerId) - this.#roundPeers = this.#roundPeers.delete(peerId) - debug("client [%s] left", shortId) + this.connections = this.connections.delete(peerId); + this.#roundPeers = this.#roundPeers.delete(peerId); + debug("client [%s] left", shortId); // Check if we are already waiting for new participants to join - if (this.waitingForMoreParticipants) return + if (this.waitingForMoreParticipants) return; // If no, check if we are still above the minimum number of participant required if (this.connections.size >= minNbOfParticipants) { // Check if remaining peers are all ready to exchange weight updates - this.sendPeersForRoundIfNeeded() - return + this.sendPeersForRoundIfNeeded(); + return; } // If we are below the minimum number of participants // tell remaining participants to wait until more participants join - this.sendWaitForMoreParticipantsMsg() - }) + this.sendWaitForMoreParticipantsMsg(); + }); } /** * Check if we have enough participants to start the training @@ -119,35 +121,44 @@ export class DecentralizedController< * If so, send the list of peers for this round to all participants */ private sendPeersForRoundIfNeeded(): void { - const minNbOfParticipants = this.task.trainingInformation.minNbOfParticipants - const nbOfPeersReady = this.#roundPeers.filter(ready => ready).size + const minNbOfParticipants = + this.task.trainingInformation.minNbOfParticipants; + const nbOfPeersReady = this.#roundPeers.filter((ready) => ready).size; // First check if there are enough participants to start the round // Then check if all peers that wanted to join this round are ready - if (nbOfPeersReady < minNbOfParticipants - || nbOfPeersReady != this.#roundPeers.size) return + if ( + nbOfPeersReady < minNbOfParticipants || + nbOfPeersReady != this.#roundPeers.size + ) + return; // Once every peer that joined the round is ready, we can start the round - this.#roundPeers.keySeq() - .map((id) => { - const readyPeerIDs: messages.PeersForRound = { - type: MessageTypes.PeersForRound, - peers: this.#roundPeers.delete(id).keySeq().toArray(), - aggregationRound: this.#aggregationRound - } - debug("Sending peer list to: %o", id.slice(0, 4)) - - const encoded = msgpack.encode(readyPeerIDs) - return [id, encoded] as [client.NodeID, Buffer] - }) - .map(([id, encoded]) => { - const conn = this.connections.get(id) - if (conn === undefined) { - throw new Error(`peer ${id} marked as ready but not connection to it`) - } - return [conn, encoded] as [WebSocket, Buffer] - }).forEach(([conn, encoded]) => { conn.send(encoded) }) + this.#roundPeers + .keySeq() + .map((id) => { + const readyPeerIDs: messages.PeersForRound = { + type: MessageTypes.PeersForRound, + peers: this.#roundPeers.delete(id).keySeq().toArray(), + aggregationRound: this.#aggregationRound, + }; + debug("Sending peer list to: %o", id.slice(0, 4)); + + const encoded = msgpack.encode(readyPeerIDs); + return [id, encoded] as [client.NodeID, Buffer]; + }) + .map(([id, encoded]) => { + const conn = this.connections.get(id); + if (conn === undefined) { + throw new Error( + `peer ${id} marked as ready but not connection to it`, + ); + } + return [conn, encoded] as [WebSocket, Buffer]; + }) + .forEach(([conn, encoded]) => { + conn.send(encoded); + }); // empty the list of peers for the next round - this.#roundPeers = Map() - this.#aggregationRound++ + this.#roundPeers = Map(); + this.#aggregationRound++; } } - diff --git a/server/src/controllers/federated_controller.ts b/server/src/controllers/federated_controller.ts index 1e52fe539..68d017bea 100644 --- a/server/src/controllers/federated_controller.ts +++ b/server/src/controllers/federated_controller.ts @@ -1,6 +1,6 @@ import createDebug from "debug"; -import WebSocket from 'ws' -import { v4 as randomUUID } from 'uuid' +import WebSocket from "ws"; +import { v4 as randomUUID } from "uuid"; import * as msgpack from "@msgpack/msgpack"; import type { DataType, Task } from "@epfml/discojs"; @@ -8,42 +8,43 @@ import { aggregator as aggregators, client, serialization, -} from '@epfml/discojs' +} from "@epfml/discojs"; import { TrainingController } from "./training_controller.js"; -import MessageTypes = client.messages.type -import FederatedMessages = client.federated.messages +import MessageTypes = client.messages.type; +import FederatedMessages = client.federated.messages; -const debug = createDebug("server:controllers:federated") +const debug = createDebug("server:controllers:federated"); export class FederatedController extends TrainingController< - D, - "federated" + D, + "federated" > { /** * Aggregators for each hosted task. By default the server waits for 100% of the nodes to send their contributions before aggregating the updates */ - #aggregator = new aggregators.MeanAggregator(undefined, 1, 'relative') + #aggregator = new aggregators.MeanAggregator(undefined, 1, "relative"); /** - * The most up to date global weights. The model weights are already serialized and - * can be sent to participants, before starting training, or when joining mid-training + * The most up to date global weights. The model weights are already serialized and + * can be sent to participants, before starting training, or when joining mid-training * or staled participants */ #latestGlobalWeights: serialization.Encoded; - constructor( - task: Task, - private readonly initialWeights: serialization.Encoded, - ) { - super(task) - this.#latestGlobalWeights = this.initialWeights + constructor( + task: Task, + private readonly initialWeights: serialization.Encoded, + ) { + super(task); + this.#latestGlobalWeights = this.initialWeights; // Save the latest weight updates to be able to send it to new or outdated clients - this.#aggregator.on('aggregation', async (weightUpdate) => { - this.#latestGlobalWeights = await serialization.weights.encode(weightUpdate) - }) + this.#aggregator.on("aggregation", async (weightUpdate) => { + this.#latestGlobalWeights = + await serialization.weights.encode(weightUpdate); + }); } /** @@ -52,118 +53,136 @@ export class FederatedController extends TrainingController< * It registers what the server will do upon receiving messages from the participant. * Note that `this.handle` is only called once to setup the logic. It is `ws.on()` * that is called upon receiving messages (and not `this.handle`) - * + * * @param task the task associated with the current websocket (= participant) * @param ws the websocket connection through which the participant and the server communicate */ handle(ws: WebSocket): void { - const minNbOfParticipants = this.task.trainingInformation.minNbOfParticipants - this.#aggregator.minNbOfParticipants = minNbOfParticipants + const minNbOfParticipants = + this.task.trainingInformation.minNbOfParticipants; + this.#aggregator.minNbOfParticipants = minNbOfParticipants; // Try generating a new Client id until there no collision with existing ones - let clientId = randomUUID() + let clientId = randomUUID(); while (!this.#aggregator.registerNode(clientId)) { - clientId = randomUUID() + clientId = randomUUID(); } - const shortId = clientId.slice(0, 4) + const shortId = clientId.slice(0, 4); // Setup callbacks triggered upon receiving the different client messages - ws.on('message', (data: Buffer) => { - const msg: unknown = msgpack.decode(data) + ws.on("message", (data: Buffer) => { + const msg: unknown = msgpack.decode(data); if (!FederatedMessages.isMessageFederated(msg)) { debug("invalid federated message received on WebSocket: %o", msg); - return // TODO send back error + return; // TODO send back error } - // Currently expect two types of message: + // Currently expect two types of message: // - the client connects to the task // - the client sends a weight update switch (msg.type) { - /* - * A new participant joins the task - */ + /* + * A new participant joins the task + */ case MessageTypes.ClientConnected: { - debug(`client [%s] joined ${this.task.id}`, shortId) - this.connections = this.connections.set(clientId, ws) // add the new client + debug(`client [%s] joined ${this.task.id}`, shortId); + this.connections = this.connections.set(clientId, ws); // add the new client const msg: FederatedMessages.NewFederatedNodeInfo = { type: MessageTypes.NewFederatedNodeInfo, id: clientId, - waitForMoreParticipants: this.connections.size < minNbOfParticipants, + waitForMoreParticipants: + this.connections.size < minNbOfParticipants, payload: this.#latestGlobalWeights, round: this.#aggregator.round, - nbOfParticipants: this.connections.size - } - ws.send(msgpack.encode(msg)) + nbOfParticipants: this.connections.size, + }; + ws.send(msgpack.encode(msg)); // Send an update to participants if we can start/resume training - this.sendEnoughParticipantsMsgIfNeeded(clientId) - break + this.sendEnoughParticipantsMsgIfNeeded(clientId); + break; } - /* - * A client sends a weight update to the server - */ + /* + * A client sends a weight update to the server + */ case MessageTypes.SendPayload: { - const { payload, round } = msg + const { payload, round } = msg; if (this.#aggregator.isValidContribution(clientId, round)) { - const weights = serialization.weights.decode(payload) + const weights = serialization.weights.decode(payload); - // Create a callback to send the aggregated weight to the client + // Create a callback to send the aggregated weight to the client // when enough contributions are received - this.#aggregator.once('aggregation', async (weightUpdate) => { - debug("Sending global weights for round %o to client [%s]", this.#aggregator.round, shortId) + this.#aggregator.once("aggregation", async (weightUpdate) => { + debug( + "Sending global weights for round %o to client [%s]", + this.#aggregator.round, + shortId, + ); const msg: FederatedMessages.ReceiveServerPayload = { type: MessageTypes.ReceiveServerPayload, round: this.#aggregator.round, // send the current round number after aggregation payload: await serialization.weights.encode(weightUpdate), - nbOfParticipants: this.connections.size - } - ws.send(msgpack.encode(msg)) - }) + nbOfParticipants: this.connections.size, + }; + ws.send(msgpack.encode(msg)); + }); // Add the contribution - this.#aggregator.add(clientId, weights, round) - debug(`Successfully added contribution from client [%s] for round ${round}`, shortId) + this.#aggregator.add(clientId, weights, round); + debug( + `Successfully added contribution from client [%s] for round ${round}`, + shortId, + ); } else { // If the client sent an invalid or outdated contribution // the server answers with the current round and last global model update - debug(`Dropped contribution from client [%s] for round ${round} ` + - `Sending last global model from round ${this.#aggregator.round - 1}`, shortId) + debug( + `Dropped contribution from client [%s] for round ${round} ` + + `Sending last global model from round ${this.#aggregator.round - 1}`, + shortId, + ); // no latest model at the first round - if (this.#latestGlobalWeights === undefined) return - + if (this.#latestGlobalWeights === undefined) return; + const msg: FederatedMessages.ReceiveServerPayload = { type: MessageTypes.ReceiveServerPayload, round: this.#aggregator.round - 1, // send the model from the previous round payload: this.#latestGlobalWeights, - nbOfParticipants: this.connections.size - } - ws.send(msgpack.encode(msg)) + nbOfParticipants: this.connections.size, + }; + ws.send(msgpack.encode(msg)); } - break + break; } } - }) + }); // Setup callback for client leaving the session - ws.on('close', () => { + ws.on("close", () => { // Remove the participant when the websocket is closed - this.connections = this.connections.delete(clientId) - this.#aggregator.removeNode(clientId) - debug("client [%s] left", shortId) + this.connections = this.connections.delete(clientId); + this.#aggregator.removeNode(clientId); + debug("client [%s] left", shortId); // Reset the training session when all participants left if (this.connections.size === 0) { - debug("All participants left. Resetting the training session") - this.#aggregator = new aggregators.MeanAggregator(undefined, 1, 'relative') - this.#latestGlobalWeights = this.initialWeights + debug("All participants left. Resetting the training session"); + this.#aggregator = new aggregators.MeanAggregator( + undefined, + 1, + "relative", + ); + this.#latestGlobalWeights = this.initialWeights; } // Check if we dropped below the minimum number of participant required // or if we are already waiting for new participants to join - if (this.connections.size >= minNbOfParticipants || + if ( + this.connections.size >= minNbOfParticipants || this.waitingForMoreParticipants - ) return + ) + return; // tell remaining participants to wait until more participants join - this.sendWaitForMoreParticipantsMsg() - }) + this.sendWaitForMoreParticipantsMsg(); + }); } } diff --git a/server/src/controllers/index.ts b/server/src/controllers/index.ts index 3a9ae2264..e7585b809 100644 --- a/server/src/controllers/index.ts +++ b/server/src/controllers/index.ts @@ -1,3 +1,3 @@ export { TrainingController } from "./training_controller.js"; export { FederatedController } from "./federated_controller.js"; -export { DecentralizedController } from "./decentralized_controller.js"; \ No newline at end of file +export { DecentralizedController } from "./decentralized_controller.js"; diff --git a/server/src/controllers/training_controller.ts b/server/src/controllers/training_controller.ts index bbee678ec..ff3e52c6f 100644 --- a/server/src/controllers/training_controller.ts +++ b/server/src/controllers/training_controller.ts @@ -1,73 +1,76 @@ import createDebug from "debug"; -import type WebSocket from 'ws' -import { Map } from 'immutable' +import type WebSocket from "ws"; +import { Map } from "immutable"; import * as msgpack from "@msgpack/msgpack"; -import { client } from '@epfml/discojs' +import { client } from "@epfml/discojs"; import type { DataType, Network, Task } from "@epfml/discojs"; -const debug = createDebug("server:controllers") +const debug = createDebug("server:controllers"); /** * The Controller abstraction is commonly used in Express * and comes from the MVC pattern (model-view-controller) * In short, the controller is where the backend logic happens * when the server receives a client request - * + * * In this case, the controller handles the training logic: * what happens when a new task (DISCOllaborative) is created * and what happens when receiving messages from participants * of a training session. - * + * * More info on controllers: * https://developer.mozilla.org/en-US/docs/Learn/Server-side/Express_Nodejs/routes - * + * */ export abstract class TrainingController< - D extends DataType, - N extends Exclude, + D extends DataType, + N extends Exclude, > { /** - * Boolean used to know if we have enough participants to train or if + * Boolean used to know if we have enough participants to train or if * we should be waiting for more */ - protected waitingForMoreParticipants = true + protected waitingForMoreParticipants = true; /** * List of active participants along with their websockets - * the list allows updating participants about the training status + * the list allows updating participants about the training status * i.e. waiting for more participants or resuming training */ - protected connections = Map() + protected connections = Map(); constructor(protected readonly task: Task) {} - abstract handle( - ws: WebSocket - ): void + abstract handle(ws: WebSocket): void; /** * If enough participants joined, notifies them that the training can start/resume - * + * * @param currentId the id of the participant that just joined */ protected sendEnoughParticipantsMsgIfNeeded(currentId: client.NodeID) { // If we are currently waiting for more participants to join and we now have enough, // broadcast to previously waiting participants that the training can start - if (this.waitingForMoreParticipants && - this.connections.size >= this.task.trainingInformation.minNbOfParticipants) { + if ( + this.waitingForMoreParticipants && + this.connections.size >= this.task.trainingInformation.minNbOfParticipants + ) { this.connections - // filter out the client that just joined as + // filter out the client that just joined as // it already knows via the NewFederatedNodeInfo message .delete(currentId) .forEach((participantWs, participantId) => { - debug("Sending enough-participant message to client [%s]", participantId.slice(0, 4)) + debug( + "Sending enough-participant message to client [%s]", + participantId.slice(0, 4), + ); const msg: client.messages.EnoughParticipants = { type: client.messages.type.EnoughParticipants, - nbOfParticipants: this.connections.size - } - participantWs.send(msgpack.encode(msg)) - }) - this.waitingForMoreParticipants = false // update the attribute + nbOfParticipants: this.connections.size, + }; + participantWs.send(msgpack.encode(msg)); + }); + this.waitingForMoreParticipants = false; // update the attribute } } @@ -77,16 +80,17 @@ export abstract class TrainingController< protected sendWaitForMoreParticipantsMsg(): void { // If we are below the minimum number of participants // tell remaining participants to wait until more participants join - this.waitingForMoreParticipants = true - this.connections - .forEach((participantWs, participantId) => { - debug("Telling remaining client [%s] to wait for participants", participantId.slice(0, 4)) - const msg: client.messages.WaitingForMoreParticipants = { - type: client.messages.type.WaitingForMoreParticipants, - nbOfParticipants: this.connections.size - } - participantWs.send(msgpack.encode(msg)) - }) + this.waitingForMoreParticipants = true; + this.connections.forEach((participantWs, participantId) => { + debug( + "Telling remaining client [%s] to wait for participants", + participantId.slice(0, 4), + ); + const msg: client.messages.WaitingForMoreParticipants = { + type: client.messages.type.WaitingForMoreParticipants, + nbOfParticipants: this.connections.size, + }; + participantWs.send(msgpack.encode(msg)); + }); } - } diff --git a/server/src/index.ts b/server/src/index.ts index ebed19e23..bb138a687 100644 --- a/server/src/index.ts +++ b/server/src/index.ts @@ -1 +1 @@ -export { Server } from './server.js' +export { Server } from "./server.js"; diff --git a/server/src/main.ts b/server/src/main.ts index 42ecef5c4..f49996f8f 100644 --- a/server/src/main.ts +++ b/server/src/main.ts @@ -15,14 +15,14 @@ const providers = Object.values(defaultTasks); console.info("Server loaded the tasks below"); console.table( - (await Promise.all(providers.map((p) => p.getTask()))).map( - (task: Task) => ({ - ID: task.id, - Title: task.displayInformation.title, - "Data Type": task.dataType, - Scheme: task.trainingInformation.scheme, - }), - ), + (await Promise.all(providers.map((p) => p.getTask()))).map( + (task: Task) => ({ + ID: task.id, + Title: task.displayInformation.title, + "Data Type": task.dataType, + Scheme: task.trainingInformation.scheme, + }), + ), ); // Init the server with default tasks diff --git a/server/src/routes/index.ts b/server/src/routes/index.ts index 7c7e794fe..2a92dbca5 100644 --- a/server/src/routes/index.ts +++ b/server/src/routes/index.ts @@ -1,2 +1,2 @@ export { TaskRouter } from "./task_router.js"; -export { TrainingRouter } from "./training_router.js"; \ No newline at end of file +export { TrainingRouter } from "./training_router.js"; diff --git a/server/src/routes/task_router.ts b/server/src/routes/task_router.ts index 79c2c2bce..dea4b8c73 100644 --- a/server/src/routes/task_router.ts +++ b/server/src/routes/task_router.ts @@ -12,95 +12,95 @@ import { z } from "zod"; const debug = createDebug("server:router:task_router"); export class TaskRouter { - readonly #expressRouter: express.Router; - readonly #taskSet: TaskSet; + readonly #expressRouter: express.Router; + readonly #taskSet: TaskSet; - constructor(taskSet: TaskSet) { - this.#taskSet = taskSet; - this.#expressRouter = express.Router(); + constructor(taskSet: TaskSet) { + this.#taskSet = taskSet; + this.#expressRouter = express.Router(); - // Return available tasks upon GET requests - this.#expressRouter.get("/", (_, res) => { - res.status(200).send( - this.#taskSet.tasks - .valueSeq() - .map(([t, _]) => serialization.task.serializeToJSON(t)) - .toArray(), - ); - }); + // Return available tasks upon GET requests + this.#expressRouter.get("/", (_, res) => { + res.status(200).send( + this.#taskSet.tasks + .valueSeq() + .map(([t, _]) => serialization.task.serializeToJSON(t)) + .toArray(), + ); + }); - this.#expressRouter.use(express.json()); + this.#expressRouter.use(express.json()); - // POST request to add a new task - this.#expressRouter.post("/", async (req, res) => { - const parsed = await z - .object({ - model: z - .array(z.number()) - .transform((arr) => Uint8Array.from(arr)) - .transform(serialization.model.decode), - task: z.any().transform(serialization.task.deserializeFromJSON), - }) - .safeParseAsync(req.body); + // POST request to add a new task + this.#expressRouter.post("/", async (req, res) => { + const parsed = await z + .object({ + model: z + .array(z.number()) + .transform((arr) => Uint8Array.from(arr)) + .transform(serialization.model.decode), + task: z.any().transform(serialization.task.deserializeFromJSON), + }) + .safeParseAsync(req.body); - if (!parsed.success) { - debug("posted task isn't valid: %s", parsed.error); - res.status(400).end(); - return; - } - const { model, task } = parsed.data; + if (!parsed.success) { + debug("posted task isn't valid: %s", parsed.error); + res.status(400).end(); + return; + } + const { model, task } = parsed.data; - try { - await this.#taskSet.addTask(task, model); - } catch (e) { - debug("add task failed with: %o", e); - if (e instanceof Error && e.message === "already existing") - res.status(409).end(); - else res.status(500).end(); - return; - } + try { + await this.#taskSet.addTask(task, model); + } catch (e) { + debug("add task failed with: %o", e); + if (e instanceof Error && e.message === "already existing") + res.status(409).end(); + else res.status(500).end(); + return; + } - res.status(200).end("Successful task upload"); - }); + res.status(200).end("Successful task upload"); + }); - this.#taskSet.on("newTask", ([task]) => { - this.#expressRouter.get(`/${task.id}/:file`, (req, res) => - this.getLatestModel(task.id, req, res), - ); - }); - } + this.#taskSet.on("newTask", ([task]) => { + this.#expressRouter.get(`/${task.id}/:file`, (req, res) => + this.getLatestModel(task.id, req, res), + ); + }); + } - public get router(): express.Router { - return this.#expressRouter; - } + public get router(): express.Router { + return this.#expressRouter; + } - /** - * Request handler called when a client sends a GET request asking for the - * TFJS model files of a given task. The files consist of the model's - * architecture file model.json and its layer weights file weights.bin. - * It requires no prior connection to the server and is thus publicly available - * data. - * @param request received from client - * @param response sent to client - */ - private getLatestModel( - id: Task.ID, - request: Request<{ file: string }>, - response: Response, - ): void { - const validModelFiles = Set.of("model.json", "weights.bin"); + /** + * Request handler called when a client sends a GET request asking for the + * TFJS model files of a given task. The files consist of the model's + * architecture file model.json and its layer weights file weights.bin. + * It requires no prior connection to the server and is thus publicly available + * data. + * @param request received from client + * @param response sent to client + */ + private getLatestModel( + id: Task.ID, + request: Request<{ file: string }>, + response: Response, + ): void { + const validModelFiles = Set.of("model.json", "weights.bin"); - const file = request.params.file; - if (!validModelFiles.has(file)) { - response.status(404); - return; - } - const taskAndModel = this.#taskSet.tasks.find(([t, _]) => t.id === id); - if (taskAndModel === undefined) { - response.status(404); - return; - } - response.status(200).send(Buffer.from(taskAndModel[1])); - debug(`${file} download for task ${id} succeeded`); - } + const file = request.params.file; + if (!validModelFiles.has(file)) { + response.status(404); + return; + } + const taskAndModel = this.#taskSet.tasks.find(([t, _]) => t.id === id); + if (taskAndModel === undefined) { + response.status(404); + return; + } + response.status(200).send(Buffer.from(taskAndModel[1])); + debug(`${file} download for task ${id} succeeded`); + } } diff --git a/server/src/routes/training_router.ts b/server/src/routes/training_router.ts index ef8fd815e..28da1e932 100644 --- a/server/src/routes/training_router.ts +++ b/server/src/routes/training_router.ts @@ -1,10 +1,14 @@ -import express from 'express' -import type expressWS from 'express-ws' -import type { Task, DataType, Network } from '@epfml/discojs' -import { serialization } from '@epfml/discojs' +import express from "express"; +import type expressWS from "express-ws"; +import type { Task, DataType, Network } from "@epfml/discojs"; +import { serialization } from "@epfml/discojs"; -import type { TaskSet } from '../task_set.js' -import { TrainingController, FederatedController, DecentralizedController } from '../controllers/index.js' +import type { TaskSet } from "../task_set.js"; +import { + TrainingController, + FederatedController, + DecentralizedController, +} from "../controllers/index.js"; /** * The TrainingRouter handles client requests related the federated @@ -13,53 +17,55 @@ import { TrainingController, FederatedController, DecentralizedController } from * the actual logic to the task's Controller. */ export class TrainingRouter> { - readonly #expressRouter: expressWS.Router + readonly #expressRouter: expressWS.Router; - constructor(network: N, wsApplier: expressWS.Instance, taskSet: TaskSet) { - this.#expressRouter = express.Router() - wsApplier.applyTo(this.#expressRouter) + constructor(network: N, wsApplier: expressWS.Instance, taskSet: TaskSet) { + this.#expressRouter = express.Router(); + wsApplier.applyTo(this.#expressRouter); this.#expressRouter.get("/", (_, res) => { res.send(`Disco ${network} server\n`); }); - taskSet.on("newTask", async ([task, encodedModel]) => { - if (task.trainingInformation.scheme !== network) return; - const t = task as Task; + taskSet.on("newTask", async ([task, encodedModel]) => { + if (task.trainingInformation.scheme !== network) return; + const t = task as Task; - await this.onNewTask(t, encodedModel); - }); + await this.onNewTask(t, encodedModel); + }); } // The method called to use the TrainingRouter - public get router (): express.Router { - return this.#expressRouter + public get router(): express.Router { + return this.#expressRouter; } // Register the task and setup the controller to handle // websocket connections - private async onNewTask( - task: Task, - encodedModel: serialization.Encoded, - ): Promise { + private async onNewTask( + task: Task, + encodedModel: serialization.Encoded, + ): Promise { // The controller handles the actual logic of collaborative training // in its `handle` method. Each task has a dedicated controller which // handles the training logic of this task only - let taskController: TrainingController; - if (task.trainingInformation.scheme === "federated") { - const t = task as Task + let taskController: TrainingController; + if (task.trainingInformation.scheme === "federated") { + const t = task as Task; // The federated controller takes the initial model weights at initialization // so that it can send it to new clients - const model = serialization.model.decode(encodedModel) - const encodedWeights = await serialization.weights.encode((await model).weights) - taskController = new FederatedController(t, encodedWeights) + const model = serialization.model.decode(encodedModel); + const encodedWeights = await serialization.weights.encode( + (await model).weights, + ); + taskController = new FederatedController(t, encodedWeights); } else { - const t = task as Task + const t = task as Task; // In decentralized learning, the server (i.e. controller) never handles model weights - taskController = new DecentralizedController(t) - } + taskController = new DecentralizedController(t); + } this.#expressRouter.ws(`/${task.id}`, (ws) => taskController.handle(ws)); } diff --git a/server/src/server.ts b/server/src/server.ts index 06641cff0..7bbb1e917 100644 --- a/server/src/server.ts +++ b/server/src/server.ts @@ -5,14 +5,14 @@ import type * as http from "http"; import type { DataType, Network, TaskProvider } from "@epfml/discojs"; -import { TaskRouter, TrainingRouter } from './routes/index.js' +import { TaskRouter, TrainingRouter } from "./routes/index.js"; import { TaskSet } from "./task_set.js"; /** * The Disco Server, initializing an Express app * Its main goal is to provide the available tasks (DISCOllaboratives) - * and tasks' base models to clients. - * + * and tasks' base models to clients. + * * More info on Express apps: * https://developer.mozilla.org/en-US/docs/Learn/Server-side/Express_Nodejs/Introduction */ @@ -20,7 +20,9 @@ export class Server { readonly #taskSet = new TaskSet(); /** setup with given initial tasks */ - static async with(...tasks: TaskProvider[]): Promise { + static async with( + ...tasks: TaskProvider[] + ): Promise { const server = new Server(); await Promise.all(tasks.map((t) => server.#taskSet.addTask(t))); @@ -33,7 +35,7 @@ export class Server { * * @param port where to start, if not given, choose a random one * @returns a tuple with the server instance and the URL - * + * **/ async serve(port?: number): Promise<[http.Server, URL]> { const wsApplier = expressWS(express(), undefined, { @@ -46,17 +48,25 @@ export class Server { app.use(express.json({ limit: "50mb" })); app.use(express.urlencoded({ limit: "50mb", extended: false })); - const taskRouter = new TaskRouter(this.#taskSet) - const federatedRouter = new TrainingRouter('federated', wsApplier, this.#taskSet) - const decentralizedRouter = new TrainingRouter('decentralized', wsApplier, this.#taskSet) + const taskRouter = new TaskRouter(this.#taskSet); + const federatedRouter = new TrainingRouter( + "federated", + wsApplier, + this.#taskSet, + ); + const decentralizedRouter = new TrainingRouter( + "decentralized", + wsApplier, + this.#taskSet, + ); - app.get('/', (_, res, next) => { - res.send('The DISCO Server\n') - next() - }) - app.use('/federated', federatedRouter.router) - app.use('/decentralized', decentralizedRouter.router) - app.use('/tasks', taskRouter.router) + app.get("/", (_, res, next) => { + res.send("The DISCO Server\n"); + next(); + }); + app.use("/federated", federatedRouter.router); + app.use("/decentralized", decentralizedRouter.router); + app.use("/tasks", taskRouter.router); const server = await new Promise((resolve, reject) => { const ret = app.listen(port); diff --git a/server/src/task_set.ts b/server/src/task_set.ts index 5af2c9f0e..78b0b6d38 100644 --- a/server/src/task_set.ts +++ b/server/src/task_set.ts @@ -1,6 +1,6 @@ import { Map } from "immutable"; -import fs from 'node:fs/promises' -import '@tensorflow/tfjs-node' +import fs from "node:fs/promises"; +import "@tensorflow/tfjs-node"; import type { DataType, Network, Task, TaskProvider } from "@epfml/discojs"; import { EventEmitter, Model, serialization } from "@epfml/discojs"; @@ -9,36 +9,36 @@ type EncodedModel = serialization.Encoded; type TaskAndModel = [Task, EncodedModel]; /** - * The TaskSet essentially handles initializing a Task and + * The TaskSet essentially handles initializing a Task and * loading its associated EncodedModel. - * + * * We rely on a TaskSet to abstract the (asynchronous) logic of getting the model * when not provided. * Depending on the case, getting the model is done by reading the model files - * from disk if they exists, downloading them from a URL or - * initializing the model from its architecture definition. - * - * We work with EncodedModels rather than Models because they are sent encoded + * from disk if they exists, downloading them from a URL or + * initializing the model from its architecture definition. + * + * We work with EncodedModels rather than Models because they are sent encoded * to clients. Since the server doesn't need to use the Model, we * simply leave it already encoded and ready to be sent to clients - * - * Due to the asynchronous nature of `addTask`, TaskSet is an EventEmitter, - * by registering callbacks on new tasks and emitting a 'newTask' event + * + * Due to the asynchronous nature of `addTask`, TaskSet is an EventEmitter, + * by registering callbacks on new tasks and emitting a 'newTask' event * when a new task has been added. - * + * * Tasks are usually passed to TaskSet when booting the server - * and objects depending on tasks and models can subscribe to + * and objects depending on tasks and models can subscribe to * the 'newTask' event to run callbacks whenever a new Task and EncodedModel are initialized. */ export class TaskSet extends EventEmitter<{ newTask: TaskAndModel; }> { - // Keep track of previously initialized task-model pairs - #tasks = Map(); + // Keep track of previously initialized task-model pairs + #tasks = Map(); - get tasks(): Map { - return this.#tasks; - } + get tasks(): Map { + return this.#tasks; + } // send known tasks to new listener override on( @@ -50,21 +50,21 @@ export class TaskSet extends EventEmitter<{ /** * Method to add a new task and optionally its associated model. - * It accepts parameters in different formats and handles + * It accepts parameters in different formats and handles * shaping them into a Task and an EncodedModel. * The method emits a 'newTask' event with the resulting Task and EncodedModel. - * + * * If a Task and the EncodedModel is provided as parameters the method does change them - * Otherwise the method handles shaping the parameters into a Task and EncodedModel + * Otherwise the method handles shaping the parameters into a Task and EncodedModel * before emitting the event - * + * * @param taskOrProvider either a Task or TaskProvider * @param model optional model, can already be an EncodedModel, a Model or a URL for the model */ - async addTask( - taskOrProvider: Task | TaskProvider, - model?: Model | EncodedModel, - ): Promise { + async addTask( + taskOrProvider: Task | TaskProvider, + model?: Model | EncodedModel, + ): Promise { // get the task const task = "getTask" in taskOrProvider @@ -72,65 +72,65 @@ export class TaskSet extends EventEmitter<{ : taskOrProvider; // get the model - let encodedModel: EncodedModel + let encodedModel: EncodedModel; if (serialization.isEncoded(model)) { - encodedModel = model // don't do anything if already encoded + encodedModel = model; // don't do anything if already encoded } else { - let tfModel: Model - if (model === undefined) { + let tfModel: Model; + if (model === undefined) { // Get the model if nothing is provided - tfModel = await this.loadModelFromTask(taskOrProvider) + tfModel = await this.loadModelFromTask(taskOrProvider); } else if (model instanceof Model) { // Don't do anything if the model is already specified - tfModel = model + tfModel = model; } else { - throw new Error('invalid model') + throw new Error("invalid model"); } - encodedModel = await serialization.model.encode(tfModel) + encodedModel = await serialization.model.encode(tfModel); } - // Add the task-model pair to the set - if (this.#tasks.has(task.id)) throw new Error("already existing"); - this.#tasks = this.#tasks.set(task.id, [task, encodedModel]); - this.emit("newTask", [task, encodedModel]); + // Add the task-model pair to the set + if (this.#tasks.has(task.id)) throw new Error("already existing"); + this.#tasks = this.#tasks.set(task.id, [task, encodedModel]); + this.emit("newTask", [task, encodedModel]); } /** * Gets the model associated to a task. First checks if the model has been saved to disk. * Otherwise, initializes it from its architecture definition (and saves it to disk) - * + * * @param taskOrProvider either a Task or a TaskProvider * @returns a promise for the associated model */ - private async loadModelFromTask( - taskOrProvider: Task | TaskProvider, - ): Promise> { + private async loadModelFromTask( + taskOrProvider: Task | TaskProvider, + ): Promise> { const task = "getTask" in taskOrProvider ? await taskOrProvider.getTask() : taskOrProvider; - - const modelPath = `./models/${task.id}/` + + const modelPath = `./models/${task.id}/`; try { - const content = await fs.readFile(`${modelPath}/model.json`) + const content = await fs.readFile(`${modelPath}/model.json`); // cast as we trust the task ID - return await serialization.model.decode(content) + return await serialization.model.decode(content); } catch { // unable to read file (potentially doesn't exist), continuing } - + if ("id" in taskOrProvider) { // if the model isn't already saved to disk then we need the TaskProvider // to get the model architecture definition - throw new Error('saved model not found and no way to get it') + throw new Error("saved model not found and no way to get it"); } const model = await taskOrProvider.getModel(); // Save the model to disk - await fs.mkdir(modelPath, { recursive: true }) - const encoded = await serialization.model.encode(model) - await fs.writeFile(`${modelPath}/model.json`, encoded) - - return model + await fs.mkdir(modelPath, { recursive: true }); + const encoded = await serialization.model.encode(model); + await fs.writeFile(`${modelPath}/model.json`, encoded); + + return model; } } diff --git a/server/tests/client.spec.ts b/server/tests/client.spec.ts index e143a546f..4a0b435c6 100644 --- a/server/tests/client.spec.ts +++ b/server/tests/client.spec.ts @@ -1,103 +1,103 @@ import type * as http from "node:http"; import type { DataType, Network, TaskProvider } from "@epfml/discojs"; import { - aggregator as aggregators, - client as clients, - defaultTasks, + aggregator as aggregators, + client as clients, + defaultTasks, } from "@epfml/discojs"; import { afterEach, describe, expect, it } from "vitest"; import { Server } from "../src/index.js"; describe("decentralized client", () => { - let handle: http.Server; - async function startServer( - ...tasks: TaskProvider[] - ): Promise { - const server = await Server.with(...tasks); - - let url: URL; - [handle, url] = await server.serve(); - return url; - } - afterEach( - () => - new Promise((resolve, reject) => - handle?.close((e) => { - if (e !== undefined) reject(e); - else resolve(); - }), - ), - ); - - it("connects to valid task", async () => { - const url = await startServer(defaultTasks.cifar10); - - const client = new clients.decentralized.DecentralizedClient( - url, - await defaultTasks.cifar10.getTask(), - new aggregators.MeanAggregator(), - ); - - await client.connect(); - await client.disconnect(); - }); - - it("fails to connect to invalid task", async () => { - const url = await startServer(); // no task - - const client = new clients.decentralized.DecentralizedClient( - url, - await defaultTasks.cifar10.getTask(), - new aggregators.MeanAggregator(), - ); - - await expect(client.connect()).rejects.toThrow(); - }); + let handle: http.Server; + async function startServer( + ...tasks: TaskProvider[] + ): Promise { + const server = await Server.with(...tasks); + + let url: URL; + [handle, url] = await server.serve(); + return url; + } + afterEach( + () => + new Promise((resolve, reject) => + handle?.close((e) => { + if (e !== undefined) reject(e); + else resolve(); + }), + ), + ); + + it("connects to valid task", async () => { + const url = await startServer(defaultTasks.cifar10); + + const client = new clients.decentralized.DecentralizedClient( + url, + await defaultTasks.cifar10.getTask(), + new aggregators.MeanAggregator(), + ); + + await client.connect(); + await client.disconnect(); + }); + + it("fails to connect to invalid task", async () => { + const url = await startServer(); // no task + + const client = new clients.decentralized.DecentralizedClient( + url, + await defaultTasks.cifar10.getTask(), + new aggregators.MeanAggregator(), + ); + + await expect(client.connect()).rejects.toThrow(); + }); }); describe("federated client", () => { - let handle: http.Server; - async function startServer( - ...tasks: TaskProvider[] - ): Promise { - const server = await Server.with(...tasks); - - let url: URL; - [handle, url] = await server.serve(); - return url; - } - afterEach( - () => - new Promise((resolve, reject) => - handle?.close((e) => { - if (e !== undefined) reject(e); - else resolve(); - }), - ), - ); - - it("connects to valid task", async () => { - const url = await startServer(defaultTasks.titanic); - - const client = new clients.federated.FederatedClient( - url, - await defaultTasks.titanic.getTask(), - new aggregators.MeanAggregator(), - ); - - await client.connect(); - await client.disconnect(); - }); - - it("fails to connect to invalid task", async () => { - const url = await startServer(); // no task - - const client = new clients.federated.FederatedClient( - url, - await defaultTasks.titanic.getTask(), - new aggregators.MeanAggregator(), - ); - - await expect(client.connect()).rejects.toThrow(); - }); + let handle: http.Server; + async function startServer( + ...tasks: TaskProvider[] + ): Promise { + const server = await Server.with(...tasks); + + let url: URL; + [handle, url] = await server.serve(); + return url; + } + afterEach( + () => + new Promise((resolve, reject) => + handle?.close((e) => { + if (e !== undefined) reject(e); + else resolve(); + }), + ), + ); + + it("connects to valid task", async () => { + const url = await startServer(defaultTasks.titanic); + + const client = new clients.federated.FederatedClient( + url, + await defaultTasks.titanic.getTask(), + new aggregators.MeanAggregator(), + ); + + await client.connect(); + await client.disconnect(); + }); + + it("fails to connect to invalid task", async () => { + const url = await startServer(); // no task + + const client = new clients.federated.FederatedClient( + url, + await defaultTasks.titanic.getTask(), + new aggregators.MeanAggregator(), + ); + + await expect(client.connect()).rejects.toThrow(); + }); }); diff --git a/server/tests/e2e/decentralized.spec.ts b/server/tests/e2e/decentralized.spec.ts index 0a0fb29af..9a1fcba99 100644 --- a/server/tests/e2e/decentralized.spec.ts +++ b/server/tests/e2e/decentralized.spec.ts @@ -1,11 +1,11 @@ import type * as http from "node:http"; import type { DataType, RoundStatus, Task, TaskProvider } from "@epfml/discojs"; import { - aggregator as aggregators, - client as clients, - Disco, - defaultTasks, - WeightsContainer, + aggregator as aggregators, + client as clients, + Disco, + defaultTasks, + WeightsContainer, } from "@epfml/discojs"; import { List } from "immutable"; import { afterEach, describe, expect, it } from "vitest"; @@ -13,25 +13,27 @@ import { Server } from "../../src/index.js"; import { datasets, Queue } from "../utils.js"; async function WSIntoList(ws: WeightsContainer): Promise>> { - return List((await Promise.all(ws.weights.map(async (w) => await w.data()))).map( - (arr) => List(arr), - )); + return List( + (await Promise.all(ws.weights.map(async (w) => await w.data()))).map( + (arr) => List(arr), + ), + ); } async function expectWSToBeClose( - left: WeightsContainer, - right: WeightsContainer, + left: WeightsContainer, + right: WeightsContainer, ): Promise { - for (const tensors of (await WSIntoList(left)).zip(await WSIntoList(right))) - for (const [l, r] of tensors[0].zip(tensors[1])) - expect(l).to.be.closeTo(r, 1e-4); + for (const tensors of (await WSIntoList(left)).zip(await WSIntoList(right))) + for (const [l, r] of tensors[0].zip(tensors[1])) + expect(l).to.be.closeTo(r, 1e-4); } describe("end-to-end decentralized", { timeout: 50_000 }, () => { let handle: http.Server | undefined; - async function startServer( - task: TaskProvider, - ): Promise { + async function startServer( + task: TaskProvider, + ): Promise { const server = await Server.with(task); let url: URL; @@ -54,57 +56,63 @@ describe("end-to-end decentralized", { timeout: 50_000 }, () => { * with other ready peers. The input will vary with model architecture and training data. If secure is true, * the client will implement secure aggregation. If it is false, it will be a clear text client. */ - async function simulateClient( - url: URL, - aggregatorType: "mean" | "secure", - input: number[], - rounds: number, - ): Promise<[WeightsContainer, clients.Client<"decentralized">]> { - const task = await defaultTasks.cifar10.getTask(); - const aggregator = - aggregatorType === "mean" - ? new aggregators.MeanAggregator(0, 1, "relative") - : new aggregators.SecureAggregator(); - - const client = new clients.decentralized.DecentralizedClient(url, task, aggregator) - await client.connect() + async function simulateClient( + url: URL, + aggregatorType: "mean" | "secure", + input: number[], + rounds: number, + ): Promise<[WeightsContainer, clients.Client<"decentralized">]> { + const task = await defaultTasks.cifar10.getTask(); + const aggregator = + aggregatorType === "mean" + ? new aggregators.MeanAggregator(0, 1, "relative") + : new aggregators.SecureAggregator(); + + const client = new clients.decentralized.DecentralizedClient( + url, + task, + aggregator, + ); + await client.connect(); // Perform multiple training rounds - let weights = WeightsContainer.of(input) + let weights = WeightsContainer.of(input); for (let r = 0; r < rounds; r++) { - await client.onRoundBeginCommunication() - await new Promise((resolve) => setTimeout(resolve, 1_000)) - weights = await client.onRoundEndCommunication(weights) + await client.onRoundBeginCommunication(); + await new Promise((resolve) => setTimeout(resolve, 1_000)); + weights = await client.onRoundEndCommunication(weights); } - return [weights, client] + return [weights, client]; } /** * Creates three clients with different update values and returns the aggregated update value between all three clients. * The clients have model dimension of 4 model updates to share, which can be seen as their input parameter in makeClient. */ - async function reachConsensus ( + async function reachConsensus( url: URL, - aggregatorType: 'mean' | 'secure', - rounds = 1 + aggregatorType: "mean" | "secure", + rounds = 1, ): Promise { // Expect the clients to reach the mean consensus, for both the mean and secure aggregators const contributions = List.of( [0.001, 3, 40, 10], [0.002, 5, 30, 11], - [0.003, 13, 11, 12] - ) + [0.003, 13, 11, 12], + ); const actual = await Promise.all( contributions .map(async (w) => await simulateClient(url, aggregatorType, w, rounds)) .toArray(), ); - const consensuses = await Promise.all(actual.map(async ([consensus, client]) => { - // Disconnect clients once they reached consensus - await client.disconnect() - return consensus - })); + const consensuses = await Promise.all( + actual.map(async ([consensus, client]) => { + // Disconnect clients once they reached consensus + await client.disconnect(); + return consensus; + }), + ); const consensus = consensuses[0]; await Promise.all( @@ -135,43 +143,43 @@ describe("end-to-end decentralized", { timeout: 50_000 }, () => { }); it("peers emit expected events", { timeout: 100_000 }, async () => { - const baseTask = await defaultTasks.lusCovid.getTask(); - const task: Task<"image", "decentralized"> = { - ...baseTask, - trainingInformation: { - ...baseTask.trainingInformation, - scheme: "decentralized", - aggregationStrategy: "mean", - roundDuration: 1, - minNbOfParticipants: 2, - }, - }; - const url = await startServer({ - ...defaultTasks.lusCovid, - getTask: () => Promise.resolve(task), - }); - const dataset = await datasets.loadLusCOVID(); + const baseTask = await defaultTasks.lusCovid.getTask(); + const task: Task<"image", "decentralized"> = { + ...baseTask, + trainingInformation: { + ...baseTask.trainingInformation, + scheme: "decentralized", + aggregationStrategy: "mean", + roundDuration: 1, + minNbOfParticipants: 2, + }, + }; + const url = await startServer({ + ...defaultTasks.lusCovid, + getTask: () => Promise.resolve(task), + }); + const dataset = await datasets.loadLusCOVID(); /** * Then at each round (each call to `disco.trainByRound`) the event cycle is: - * a) During onRoundBeingCommunication, + * a) During onRoundBeingCommunication, * 1. the peer notifies the server that they want to join the next round * 2. finishes by updating the status to "local training" * (without waiting for a server answer) * b) local training (the status remains "local training") - * c) During onRoundEndCommunication - * 1. the peer notifies the server that they are ready to share weights + * c) During onRoundEndCommunication + * 1. the peer notifies the server that they are ready to share weights * set status to "connecting to peers" * 2. wait for the server to answer with the current round's peers list * this is where the nb of participants is updated - * 3. establish peer-to-peer connections + * 3. establish peer-to-peer connections * 4. set status to "updating model" and exchange weight updates - * + * * Given this, it is important to note that calling disco.trainByRound().next() * for the first time will perform a) and then b) where it stops and yields the round logs. * Thus, c) isn't called and the weight sharing is not performed during this call to next(). * Calling next() again will then run c), as well as a) and b) again. - * + * * In this test the timeline is: * - User 1 joins the task by themselves * - User 2 joins @@ -185,90 +193,98 @@ describe("end-to-end decentralized", { timeout: 50_000 }, () => { const discoUser1 = new Disco(task, url, { preprocessOnce: true }); const statusUser1 = new Queue(); const nbParticipantsUser1 = new Queue(); - discoUser1.on("status", status => { statusUser1.put(status) }) - discoUser1.on("participants", (participants) => { nbParticipantsUser1.put(participants) }) - const generatorUser1 = discoUser1.trainByRound(dataset) - + discoUser1.on("status", (status) => { + statusUser1.put(status); + }); + discoUser1.on("participants", (participants) => { + nbParticipantsUser1.put(participants); + }); + const generatorUser1 = discoUser1.trainByRound(dataset); + // Have User 1 join the task and train locally for one round - const logUser1Round1 = await generatorUser1.next() - expect(logUser1Round1.done).to.be.false + const logUser1Round1 = await generatorUser1.next(); + expect(logUser1Round1.done).to.be.false; // User 1 did a) and b) so their status should be Training - expect(await statusUser1.next()).equal("local training") - expect(await nbParticipantsUser1.next()).equal(1) + expect(await statusUser1.next()).equal("local training"); + expect(await nbParticipantsUser1.next()).equal(1); if (logUser1Round1.done) - throw new Error("User 1 finished training at the 1st round") + throw new Error("User 1 finished training at the 1st round"); // participant list not updated yet (updated at step c)) - expect((logUser1Round1.value).participants).equal(1) + expect(logUser1Round1.value.participants).equal(1); // Calling next() a 2nd time makes User 1 go to c) where the peer should // stay stuck awaiting until another participant joins - const logUser1Round2Promise = generatorUser1.next() - expect(await statusUser1.next()).equal("connecting to peers") // tries to connect to peers - expect(await statusUser1.next()).equal("not enough participants") // but has to wait for more participants + const logUser1Round2Promise = generatorUser1.next(); + expect(await statusUser1.next()).equal("connecting to peers"); // tries to connect to peers + expect(await statusUser1.next()).equal("not enough participants"); // but has to wait for more participants /* USER 2 JOINS */ const discoUser2 = new Disco(task, url, { preprocessOnce: true }); const statusUser2 = new Queue(); const nbParticipantsUser2 = new Queue(); - discoUser2.on("status", status => { statusUser2.put(status) }) - discoUser2.on("participants", (participants) => { nbParticipantsUser2.put(participants) }) - const generatorUser2 = discoUser2.trainByRound(dataset) + discoUser2.on("status", (status) => { + statusUser2.put(status); + }); + discoUser2.on("participants", (participants) => { + nbParticipantsUser2.put(participants); + }); + const generatorUser2 = discoUser2.trainByRound(dataset); // Have User 2 join the task and train for one round - const logUser2Round1 = await generatorUser2.next() - expect(logUser2Round1.done).to.be.false + const logUser2Round1 = await generatorUser2.next(); + expect(logUser2Round1.done).to.be.false; if (logUser2Round1.done) - throw new Error("User 2 finished training at the 1st round") + throw new Error("User 2 finished training at the 1st round"); // round payload should contain the number of participants - expect((logUser2Round1.value).participants).equal(2) - expect(await nbParticipantsUser2.next()).equal(2) + expect(logUser2Round1.value.participants).equal(2); + expect(await nbParticipantsUser2.next()).equal(2); // Receive the EnoughParticipants message with the participants - expect(await nbParticipantsUser1.next()).equal(2) + expect(await nbParticipantsUser1.next()).equal(2); // User 2 did a) and b) - expect(await statusUser2.next()).equal("local training") + expect(await statusUser2.next()).equal("local training"); // User 1 is still in c) now waiting for user 2 to be ready to exchange weight updates - expect(await statusUser1.next()).equal("connecting to peers") + expect(await statusUser1.next()).equal("connecting to peers"); /* ROUND 2 */ - // The server should answer with the round's peers list. + // The server should answer with the round's peers list. // Peers then exchange updates and then start training locally with the new weights - const logUser2Round2 = await generatorUser2.next() - const logUser1Round2 = await logUser1Round2Promise // the promise can resolve now - expect(logUser1Round2.done).to.be.false - expect(logUser2Round2.done).to.be.false + const logUser2Round2 = await generatorUser2.next(); + const logUser1Round2 = await logUser1Round2Promise; // the promise can resolve now + expect(logUser1Round2.done).to.be.false; + expect(logUser2Round2.done).to.be.false; if (logUser1Round2.done || logUser2Round2.done) - throw new Error("User 1 or 2 finished training at the 2nd round") + throw new Error("User 1 or 2 finished training at the 2nd round"); // nb of participants should now be updated - expect((logUser1Round2.value).participants).equal(2) - expect((logUser2Round2.value).participants).equal(2) - expect(await nbParticipantsUser2.next()).equal(2) - expect(await nbParticipantsUser1.next()).equal(2) + expect(logUser1Round2.value.participants).equal(2); + expect(logUser2Round2.value.participants).equal(2); + expect(await nbParticipantsUser2.next()).equal(2); + expect(await nbParticipantsUser1.next()).equal(2); // User 1 and 2 did c), a) and b) - expect(await statusUser1.next()).equal("updating model") // second to last - expect(await statusUser1.next()).equal("local training") + expect(await statusUser1.next()).equal("updating model"); // second to last + expect(await statusUser1.next()).equal("local training"); + + expect(await statusUser2.next()).equal("connecting to peers"); // back to connecting when user 1 joins + expect(await statusUser2.next()).equal("updating model"); + expect(await statusUser2.next()).equal("local training"); - expect(await statusUser2.next()).equal("connecting to peers") // back to connecting when user 1 joins - expect(await statusUser2.next()).equal("updating model") - expect(await statusUser2.next()).equal("local training") - /* USER 1 LEAVES */ - await discoUser1.close() + await discoUser1.close(); // Disconnect updates the number of participants - expect(await nbParticipantsUser1.next()).equal(1) + expect(await nbParticipantsUser1.next()).equal(1); // User 2 receives the WaitingForMoreParticipants message - expect(await nbParticipantsUser2.next()).equal(1) + expect(await nbParticipantsUser2.next()).equal(1); // server notifies user 2 to wait - expect(await statusUser2.next()).equal("not enough participants") + expect(await statusUser2.next()).equal("not enough participants"); // Make user 2 go to c) - const logUser2Round3Promise = generatorUser2.next() + const logUser2Round3Promise = generatorUser2.next(); // await new Promise((res, _) => setTimeout(res, statusUpdateTime)) // Wait some time for the status to update // starts c) and waits for user 3 to join - expect(await statusUser2.next()).equal("connecting to peers") - expect(await statusUser2.next()).equal("not enough participants") + expect(await statusUser2.next()).equal("connecting to peers"); + expect(await statusUser2.next()).equal("not enough participants"); /* USER 3 JOINS */ @@ -276,54 +292,58 @@ describe("end-to-end decentralized", { timeout: 50_000 }, () => { const discoUser3 = new Disco(task, url, { preprocessOnce: true }); const statusUser3 = new Queue(); const nbParticipantsUser3 = new Queue(); - discoUser3.on("status", status => { statusUser3.put(status) }) - discoUser3.on("participants", (participants) => { nbParticipantsUser3.put(participants) }) - const generatorUser3 = discoUser3.trainByRound(dataset) + discoUser3.on("status", (status) => { + statusUser3.put(status); + }); + discoUser3.on("participants", (participants) => { + nbParticipantsUser3.put(participants); + }); + const generatorUser3 = discoUser3.trainByRound(dataset); // User 3 joins mid-training and trains one local round - const logUser3Round1 = await generatorUser3.next() - expect(logUser3Round1.done).to.be.false + const logUser3Round1 = await generatorUser3.next(); + expect(logUser3Round1.done).to.be.false; if (logUser3Round1.done) - throw new Error("User 3 finished training at the 1st round") - expect((logUser3Round1.value).participants).equal(2) - expect(await nbParticipantsUser3.next()).equal(2) + throw new Error("User 3 finished training at the 1st round"); + expect(logUser3Round1.value.participants).equal(2); + expect(await nbParticipantsUser3.next()).equal(2); // User 2 receives the EnoughParticipants message // User 2 is still in c) waiting for user 3 to share their local update - expect(await nbParticipantsUser2.next()).equal(2) - + expect(await nbParticipantsUser2.next()).equal(2); + // User 3 did a) and b) - expect(await statusUser3.next()).equal("local training") + expect(await statusUser3.next()).equal("local training"); // User 2 is still in c) waiting for user 3 to be ready to exchange waits - expect(await statusUser2.next()).equal("connecting to peers") - + expect(await statusUser2.next()).equal("connecting to peers"); + /* ROUND 3 */ // User 3 notifies the server that they are ready to exchange waits // then user 2 and 3 exchange weight updates - const logUser3Round3 = await generatorUser3.next() - const logUser2Round3 = await logUser2Round3Promise // the promise can resolve now + const logUser3Round3 = await generatorUser3.next(); + const logUser2Round3 = await logUser2Round3Promise; // the promise can resolve now if (logUser3Round3.done || logUser2Round3.done) - throw new Error("User 1 or 2 finished training at the 3rd round") - - expect(logUser2Round3.value.participants).equal(2) - expect(logUser3Round3.value.participants).equal(2) - expect(await nbParticipantsUser3.next()).equal(2) - expect(await nbParticipantsUser2.next()).equal(2) + throw new Error("User 1 or 2 finished training at the 3rd round"); + + expect(logUser2Round3.value.participants).equal(2); + expect(logUser3Round3.value.participants).equal(2); + expect(await nbParticipantsUser3.next()).equal(2); + expect(await nbParticipantsUser2.next()).equal(2); // both user 2 and 3 did c), a) and are now in b) - expect(await statusUser2.next()).equal("updating model") - expect(await statusUser2.next()).equal("local training") + expect(await statusUser2.next()).equal("updating model"); + expect(await statusUser2.next()).equal("local training"); + + expect(await statusUser3.next()).equal("connecting to peers"); + expect(await statusUser3.next()).equal("updating model"); + expect(await statusUser3.next()).equal("local training"); - expect(await statusUser3.next()).equal("connecting to peers") - expect(await statusUser3.next()).equal("updating model") - expect(await statusUser3.next()).equal("local training") - /* USER 2 AND 3 LEAVE */ - await discoUser2.close() - expect(await statusUser3.next()).equal("not enough participants") - expect(await nbParticipantsUser3.next()).equal(1) + await discoUser2.close(); + expect(await statusUser3.next()).equal("not enough participants"); + expect(await nbParticipantsUser3.next()).equal(1); - await discoUser3.close() + await discoUser3.close(); }); -}) +}); diff --git a/server/tests/e2e/federated.spec.ts b/server/tests/e2e/federated.spec.ts index 20ba132e2..acd076555 100644 --- a/server/tests/e2e/federated.spec.ts +++ b/server/tests/e2e/federated.spec.ts @@ -1,13 +1,13 @@ import type * as http from "node:http"; import type { - DataFormat, - DataType, - Dataset, - EpochLogs, - RoundStatus, - Task, - TaskProvider, - WeightsContainer, + DataFormat, + DataType, + Dataset, + EpochLogs, + RoundStatus, + Task, + TaskProvider, + WeightsContainer, } from "@epfml/discojs"; import { Disco, defaultTasks } from "@epfml/discojs"; import { List } from "immutable"; @@ -17,340 +17,343 @@ import { Queue, datasets } from "../utils.js"; // Array.fromAsync not yet widely used (2024) async function arrayFromAsync(iter: AsyncIterable): Promise { - const ret: T[] = []; - for await (const e of iter) { - // TODO trick to allow other Promises to run - // else one client might progress alone without communicating with others - // will be fixed when client orchestrations in the server is correctly done - await new Promise((resolve) => setTimeout(resolve, 10)); - - ret.push(e); - } - return ret; + const ret: T[] = []; + for await (const e of iter) { + // TODO trick to allow other Promises to run + // else one client might progress alone without communicating with others + // will be fixed when client orchestrations in the server is correctly done + await new Promise((resolve) => setTimeout(resolve, 10)); + + ret.push(e); + } + return ret; } describe("end-to-end federated", () => { - let handle: http.Server | undefined; - async function startServer( - task: TaskProvider, - ): Promise { - const server = await Server.with(task); - - let url: URL; - [handle, url] = await server.serve(); - return url; - } - afterEach( - () => - new Promise((resolve, reject) => - handle?.close((e) => { - if (e !== undefined) reject(e); - else resolve(); - handle = undefined; - }), - ), - ); - - async function runUser( - url: URL, - task: Task, - dataset: Dataset, - preprocessOnce = true, - ): Promise<[WeightsContainer, EpochLogs]> { - const disco = new Disco(task, url, { preprocessOnce }); - - const logs = List(await arrayFromAsync(disco.trainByRound(dataset))); - await disco.close(); - - expect(logs.first()?.epochs.first()?.training.loss).to.be.above( - logs.last()?.epochs.last()?.training.loss as number, - ); - - const lastEpoch = logs.last()?.epochs.last(); - if (lastEpoch === undefined) throw new Error("no epoch ran"); - return [disco.trainer.model.weights, lastEpoch]; - } - - it("three cifar10 users reach consensus", { timeout: 200_000 }, async () => { - const task = await defaultTasks.cifar10.getTask(); - const cifar10Task: Task<"image", "federated"> = { - ...task, - trainingInformation: { - ...task.trainingInformation, - scheme: "federated", - aggregationStrategy: "mean", - minNbOfParticipants: 2, - }, - }; - const url = await startServer({ - getModel: () => defaultTasks.cifar10.getModel(), - getTask: () => Promise.resolve(cifar10Task), - }); - const dataset = await datasets.loadCifar10(); - - const [[m1, l1], [m2, l2], [m3, l3]] = await Promise.all([ - runUser(url, cifar10Task, dataset), - runUser(url, cifar10Task, dataset), - runUser(url, cifar10Task, dataset), - ]); - - for (const lastEpoch of [l1, l2, l3]) { - expect(lastEpoch.training.accuracy).to.be.greaterThan(0.4); - expect(lastEpoch.validation?.accuracy).to.be.greaterThan(0.4); - } - assert.isTrue(m1.equals(m2) && m2.equals(m3)); - }); - - it("two titanic users reach consensus", { timeout: 50_000 }, async () => { - const task = await defaultTasks.titanic.getTask(); - task.trainingInformation = { - ...task.trainingInformation, - minNbOfParticipants: 2, - }; - const url = await startServer({ - ...defaultTasks.titanic, - getTask: () => Promise.resolve(task), - }); - const dataset = datasets.loadTitanic(); - - const [[m1, l1], [m2, l2]] = await Promise.all([ - runUser(url, task, dataset), - runUser(url, task, dataset), - ]); - - for (const lastEpoch of [l1, l2]) { - expect(lastEpoch.training.accuracy).to.be.greaterThan(0.4); - expect(lastEpoch.validation?.accuracy).to.be.greaterThan(0.4); - } - assert.isTrue(m1.equals(m2)); - }); - - it("two lus_covid users reach consensus", { timeout: 200_000 }, async () => { - const task = await defaultTasks.lusCovid.getTask(); - task.trainingInformation = { - ...task.trainingInformation, - epochs: 16, - roundDuration: 2, - minNbOfParticipants: 2, - }; - const url = await startServer({ - ...defaultTasks.lusCovid, - getTask: () => Promise.resolve(task), - }); - const dataset = await datasets.loadLusCOVID(); - - const [[m1, l1], [m2, l2]] = await Promise.all([ - runUser(url, task, dataset), - runUser(url, task, dataset), - ]); - - for (const lastEpoch of [l1, l2]) { - expect(lastEpoch.training.accuracy).to.be.greaterThan(0.4); - expect(lastEpoch.validation?.accuracy).to.be.greaterThan(0.4); - } - assert.isTrue(m1.equals(m2)); - }); - - it("two wikitext reach consensus", { timeout: 500_000 }, async () => { - const task = await defaultTasks.wikitext.getTask(); - task.trainingInformation = { - ...task.trainingInformation, - epochs: 2, - roundDuration: 2, - minNbOfParticipants: 2, - }; - const url = await startServer({ - ...defaultTasks.wikitext, - getTask: () => Promise.resolve(task), - }); - const dataset = datasets.loadWikitext(); - - const [r1, r2] = await Promise.all([ - runUser(url, task, dataset, false), - runUser(url, task, dataset, false), - ]); - assert.isTrue(r1[0].equals(r2[0])); - }); - - it("clients emit expected events", { timeout: 100_000 }, async () => { - const task = await defaultTasks.lusCovid.getTask(); - task.trainingInformation = { - ...task.trainingInformation, - roundDuration: 1, - minNbOfParticipants: 2, - }; - const url = await startServer({ - ...defaultTasks.lusCovid, - getTask: () => Promise.resolve(task), - }); - const dataset = await datasets.loadLusCOVID(); - - /** - * When disco.trainByRound is called for the first time, the client connects to the server - * which returns the latest model, current round and nb of participants. - * Then at each round the event cycle is: - * a) onRoundBeingCommunication which updates the status to "local training" - * b) local training (the status remains "local training") - * c) onRoundEndCommunication which sends the local update and - * receives the global weights while emitting the status UPDATE - * - * Given this, it is important to note that calling disco.trainByRound().next() - * for the first time will perform a) and then b) where it stops and yields the round logs. - * Thus, c) isn't done and the model aggregation by the server is not performed during this first call to next(). - * - * Calling next() again will then do c), and back to a) and b). - * - * In this test the timeline is: - * - User 1 joins the task by themselves - * - User 2 joins - * - User 1 leaves - * - User 3 joins - * - User 2 & 3 leave - */ - - // Create User 1 - const discoUser1 = new Disco(task, url, { preprocessOnce: true }); - const statusUser1 = new Queue(); - const nbParticipantsUser1 = new Queue(); - discoUser1.on("status", (status) => statusUser1.put(status)); - discoUser1.on("participants", (participants) => - nbParticipantsUser1.put(participants), - ); - const generatorUser1 = discoUser1.trainByRound(dataset); - - // Have User 1 join the task and train locally for one round - await generatorUser1.next(); - expect(await statusUser1.next()).equal("local training"); - expect(await nbParticipantsUser1.next()).equal(1); - - // Calling next() a 2nd time makes User 1 go to c) where the client should - // stay stuck awaiting until another participant joins - const logUser1Round2Promise = generatorUser1.next(); - expect(await statusUser1.next()).equal("not enough participants"); - - // Create User 2 - const discoUser2 = new Disco(task, url, { preprocessOnce: true }); - const statusUser2 = new Queue(); - const nbParticipantsUser2 = new Queue(); - discoUser2.on("status", (status) => statusUser2.put(status)); - discoUser2.on("participants", (participants) => - nbParticipantsUser2.put(participants), - ); - const generatorUser2 = discoUser2.trainByRound(dataset); - - // Have User 2 join the task and train for one round - await generatorUser2.next(); - // User 2 did a) and b) - expect(await statusUser1.next()).equal("local training"); - expect(await statusUser2.next()).equal("local training"); - // User 1 is still in c) now waiting for user 2 to share their local update - // and for the server to aggregate the local updates - expect(await statusUser1.next()).equal("updating model"); - // User 2 connects to the server which triggers the participant event - expect(await nbParticipantsUser2.next()).equal(2); - // Receive the EnoughParticipants message with the participants - expect(await nbParticipantsUser1.next()).equal(2); - - // Proceed with round 2 - - // the server should answer with the new global weights - // and users should train locally on the new weights - await Promise.all([logUser1Round2Promise, generatorUser2.next()]); - // User 1 and 2 did c), a) and b) - expect(await statusUser2.next()).equal("updating model"); - expect(await statusUser1.next()).equal("local training"); - expect(await statusUser2.next()).equal("local training"); - // Receive the server payload during c) along with the participants - expect(await nbParticipantsUser2.next()).equal(2); - expect(await nbParticipantsUser1.next()).equal(2); - - // Make user 2 go to c) - const logUser2Round3Promise = generatorUser2.next(); - expect(await statusUser2.next()).equal("updating model"); - - // Have user 1 quit the session - await discoUser1.close(); - // User 2 receives the WaitingForMoreParticipants message - expect(await statusUser2.next()).equal("not enough participants"); - expect(await nbParticipantsUser2.next()).equal(1); - - // Create User 3 - const discoUser3 = new Disco(task, url, { preprocessOnce: true }); - const statusUser3 = new Queue(); - const nbParticipantsUser3 = new Queue(); - discoUser3.on("status", (status) => statusUser3.put(status)); - discoUser3.on("participants", (participants) => - nbParticipantsUser3.put(participants), - ); - const generatorUser3 = discoUser3.trainByRound(dataset); - - // User 3 joins mid-training and trains one local round - await generatorUser3.next(); - expect(await statusUser3.next()).equal("local training"); - expect(await nbParticipantsUser3.next()).equal(2); - - // User 2 is still in c) waiting for user 3 to share their local update - // and for the server to aggregate the local updates - expect(await statusUser2.next()).equal("updating model"); - // User 2 receives the EnoughParticipants message - expect(await nbParticipantsUser2.next()).equal(2); - - // User 3 sends their weights to the server - await Promise.all([logUser2Round3Promise, generatorUser3.next()]); - expect(await statusUser3.next()).equal("updating model"); - - // the server should accept user 3's weights (should not be outdated) and aggregate the global weights - // both user 2 and 3 did c), a) and are now in b) - expect(await statusUser2.next()).equal("local training"); - expect(await statusUser3.next()).equal("local training"); - // User 2 and 3 finish c) - expect(await nbParticipantsUser3.next()).equal(2); - expect(await nbParticipantsUser2.next()).equal(2); - - await discoUser2.close(); - expect(await statusUser3.next()).equal("not enough participants"); - // WaitForMoreParticipants message - expect(await nbParticipantsUser3.next()).equal(1); - - await discoUser3.close(); - }); - - /** - * Test if federated learning task lus_covid operates correctly with differential privacy - */ - it("three lus_covid clients meet consensus with differential privacy", { timeout: 1_000_000 }, async () => { - const task = await defaultTasks.lusCovid.getTask(); - task.trainingInformation = { - ...task.trainingInformation, - epochs: 20, - roundDuration: 10, - minNbOfParticipants: 3, - aggregationStrategy: "mean", - privacy: { - differentialPrivacy: { - epsilon: 50, - delta: 1e-5, - clippingRadius: 10, - } - } - }; - const url = await startServer({ - ...defaultTasks.lusCovid, - getTask: () => Promise.resolve(task), - }); - const dataset = await datasets.loadLusCOVID(); - - const [[m1, l1], [m2, l2], [m3, l3]] = await Promise.all([ - runUser(url, task, dataset), - runUser(url, task, dataset), - runUser(url, task, dataset), - ]); - - for (const lastEpoch of [l1, l2, l3]) { - expect(lastEpoch.training.accuracy).to.be.greaterThan(0.4); - expect(lastEpoch.validation?.accuracy).to.be.greaterThan(0.4); - } - assert.isTrue(m1.equals(m2) && m2.equals(m3)); - }) + let handle: http.Server | undefined; + async function startServer( + task: TaskProvider, + ): Promise { + const server = await Server.with(task); + + let url: URL; + [handle, url] = await server.serve(); + return url; + } + afterEach( + () => + new Promise((resolve, reject) => + handle?.close((e) => { + if (e !== undefined) reject(e); + else resolve(); + handle = undefined; + }), + ), + ); + + async function runUser( + url: URL, + task: Task, + dataset: Dataset, + preprocessOnce = true, + ): Promise<[WeightsContainer, EpochLogs]> { + const disco = new Disco(task, url, { preprocessOnce }); + + const logs = List(await arrayFromAsync(disco.trainByRound(dataset))); + await disco.close(); + + expect(logs.first()?.epochs.first()?.training.loss).to.be.above( + logs.last()?.epochs.last()?.training.loss as number, + ); + + const lastEpoch = logs.last()?.epochs.last(); + if (lastEpoch === undefined) throw new Error("no epoch ran"); + return [disco.trainer.model.weights, lastEpoch]; + } + + it("three cifar10 users reach consensus", { timeout: 200_000 }, async () => { + const task = await defaultTasks.cifar10.getTask(); + const cifar10Task: Task<"image", "federated"> = { + ...task, + trainingInformation: { + ...task.trainingInformation, + scheme: "federated", + aggregationStrategy: "mean", + minNbOfParticipants: 2, + }, + }; + const url = await startServer({ + getModel: () => defaultTasks.cifar10.getModel(), + getTask: () => Promise.resolve(cifar10Task), + }); + const dataset = await datasets.loadCifar10(); + + const [[m1, l1], [m2, l2], [m3, l3]] = await Promise.all([ + runUser(url, cifar10Task, dataset), + runUser(url, cifar10Task, dataset), + runUser(url, cifar10Task, dataset), + ]); + + for (const lastEpoch of [l1, l2, l3]) { + expect(lastEpoch.training.accuracy).to.be.greaterThan(0.4); + expect(lastEpoch.validation?.accuracy).to.be.greaterThan(0.4); + } + assert.isTrue(m1.equals(m2) && m2.equals(m3)); + }); + + it("two titanic users reach consensus", { timeout: 50_000 }, async () => { + const task = await defaultTasks.titanic.getTask(); + task.trainingInformation = { + ...task.trainingInformation, + minNbOfParticipants: 2, + }; + const url = await startServer({ + ...defaultTasks.titanic, + getTask: () => Promise.resolve(task), + }); + const dataset = datasets.loadTitanic(); + + const [[m1, l1], [m2, l2]] = await Promise.all([ + runUser(url, task, dataset), + runUser(url, task, dataset), + ]); + + for (const lastEpoch of [l1, l2]) { + expect(lastEpoch.training.accuracy).to.be.greaterThan(0.4); + expect(lastEpoch.validation?.accuracy).to.be.greaterThan(0.4); + } + assert.isTrue(m1.equals(m2)); + }); + + it("two lus_covid users reach consensus", { timeout: 200_000 }, async () => { + const task = await defaultTasks.lusCovid.getTask(); + task.trainingInformation = { + ...task.trainingInformation, + epochs: 16, + roundDuration: 2, + minNbOfParticipants: 2, + }; + const url = await startServer({ + ...defaultTasks.lusCovid, + getTask: () => Promise.resolve(task), + }); + const dataset = await datasets.loadLusCOVID(); + + const [[m1, l1], [m2, l2]] = await Promise.all([ + runUser(url, task, dataset), + runUser(url, task, dataset), + ]); + + for (const lastEpoch of [l1, l2]) { + expect(lastEpoch.training.accuracy).to.be.greaterThan(0.4); + expect(lastEpoch.validation?.accuracy).to.be.greaterThan(0.4); + } + assert.isTrue(m1.equals(m2)); + }); + + it("two wikitext reach consensus", { timeout: 500_000 }, async () => { + const task = await defaultTasks.wikitext.getTask(); + task.trainingInformation = { + ...task.trainingInformation, + epochs: 2, + roundDuration: 2, + minNbOfParticipants: 2, + }; + const url = await startServer({ + ...defaultTasks.wikitext, + getTask: () => Promise.resolve(task), + }); + const dataset = datasets.loadWikitext(); + + const [r1, r2] = await Promise.all([ + runUser(url, task, dataset, false), + runUser(url, task, dataset, false), + ]); + assert.isTrue(r1[0].equals(r2[0])); + }); + + it("clients emit expected events", { timeout: 100_000 }, async () => { + const task = await defaultTasks.lusCovid.getTask(); + task.trainingInformation = { + ...task.trainingInformation, + roundDuration: 1, + minNbOfParticipants: 2, + }; + const url = await startServer({ + ...defaultTasks.lusCovid, + getTask: () => Promise.resolve(task), + }); + const dataset = await datasets.loadLusCOVID(); + + /** + * When disco.trainByRound is called for the first time, the client connects to the server + * which returns the latest model, current round and nb of participants. + * Then at each round the event cycle is: + * a) onRoundBeingCommunication which updates the status to "local training" + * b) local training (the status remains "local training") + * c) onRoundEndCommunication which sends the local update and + * receives the global weights while emitting the status UPDATE + * + * Given this, it is important to note that calling disco.trainByRound().next() + * for the first time will perform a) and then b) where it stops and yields the round logs. + * Thus, c) isn't done and the model aggregation by the server is not performed during this first call to next(). + * + * Calling next() again will then do c), and back to a) and b). + * + * In this test the timeline is: + * - User 1 joins the task by themselves + * - User 2 joins + * - User 1 leaves + * - User 3 joins + * - User 2 & 3 leave + */ + + // Create User 1 + const discoUser1 = new Disco(task, url, { preprocessOnce: true }); + const statusUser1 = new Queue(); + const nbParticipantsUser1 = new Queue(); + discoUser1.on("status", (status) => statusUser1.put(status)); + discoUser1.on("participants", (participants) => + nbParticipantsUser1.put(participants), + ); + const generatorUser1 = discoUser1.trainByRound(dataset); + + // Have User 1 join the task and train locally for one round + await generatorUser1.next(); + expect(await statusUser1.next()).equal("local training"); + expect(await nbParticipantsUser1.next()).equal(1); + + // Calling next() a 2nd time makes User 1 go to c) where the client should + // stay stuck awaiting until another participant joins + const logUser1Round2Promise = generatorUser1.next(); + expect(await statusUser1.next()).equal("not enough participants"); + + // Create User 2 + const discoUser2 = new Disco(task, url, { preprocessOnce: true }); + const statusUser2 = new Queue(); + const nbParticipantsUser2 = new Queue(); + discoUser2.on("status", (status) => statusUser2.put(status)); + discoUser2.on("participants", (participants) => + nbParticipantsUser2.put(participants), + ); + const generatorUser2 = discoUser2.trainByRound(dataset); + + // Have User 2 join the task and train for one round + await generatorUser2.next(); + // User 2 did a) and b) + expect(await statusUser1.next()).equal("local training"); + expect(await statusUser2.next()).equal("local training"); + // User 1 is still in c) now waiting for user 2 to share their local update + // and for the server to aggregate the local updates + expect(await statusUser1.next()).equal("updating model"); + // User 2 connects to the server which triggers the participant event + expect(await nbParticipantsUser2.next()).equal(2); + // Receive the EnoughParticipants message with the participants + expect(await nbParticipantsUser1.next()).equal(2); + + // Proceed with round 2 + + // the server should answer with the new global weights + // and users should train locally on the new weights + await Promise.all([logUser1Round2Promise, generatorUser2.next()]); + // User 1 and 2 did c), a) and b) + expect(await statusUser2.next()).equal("updating model"); + expect(await statusUser1.next()).equal("local training"); + expect(await statusUser2.next()).equal("local training"); + // Receive the server payload during c) along with the participants + expect(await nbParticipantsUser2.next()).equal(2); + expect(await nbParticipantsUser1.next()).equal(2); + + // Make user 2 go to c) + const logUser2Round3Promise = generatorUser2.next(); + expect(await statusUser2.next()).equal("updating model"); + + // Have user 1 quit the session + await discoUser1.close(); + // User 2 receives the WaitingForMoreParticipants message + expect(await statusUser2.next()).equal("not enough participants"); + expect(await nbParticipantsUser2.next()).equal(1); + + // Create User 3 + const discoUser3 = new Disco(task, url, { preprocessOnce: true }); + const statusUser3 = new Queue(); + const nbParticipantsUser3 = new Queue(); + discoUser3.on("status", (status) => statusUser3.put(status)); + discoUser3.on("participants", (participants) => + nbParticipantsUser3.put(participants), + ); + const generatorUser3 = discoUser3.trainByRound(dataset); + + // User 3 joins mid-training and trains one local round + await generatorUser3.next(); + expect(await statusUser3.next()).equal("local training"); + expect(await nbParticipantsUser3.next()).equal(2); + + // User 2 is still in c) waiting for user 3 to share their local update + // and for the server to aggregate the local updates + expect(await statusUser2.next()).equal("updating model"); + // User 2 receives the EnoughParticipants message + expect(await nbParticipantsUser2.next()).equal(2); + + // User 3 sends their weights to the server + await Promise.all([logUser2Round3Promise, generatorUser3.next()]); + expect(await statusUser3.next()).equal("updating model"); + + // the server should accept user 3's weights (should not be outdated) and aggregate the global weights + // both user 2 and 3 did c), a) and are now in b) + expect(await statusUser2.next()).equal("local training"); + expect(await statusUser3.next()).equal("local training"); + // User 2 and 3 finish c) + expect(await nbParticipantsUser3.next()).equal(2); + expect(await nbParticipantsUser2.next()).equal(2); + + await discoUser2.close(); + expect(await statusUser3.next()).equal("not enough participants"); + // WaitForMoreParticipants message + expect(await nbParticipantsUser3.next()).equal(1); + + await discoUser3.close(); + }); + + /** + * Test if federated learning task lus_covid operates correctly with differential privacy + */ + it( + "three lus_covid clients meet consensus with differential privacy", + { timeout: 1_000_000 }, + async () => { + const task = await defaultTasks.lusCovid.getTask(); + task.trainingInformation = { + ...task.trainingInformation, + epochs: 20, + roundDuration: 10, + minNbOfParticipants: 3, + aggregationStrategy: "mean", + privacy: { + differentialPrivacy: { + epsilon: 50, + delta: 1e-5, + clippingRadius: 10, + }, + }, + }; + const url = await startServer({ + ...defaultTasks.lusCovid, + getTask: () => Promise.resolve(task), + }); + const dataset = await datasets.loadLusCOVID(); + + const [[m1, l1], [m2, l2], [m3, l3]] = await Promise.all([ + runUser(url, task, dataset), + runUser(url, task, dataset), + runUser(url, task, dataset), + ]); + + for (const lastEpoch of [l1, l2, l3]) { + expect(lastEpoch.training.accuracy).to.be.greaterThan(0.4); + expect(lastEpoch.validation?.accuracy).to.be.greaterThan(0.4); + } + assert.isTrue(m1.equals(m2) && m2.equals(m3)); + }, + ); }); - diff --git a/server/tests/utils.ts b/server/tests/utils.ts index 651ecb283..be942ecf4 100644 --- a/server/tests/utils.ts +++ b/server/tests/utils.ts @@ -4,67 +4,67 @@ import { loadCSV, loadImagesInDir, loadText } from "@epfml/discojs-node"; const DATASET_DIR = path.join(__dirname, "..", "..", "datasets"); export const datasets = { - async loadCifar10() { - // TODO single label means model can't be wrong - return (await loadImagesInDir(path.join(DATASET_DIR, "CIFAR10"))).zip( - Repeat("cat"), - ); - }, - async loadLusCOVID() { - const [positive, negative] = [ - ( - await loadImagesInDir(path.join(DATASET_DIR, "lus_covid", "COVID+")) - ).zip(Repeat("COVID-Positive")), - ( - await loadImagesInDir(path.join(DATASET_DIR, "lus_covid", "COVID-")) - ).zip(Repeat("COVID-Negative")), - ]; - return positive.chain(negative); - }, - async loadSimpleFace() { - const [adult, child] = [ - ( - await loadImagesInDir(path.join(DATASET_DIR, "simple_face", "adult")) - ).zip(Repeat("adult")), - ( - await loadImagesInDir(path.join(DATASET_DIR, "simple_face", "child")) - ).zip(Repeat("child")), - ]; - return adult.chain(child); - }, - loadTitanic: () => loadCSV(path.join(DATASET_DIR, "titanic_train.csv")), - loadWikitext: () => - loadText(path.join(DATASET_DIR, "wikitext", "wiki.train.tokens")).chain( - loadText(path.join(DATASET_DIR, "wikitext", "wiki.valid.tokens")), - ), + async loadCifar10() { + // TODO single label means model can't be wrong + return (await loadImagesInDir(path.join(DATASET_DIR, "CIFAR10"))).zip( + Repeat("cat"), + ); + }, + async loadLusCOVID() { + const [positive, negative] = [ + ( + await loadImagesInDir(path.join(DATASET_DIR, "lus_covid", "COVID+")) + ).zip(Repeat("COVID-Positive")), + ( + await loadImagesInDir(path.join(DATASET_DIR, "lus_covid", "COVID-")) + ).zip(Repeat("COVID-Negative")), + ]; + return positive.chain(negative); + }, + async loadSimpleFace() { + const [adult, child] = [ + ( + await loadImagesInDir(path.join(DATASET_DIR, "simple_face", "adult")) + ).zip(Repeat("adult")), + ( + await loadImagesInDir(path.join(DATASET_DIR, "simple_face", "child")) + ).zip(Repeat("child")), + ]; + return adult.chain(child); + }, + loadTitanic: () => loadCSV(path.join(DATASET_DIR, "titanic_train.csv")), + loadWikitext: () => + loadText(path.join(DATASET_DIR, "wikitext", "wiki.train.tokens")).chain( + loadText(path.join(DATASET_DIR, "wikitext", "wiki.valid.tokens")), + ), }; export class Queue { - #content = List<[index: number, T]>(); - // keep track of what was added and asked for - #index = { head: 0, tail: 0 }; + #content = List<[index: number, T]>(); + // keep track of what was added and asked for + #index = { head: 0, tail: 0 }; - put(e: T) { - this.#content = this.#content.push([this.#index.tail, e]); - this.#index.tail++; - } + put(e: T) { + this.#content = this.#content.push([this.#index.tail, e]); + this.#index.tail++; + } - async next(): Promise { - const index = this.#index.head; - this.#index.head++; + async next(): Promise { + const index = this.#index.head; + this.#index.head++; - for (;;) { - const ret = this.#content.first(); - if (ret !== undefined && ret[0] > index) - throw new Error("assertion failed: head's index bigger than ours"); + for (;;) { + const ret = this.#content.first(); + if (ret !== undefined && ret[0] > index) + throw new Error("assertion failed: head's index bigger than ours"); - // check that it is intended for us - if (ret?.[0] === index) { - this.#content = this.#content.shift(); - return ret[1]; - } + // check that it is intended for us + if (ret?.[0] === index) { + this.#content = this.#content.shift(); + return ret[1]; + } - await new Promise((resolve) => setTimeout(resolve, 10)); - } - } + await new Promise((resolve) => setTimeout(resolve, 10)); + } + } } diff --git a/server/tests/validator.spec.ts b/server/tests/validator.spec.ts index 1275fcddd..8f3ebebc3 100644 --- a/server/tests/validator.spec.ts +++ b/server/tests/validator.spec.ts @@ -3,65 +3,77 @@ import { describe, expect, it } from "vitest"; import { datasets } from "./utils.js"; describe("validator", () => { - it("can read and predict randomly on simple_face", { timeout: 20_000 }, async () => { - const provider = defaultTasks.simpleFace; - const dataset = await datasets.loadSimpleFace(); + it( + "can read and predict randomly on simple_face", + { timeout: 20_000 }, + async () => { + const provider = defaultTasks.simpleFace; + const dataset = await datasets.loadSimpleFace(); - const validator = new Validator( - await provider.getTask(), - await provider.getModel(), - ); + const validator = new Validator( + await provider.getTask(), + await provider.getModel(), + ); - let hits = 0; - let size = 0; - for await (const correct of validator.test(dataset)) { - if (correct) hits++; - size++; - } + let hits = 0; + let size = 0; + for await (const correct of validator.test(dataset)) { + if (correct) hits++; + size++; + } - expect(hits / size).to.be.greaterThan(0.3); - }); + expect(hits / size).to.be.greaterThan(0.3); + }, + ); - it("can read and predict randomly on titanic", { timeout: 10_000 }, async () => { - const provider = defaultTasks.titanic; - const dataset = datasets.loadTitanic(); + it( + "can read and predict randomly on titanic", + { timeout: 10_000 }, + async () => { + const provider = defaultTasks.titanic; + const dataset = datasets.loadTitanic(); - const validator = new Validator( - await provider.getTask(), - await provider.getModel(), - ); + const validator = new Validator( + await provider.getTask(), + await provider.getModel(), + ); - let hits = 0; - let size = 0; - for await (const correct of validator.test(dataset)) { - if (correct) hits++; - size++; - } + let hits = 0; + let size = 0; + for await (const correct of validator.test(dataset)) { + if (correct) hits++; + size++; + } - expect(hits / size).to.be.greaterThan(0.3); - }); + expect(hits / size).to.be.greaterThan(0.3); + }, + ); - it("can read and predict randomly on lus_covid", { timeout: 50_000 }, async () => { - const task = await defaultTasks.lusCovid.getTask(); - task.trainingInformation = { - ...task.trainingInformation, - roundDuration: 2, - minNbOfParticipants: 2, - }; - const dataset = await datasets.loadLusCOVID(); + it( + "can read and predict randomly on lus_covid", + { timeout: 50_000 }, + async () => { + const task = await defaultTasks.lusCovid.getTask(); + task.trainingInformation = { + ...task.trainingInformation, + roundDuration: 2, + minNbOfParticipants: 2, + }; + const dataset = await datasets.loadLusCOVID(); - const validator = new Validator( - task, - await defaultTasks.lusCovid.getModel(), - ); + const validator = new Validator( + task, + await defaultTasks.lusCovid.getModel(), + ); - let hits = 0; - let size = 0; - for await (const correct of validator.test(dataset)) { - if (correct) hits++; - size++; - } + let hits = 0; + let size = 0; + for await (const correct of validator.test(dataset)) { + if (correct) hits++; + size++; + } - expect(hits / size).to.be.greaterThan(0.3); - }); + expect(hits / size).to.be.greaterThan(0.3); + }, + ); }); diff --git a/server/tsconfig.json b/server/tsconfig.json index 5cba070bf..6a234b504 100644 --- a/server/tsconfig.json +++ b/server/tsconfig.json @@ -17,4 +17,4 @@ "path": "tsconfig.vitest.json" } ] -} \ No newline at end of file +} diff --git a/server/tsconfig.lib.json b/server/tsconfig.lib.json index 806611d85..9bea26418 100644 --- a/server/tsconfig.lib.json +++ b/server/tsconfig.lib.json @@ -1,5 +1,5 @@ { - "extends": "../tsconfig.base.json", - "compilerOptions": { "outDir": "dist" }, - "include": ["src"] + "extends": "../tsconfig.base.json", + "compilerOptions": { "outDir": "dist" }, + "include": ["src"] } diff --git a/server/tsconfig.vitest.json b/server/tsconfig.vitest.json index baf86a82e..135cacb84 100644 --- a/server/tsconfig.vitest.json +++ b/server/tsconfig.vitest.json @@ -1,5 +1,5 @@ { - "extends": "../tsconfig.base.json", - "compilerOptions": { "noEmit": true }, - "include": ["src", "tests"] + "extends": "../tsconfig.base.json", + "compilerOptions": { "noEmit": true }, + "include": ["src", "tests"] } diff --git a/tsconfig.base.json b/tsconfig.base.json index ccf837eb0..a43c63956 100644 --- a/tsconfig.base.json +++ b/tsconfig.base.json @@ -19,4 +19,3 @@ "noImplicitOverride": true } } - diff --git a/vitest.config.ts b/vitest.config.ts index bec1b211c..536c237da 100644 --- a/vitest.config.ts +++ b/vitest.config.ts @@ -1,30 +1,30 @@ import { defineConfig } from "vitest/config"; export default defineConfig({ - test: { - setupFiles: "./testSetupImportTFJSNode.ts", + test: { + setupFiles: "./testSetupImportTFJSNode.ts", - projects: [ - { - extends: true, - test: { name: "discojs", include: ["discojs/**/*.spec.ts"] }, - }, - { - extends: true, - test: { name: "discojs-node", include: ["discojs-node/**/*.spec.ts"] }, - }, - { - extends: true, - test: { - name: "discojs-web", - include: ["discojs-web/**/*.spec.ts"], - environment: "jsdom", - }, - }, - { - extends: true, - test: { name: "server", include: ["server/tests/**/*.spec.ts"] }, - }, - ], - }, + projects: [ + { + extends: true, + test: { name: "discojs", include: ["discojs/**/*.spec.ts"] }, + }, + { + extends: true, + test: { name: "discojs-node", include: ["discojs-node/**/*.spec.ts"] }, + }, + { + extends: true, + test: { + name: "discojs-web", + include: ["discojs-web/**/*.spec.ts"], + environment: "jsdom", + }, + }, + { + extends: true, + test: { name: "server", include: ["server/tests/**/*.spec.ts"] }, + }, + ], + }, }); diff --git a/webapp/cypress/e2e/store/models.cy.ts b/webapp/cypress/e2e/store/models.cy.ts index 67cda1fab..3c7e27e4c 100644 --- a/webapp/cypress/e2e/store/models.cy.ts +++ b/webapp/cypress/e2e/store/models.cy.ts @@ -2,45 +2,49 @@ import { defaultTasks } from "@epfml/discojs"; import { setupServerWith } from "../../support/e2e"; beforeEach(() => - cy.wrap(async () => { - const root = await navigator.storage.getDirectory(); - try { - await root.removeEntry("models", { recursive: true }); - } catch (e) { - if (e instanceof DOMException && e.name === "NotFoundError") return; - throw e; - } - }), + cy.wrap(async () => { + const root = await navigator.storage.getDirectory(); + try { + await root.removeEntry("models", { recursive: true }); + } catch (e) { + if (e instanceof DOMException && e.name === "NotFoundError") return; + throw e; + } + }), ); -it("stores models", +it( + "stores models", { retries: 5 }, // can exhaust memory () => { - setupServerWith(defaultTasks.titanic); + setupServerWith(defaultTasks.titanic); - cy.visit("/evaluate"); - cy.contains("button", "download").click(); - cy.contains("button", "test").should("exist"); + cy.visit("/evaluate"); + cy.contains("button", "download").click(); + cy.contains("button", "test").should("exist"); - cy.reload(); - cy.contains("button", "test").should("exist"); -}); + cy.reload(); + cy.contains("button", "test").should("exist"); + }, +); -it("stores larger models", +it( + "stores larger models", { retries: 5 }, // can exhaust memory () => { - setupServerWith(defaultTasks.wikitext); + setupServerWith(defaultTasks.wikitext); - cy.visit("/evaluate"); - cy.contains("button", "download").click(); - cy.contains("button", "test") - .should("exist") - .then( - () => - // storage takes time and no user feedback - new Promise((resolve) => setTimeout(resolve, 300)), - ); + cy.visit("/evaluate"); + cy.contains("button", "download").click(); + cy.contains("button", "test") + .should("exist") + .then( + () => + // storage takes time and no user feedback + new Promise((resolve) => setTimeout(resolve, 300)), + ); - cy.reload(); - cy.contains("button", "test").should("exist"); -}); + cy.reload(); + cy.contains("button", "test").should("exist"); + }, +); diff --git a/webapp/cypress/e2e/task-creation.cy.ts b/webapp/cypress/e2e/task-creation.cy.ts index 4ea2c9ecf..309404415 100644 --- a/webapp/cypress/e2e/task-creation.cy.ts +++ b/webapp/cypress/e2e/task-creation.cy.ts @@ -10,7 +10,7 @@ it("submits with tabular task", () => { cy.visit("/create"); - cy.get('form').should('be.visible'); // Wait for the form to be fully loaded + cy.get("form").should("be.visible"); // Wait for the form to be fully loaded cy.get("input[name='id']").type("id"); cy.get("select[name='dataType']").select("tabular"); @@ -88,7 +88,7 @@ it("submits with tabular task", () => { outputColumn: "output", tensorBackend: "tfjs", }, - } satisfies Task<"tabular", "federated">); + } satisfies Task<"tabular", "federated">); }); async function getArtifacts( diff --git a/webapp/cypress/support/e2e.ts b/webapp/cypress/support/e2e.ts index f47cb9f59..34951d125 100644 --- a/webapp/cypress/support/e2e.ts +++ b/webapp/cypress/support/e2e.ts @@ -22,7 +22,7 @@ export function setupServerWith( ) .as("taskAndModels"); - cy.get, unknown]>>("@taskAndModels") + cy.get, unknown]>>("@taskAndModels") .then((taskAndModels) => taskAndModels.map(([t]) => serialization.task.serializeToJSON(t)), ) @@ -67,30 +67,30 @@ type BasicKeys = | "aggregationStrategy"; export function basicTask( - dataType: D, - info: Omit, BasicKeys>, + dataType: D, + info: Omit, BasicKeys>, ): Task { - return { - id: "task", - dataType, - trainingInformation: { - epochs: 1, - batchSize: 1, - roundDuration: 1, - validationSplit: 1, - tensorBackend: "tfjs", - scheme: "local", - aggregationStrategy: "mean", - ...info, - }, - displayInformation: { - title: "task", - summary: { preview: "preview", overview: "overview" }, - }, - // cast as typescript doesn't work well w/ generics - } as Task; + return { + id: "task", + dataType, + trainingInformation: { + epochs: 1, + batchSize: 1, + roundDuration: 1, + validationSplit: 1, + tensorBackend: "tfjs", + scheme: "local", + aggregationStrategy: "mean", + ...info, + }, + displayInformation: { + title: "task", + summary: { preview: "preview", overview: "overview" }, + }, + // cast as typescript doesn't work well w/ generics + } as Task; } before(() => { - localStorage.debug = "discojs*,webapp*"; + localStorage.debug = "discojs*,webapp*"; }); diff --git a/webapp/index.html b/webapp/index.html index 8b7ff734c..ac36c1a97 100644 --- a/webapp/index.html +++ b/webapp/index.html @@ -1,15 +1,21 @@ - + - - - - + + + + Disco
diff --git a/webapp/package.json b/webapp/package.json index 84ff83644..081726f5c 100644 --- a/webapp/package.json +++ b/webapp/package.json @@ -1,50 +1,50 @@ { - "name": "webapp", - "private": true, - "type": "module", - "scripts": { - "start": "vite", - "build": "vue-tsc --build && vite build", - "test": "npm run test:unit && npm run test:e2e", - "test:unit": "vitest --run", - "test:e2e": "VITE_SERVER_URL=http://server start-server-and-test start http://localhost:1351 'cypress run --e2e'" - }, - "dependencies": { - "@epfml/discojs": "*", - "@epfml/discojs-web": "*", - "@msgpack/msgpack": "3", - "chart.js": "4", - "cypress": "15", - "d3": "7", - "driver.js": "1", - "pinia": "3", - "pinia-plugin-persistedstate-2": "2", - "vee-validate": "5.0.0-beta.0", - "vue": "3", - "vue-chartjs": "5", - "vue-router": "5", - "vue-tippy": "6", - "vue-toast-notification": "3", - "yup": "1", - "zod": "4" - }, - "devDependencies": { - "@pinia/testing": "1", - "@tailwindcss/vite": "4", - "@tsconfig/node20": "20", - "@types/d3": "7", - "@types/jsdom": "28", - "@vitejs/plugin-vue": "6", - "@vue/test-utils": "2", - "@vue/tsconfig": "0.9", - "canvas": "3", - "jsdom": "29", - "start-server-and-test": "3", - "tailwindcss": "4", - "typescript": "5", - "vite": "8", - "vite-plugin-node-polyfills": "0.26", - "vue-tsc": "3", - "vue3-spinners": "1" - } + "name": "webapp", + "private": true, + "type": "module", + "scripts": { + "start": "vite", + "build": "vue-tsc --build && vite build", + "test": "npm run test:unit && npm run test:e2e", + "test:unit": "vitest --run", + "test:e2e": "VITE_SERVER_URL=http://server start-server-and-test start http://localhost:1351 'cypress run --e2e'" + }, + "dependencies": { + "@epfml/discojs": "*", + "@epfml/discojs-web": "*", + "@msgpack/msgpack": "3", + "chart.js": "4", + "cypress": "15", + "d3": "7", + "driver.js": "1", + "pinia": "3", + "pinia-plugin-persistedstate-2": "2", + "vee-validate": "5.0.0-beta.0", + "vue": "3", + "vue-chartjs": "5", + "vue-router": "5", + "vue-tippy": "6", + "vue-toast-notification": "3", + "yup": "1", + "zod": "4" + }, + "devDependencies": { + "@pinia/testing": "1", + "@tailwindcss/vite": "4", + "@tsconfig/node20": "20", + "@types/d3": "7", + "@types/jsdom": "28", + "@vitejs/plugin-vue": "6", + "@vue/test-utils": "2", + "@vue/tsconfig": "0.9", + "canvas": "3", + "jsdom": "29", + "start-server-and-test": "3", + "tailwindcss": "4", + "typescript": "5", + "vite": "8", + "vite-plugin-node-polyfills": "0.26", + "vue-tsc": "3", + "vue3-spinners": "1" + } } diff --git a/webapp/public/404.html b/webapp/public/404.html index f73c6b94a..bd27934a8 100644 --- a/webapp/public/404.html +++ b/webapp/public/404.html @@ -1,17 +1,15 @@ - + - - - + + + DISCO - - - + + - diff --git a/webapp/src/assets/css/styles.css b/webapp/src/assets/css/styles.css index d64bcfd28..691caabdb 100644 --- a/webapp/src/assets/css/styles.css +++ b/webapp/src/assets/css/styles.css @@ -6,26 +6,31 @@ font-size: small; } -.tippy-box[data-theme~="custom-dark"][data-placement^="top"]>.tippy-arrow::before { +.tippy-box[data-theme~="custom-dark"][data-placement^="top"] + > .tippy-arrow::before { border-top-color: #6f7174; } -.tippy-box[data-theme~="custom-dark"][data-placement^="bottom"]>.tippy-arrow::before { +.tippy-box[data-theme~="custom-dark"][data-placement^="bottom"] + > .tippy-arrow::before { border-bottom-color: #6f7174; } -.tippy-box[data-theme~="custom-dark"][data-placement^="left"]>.tippy-arrow::before { +.tippy-box[data-theme~="custom-dark"][data-placement^="left"] + > .tippy-arrow::before { border-left-color: #6f7174; } -.tippy-box[data-theme~="custom-dark"][data-placement^="right"]>.tippy-arrow::before { +.tippy-box[data-theme~="custom-dark"][data-placement^="right"] + > .tippy-arrow::before { border-right-color: #6f7174; } @font-face { font-family: AmpleSoftMedium; - src: url('../fonts/AmpleSoftMedium.woff2') format('woff2'), - url('../fonts/AmpleSoftMedium.woff') format('woff'); + src: + url("../fonts/AmpleSoftMedium.woff2") format("woff2"), + url("../fonts/AmpleSoftMedium.woff") format("woff"); } .cards-gap { @@ -37,4 +42,4 @@ .cards-gap { gap: 2rem; } -} \ No newline at end of file +} diff --git a/webapp/src/assets/css/tailwind.css b/webapp/src/assets/css/tailwind.css index 85fbe5c40..e66dc68c1 100644 --- a/webapp/src/assets/css/tailwind.css +++ b/webapp/src/assets/css/tailwind.css @@ -1,4 +1,4 @@ -@import 'tailwindcss'; +@import "tailwindcss"; @custom-variant dark (&:is(.dark *)); @@ -25,4 +25,4 @@ --color-heading-light: #334155; --color-heading-dark: #fff; -} \ No newline at end of file +} diff --git a/webapp/src/assets/gif/DecentralizedGIF.vue b/webapp/src/assets/gif/DecentralizedGIF.vue index 282f3b1a3..c5d9ddd33 100644 --- a/webapp/src/assets/gif/DecentralizedGIF.vue +++ b/webapp/src/assets/gif/DecentralizedGIF.vue @@ -1,10 +1,12 @@ diff --git a/webapp/src/assets/gif/DiscoGIF.vue b/webapp/src/assets/gif/DiscoGIF.vue index 6017d16bd..d2730e4f2 100644 --- a/webapp/src/assets/gif/DiscoGIF.vue +++ b/webapp/src/assets/gif/DiscoGIF.vue @@ -1,6 +1,6 @@ diff --git a/webapp/src/assets/gif/FederatedGIF.vue b/webapp/src/assets/gif/FederatedGIF.vue index 7e15f6bc2..bbd002ec3 100644 --- a/webapp/src/assets/gif/FederatedGIF.vue +++ b/webapp/src/assets/gif/FederatedGIF.vue @@ -1,10 +1,12 @@ diff --git a/webapp/src/assets/logos/AriadneLabsLogo.vue b/webapp/src/assets/logos/AriadneLabsLogo.vue index 4bf58dfed..2e12a8eea 100644 --- a/webapp/src/assets/logos/AriadneLabsLogo.vue +++ b/webapp/src/assets/logos/AriadneLabsLogo.vue @@ -1,30 +1,86 @@ diff --git a/webapp/src/assets/logos/DiscoLogo.vue b/webapp/src/assets/logos/DiscoLogo.vue index b7fdb5a74..03c49549b 100644 --- a/webapp/src/assets/logos/DiscoLogo.vue +++ b/webapp/src/assets/logos/DiscoLogo.vue @@ -16,9 +16,7 @@ " > - + diff --git a/webapp/src/assets/logos/MLOLogo.vue b/webapp/src/assets/logos/MLOLogo.vue index 5435fb6fa..ba768b55e 100644 --- a/webapp/src/assets/logos/MLOLogo.vue +++ b/webapp/src/assets/logos/MLOLogo.vue @@ -213,7 +213,7 @@ diff --git a/webapp/src/assets/logos/TensorflowLogo.vue b/webapp/src/assets/logos/TensorflowLogo.vue index 5e58bf4e0..67dc8730e 100644 --- a/webapp/src/assets/logos/TensorflowLogo.vue +++ b/webapp/src/assets/logos/TensorflowLogo.vue @@ -1,57 +1,114 @@ diff --git a/webapp/src/assets/svg/BinIcon.vue b/webapp/src/assets/svg/BinIcon.vue index 6d817db91..28c0f1b18 100644 --- a/webapp/src/assets/svg/BinIcon.vue +++ b/webapp/src/assets/svg/BinIcon.vue @@ -15,10 +15,10 @@ export default { props: { customClass: { - default: 'h-9 w-5', - type: String + default: "h-9 w-5", + type: String, }, - viewBox: { default: '0 0 17 17', type: String } - } -} + viewBox: { default: "0 0 17 17", type: String }, + }, +}; diff --git a/webapp/src/assets/svg/CreateIcon.vue b/webapp/src/assets/svg/CreateIcon.vue index b0643190f..bb32d6964 100644 --- a/webapp/src/assets/svg/CreateIcon.vue +++ b/webapp/src/assets/svg/CreateIcon.vue @@ -6,18 +6,22 @@ :viewBox="viewBox" stroke="currentColor" > - - + + diff --git a/webapp/src/assets/svg/DiscoParticlesIcon.vue b/webapp/src/assets/svg/DiscoParticlesIcon.vue index ef44491ec..20c3f3957 100644 --- a/webapp/src/assets/svg/DiscoParticlesIcon.vue +++ b/webapp/src/assets/svg/DiscoParticlesIcon.vue @@ -1,16 +1,32 @@ diff --git a/webapp/src/assets/svg/DownArrow.vue b/webapp/src/assets/svg/DownArrow.vue index e6d4f00ad..caae1cd39 100644 --- a/webapp/src/assets/svg/DownArrow.vue +++ b/webapp/src/assets/svg/DownArrow.vue @@ -13,8 +13,11 @@ diff --git a/webapp/src/assets/svg/EvaluateIcon.vue b/webapp/src/assets/svg/EvaluateIcon.vue index 4bb5fc8c2..74161d961 100644 --- a/webapp/src/assets/svg/EvaluateIcon.vue +++ b/webapp/src/assets/svg/EvaluateIcon.vue @@ -6,17 +6,20 @@ :viewBox="viewBox" stroke="currentColor" > - + diff --git a/webapp/src/assets/svg/HomeIcon.vue b/webapp/src/assets/svg/HomeIcon.vue index ac38abf13..a1c143641 100644 --- a/webapp/src/assets/svg/HomeIcon.vue +++ b/webapp/src/assets/svg/HomeIcon.vue @@ -18,10 +18,10 @@ export default { props: { customClass: { - default: 'w-7 h-7', - type: String + default: "w-7 h-7", + type: String, }, - viewBox: { default: '0 0 24 24', type: String } - } -} + viewBox: { default: "0 0 24 24", type: String }, + }, +}; diff --git a/webapp/src/assets/svg/InfoIcon.vue b/webapp/src/assets/svg/InfoIcon.vue index 6db859760..1cde5ac46 100644 --- a/webapp/src/assets/svg/InfoIcon.vue +++ b/webapp/src/assets/svg/InfoIcon.vue @@ -18,10 +18,10 @@ export default { props: { customClass: { - default: 'bi bi-info-circle w-7 h-7', - type: String + default: "bi bi-info-circle w-7 h-7", + type: String, }, - viewBox: { default: '-1 -1 18 18', type: String } - } -} + viewBox: { default: "-1 -1 18 18", type: String }, + }, +}; diff --git a/webapp/src/assets/svg/ModelExchangeIcon.vue b/webapp/src/assets/svg/ModelExchangeIcon.vue index 558fd12be..873f70d8e 100644 --- a/webapp/src/assets/svg/ModelExchangeIcon.vue +++ b/webapp/src/assets/svg/ModelExchangeIcon.vue @@ -5,16 +5,19 @@ :class="customClass" :viewBox="viewBox" > - - - + + diff --git a/webapp/src/assets/svg/ModelIcon.vue b/webapp/src/assets/svg/ModelIcon.vue index a718f3d5a..e962a0f6b 100644 --- a/webapp/src/assets/svg/ModelIcon.vue +++ b/webapp/src/assets/svg/ModelIcon.vue @@ -17,8 +17,8 @@ diff --git a/webapp/src/assets/svg/MoonIcon.vue b/webapp/src/assets/svg/MoonIcon.vue index 8310baeeb..061e61a73 100644 --- a/webapp/src/assets/svg/MoonIcon.vue +++ b/webapp/src/assets/svg/MoonIcon.vue @@ -5,7 +5,7 @@ stroke="currentColor" :class="customClass" :viewBox="viewBox" - stroke-width="2" + stroke-width="2" > diff --git a/webapp/src/assets/svg/PeopleIcon.vue b/webapp/src/assets/svg/PeopleIcon.vue index 9086a9a23..e3c25a2d5 100644 --- a/webapp/src/assets/svg/PeopleIcon.vue +++ b/webapp/src/assets/svg/PeopleIcon.vue @@ -20,10 +20,10 @@ export default { props: { customClass: { - default: 'w-12 h-12 text-gray-300', - type: String + default: "w-12 h-12 text-gray-300", + type: String, }, - viewBox: { default: '-6 -3 24 24', type: String } - } -} + viewBox: { default: "-6 -3 24 24", type: String }, + }, +}; diff --git a/webapp/src/assets/svg/PlugIcon.vue b/webapp/src/assets/svg/PlugIcon.vue index 149e487ea..22b34e576 100644 --- a/webapp/src/assets/svg/PlugIcon.vue +++ b/webapp/src/assets/svg/PlugIcon.vue @@ -1,24 +1,28 @@ diff --git a/webapp/src/assets/svg/StackIcon.vue b/webapp/src/assets/svg/StackIcon.vue index 41cc99709..7989b7f09 100644 --- a/webapp/src/assets/svg/StackIcon.vue +++ b/webapp/src/assets/svg/StackIcon.vue @@ -16,13 +16,13 @@ diff --git a/webapp/src/assets/svg/SunIcon.vue b/webapp/src/assets/svg/SunIcon.vue index 7a988d99d..913f30e8b 100644 --- a/webapp/src/assets/svg/SunIcon.vue +++ b/webapp/src/assets/svg/SunIcon.vue @@ -1,27 +1,28 @@ - - \ No newline at end of file + + + + + + + diff --git a/webapp/src/assets/svg/TimerIcon.vue b/webapp/src/assets/svg/TimerIcon.vue index 1411d7669..40e902417 100644 --- a/webapp/src/assets/svg/TimerIcon.vue +++ b/webapp/src/assets/svg/TimerIcon.vue @@ -15,10 +15,10 @@ export default { props: { customClass: { - default: 'w-12 h-12 text-gray-300 dark:text-primary-dark', - type: String + default: "w-12 h-12 text-gray-300 dark:text-primary-dark", + type: String, }, - viewBox: { default: '-6 -3 24 24', type: String } - } -} + viewBox: { default: "-6 -3 24 24", type: String }, + }, +}; diff --git a/webapp/src/assets/svg/UpArrow.vue b/webapp/src/assets/svg/UpArrow.vue index 529bb9ff1..12d42a74d 100644 --- a/webapp/src/assets/svg/UpArrow.vue +++ b/webapp/src/assets/svg/UpArrow.vue @@ -13,8 +13,11 @@ diff --git a/webapp/src/components/App.vue b/webapp/src/components/App.vue index 87b9e3a78..2d3417553 100644 --- a/webapp/src/components/App.vue +++ b/webapp/src/components/App.vue @@ -2,52 +2,25 @@
- + - + diff --git a/webapp/src/components/containers/ButtonsCard.vue b/webapp/src/components/containers/ButtonsCard.vue index 0da414961..65666e7df 100644 --- a/webapp/src/components/containers/ButtonsCard.vue +++ b/webapp/src/components/containers/ButtonsCard.vue @@ -1,7 +1,6 @@ diff --git a/webapp/src/components/containers/IconCardHeader.vue b/webapp/src/components/containers/IconCardHeader.vue index 32cf43b2b..a5db16a91 100644 --- a/webapp/src/components/containers/IconCardHeader.vue +++ b/webapp/src/components/containers/IconCardHeader.vue @@ -1,14 +1,14 @@ diff --git a/webapp/src/components/pages/TaskList.vue b/webapp/src/components/pages/TaskList.vue index d59461b9d..443c6a7e9 100644 --- a/webapp/src/components/pages/TaskList.vue +++ b/webapp/src/components/pages/TaskList.vue @@ -1,110 +1,129 @@ diff --git a/webapp/src/components/progress_bars/ProgressIcon.vue b/webapp/src/components/progress_bars/ProgressIcon.vue index ac225bde7..b22f52bcd 100644 --- a/webapp/src/components/progress_bars/ProgressIcon.vue +++ b/webapp/src/components/progress_bars/ProgressIcon.vue @@ -4,20 +4,31 @@
diff --git a/webapp/src/components/progress_bars/TestingButtons.vue b/webapp/src/components/progress_bars/TestingButtons.vue index 2d232e8d8..832fdf3a5 100644 --- a/webapp/src/components/progress_bars/TestingButtons.vue +++ b/webapp/src/components/progress_bars/TestingButtons.vue @@ -3,16 +3,10 @@ v-show="showPrev || showNext" class="mx-auto flex gap-4 lg:gap-8 mt-8 lg:mt-12 justify-center" > - + previous - + next
diff --git a/webapp/src/components/progress_bars/TrainingBar.vue b/webapp/src/components/progress_bars/TrainingBar.vue index 4f5284d42..ce6b4e86d 100644 --- a/webapp/src/components/progress_bars/TrainingBar.vue +++ b/webapp/src/components/progress_bars/TrainingBar.vue @@ -4,12 +4,11 @@ v-if="title !== undefined && displayTitle" class="flex flex-wrap font-disco text-3xl justify-center" > - {{ title }} + {{ + title + }}
-
+