From 275537f90fbd76a34a38257d6f17d8bf956a3b22 Mon Sep 17 00:00:00 2001 From: runwangdl Date: Mon, 13 Apr 2026 22:22:03 +0000 Subject: [PATCH 1/6] [NE16] Add GAP9_w_NE16 platform: NE16 accelerator Engine on GAP9 Mirrors the Siracusa_w_neureka pattern. NE16Platform extends GAP9Platform with engines=[NE16Engine, GAP9ClusterEngine]; NE16Deployer extends GAP9Deployer (reuses ClDma transformers via GAP9Bindings). New Target: Deeploy/Targets/NE16/ (Platform, Engine, Bindings, Parsers, Tiler, Deployer, Templates, TileConstraints, TopologyOptimizationPasses). The _weightEncode function is ported from pulp-nnx/test/Ne16Weight.py (single CIN_SUBTILE=16 mode, no 1x1 vs 3x3 split). ConvTemplate subtile constants set per ne16_task_defs.h (output 3x3, weight stride bytes PW=16 DW/Dense=144). New test infrastructure: - DeeployTest/deeployRunner_tiled_gap9_w_ne16.py - DeeployTest/test_gap9_ne16_tiled_config.py (PW/DW/Dense RQ Conv) DeeployTest wiring: - testUtils/platformMapping.py: register GAP9_w_NE16 in the platforms list, mapPlatform, setupMemoryPlatform, mapDeployer. - testMVP.py: include GAP9_w_NE16 in the EngineColoringDeployerWrapper branch (without it NE16AdjustWeightMemoryLayoutPass never fires and parsing backtracks to exhaustion). - testUtils/core/execution.py: build the GAP9 SDK 'image' target for GAP9_w_NE16 too (so chip.soc.mram.bin is produced before gvsoc run). - CMakeLists.txt, DeeployTest/CMakeLists.txt: accept GAP9_w_NE16 alongside GAP9 in the platform branches. - TargetLibraries/GAP9/CMakeLists.txt: for GAP9_w_NE16 platform, add_subdirectory on pulp-nnx with USE_NE16=ON and link it into deeploygap9. Fix: Deeploy/Targets/PULPOpen/Templates/FloatGemmTemplate.py referenced an undefined symbol float32_tPtr from Deeploy.AbstractDataTypes; define it locally via PointerClass(float32_t) to unblock the import chain reached by NE16Platform. Verified on gvsoc gap9.evk: PW 1x1 RQ (Regular_RQ): 0/1152 errors, 901917 cycles DW 3x3 RQ (DW_2D_RQ): 0/1280 errors, 27339 cycles (--enable-3x3) Dense 3x3 (Regular_2D_RQ): 0/6372 errors, 244595 cycles (--enable-3x3) --- .../workflows/_runner-gap9-w-ne16-tiled.yml | 54 +++ .../ci-platform-gap9-w-ne16-tiled.yml | 46 ++ CMakeLists.txt | 12 +- Deeploy/Targets/NE16/Bindings.py | 72 ++++ Deeploy/Targets/NE16/Deployer.py | 40 ++ Deeploy/Targets/NE16/Engine.py | 81 ++++ .../MemoryLevelAnnotationPasses.py | 45 ++ .../NE16/OptimizationPasses/__init__.py | 5 + Deeploy/Targets/NE16/Parsers.py | 203 +++++++++ Deeploy/Targets/NE16/Platform.py | 63 +++ .../NE16/Templates/AllocateTemplate.py | 18 + .../Targets/NE16/Templates/ConvTemplate.py | 398 ++++++++++++++++++ Deeploy/Targets/NE16/Templates/__init__.py | 5 + .../TileConstraints/NE16DenseConstraint.py | 268 ++++++++++++ .../NE16DepthwiseConstraint.py | 265 ++++++++++++ .../NE16PointwiseConstraint.py | 298 +++++++++++++ .../NE16/TileConstraints/RequantHelpers.py | 53 +++ .../Targets/NE16/TileConstraints/__init__.py | 5 + Deeploy/Targets/NE16/Tiler.py | 29 ++ .../NE16/TopologyOptimizationPasses/Passes.py | 278 ++++++++++++ .../TopologyOptimizationPasses/__init__.py | 5 + Deeploy/Targets/NE16/__init__.py | 5 + .../PULPOpen/Templates/FloatGemmTemplate.py | 5 +- DeeployTest/CMakeLists.txt | 2 +- .../Integer/Conv/Dense_2D_RQ/inputs.npz | Bin 0 -> 8456 bytes .../Integer/Conv/Dense_2D_RQ/network.onnx | Bin 0 -> 9817 bytes .../Integer/Conv/Dense_2D_RQ/outputs.npz | Bin 0 -> 8458 bytes DeeployTest/conftest.py | 1 + .../deeployRunner_tiled_gap9_w_ne16.py | 22 + DeeployTest/testMVP.py | 2 +- DeeployTest/testUtils/core/execution.py | 2 +- DeeployTest/testUtils/platformMapping.py | 26 +- DeeployTest/test_gap9_ne16_tiled_config.py | 38 ++ DeeployTest/test_platforms.py | 106 +++++ TargetLibraries/GAP9/CMakeLists.txt | 16 + 35 files changed, 2457 insertions(+), 11 deletions(-) create mode 100644 .github/workflows/_runner-gap9-w-ne16-tiled.yml create mode 100644 .github/workflows/ci-platform-gap9-w-ne16-tiled.yml create mode 100644 Deeploy/Targets/NE16/Bindings.py create mode 100644 Deeploy/Targets/NE16/Deployer.py create mode 100644 Deeploy/Targets/NE16/Engine.py create mode 100644 Deeploy/Targets/NE16/OptimizationPasses/MemoryLevelAnnotationPasses.py create mode 100644 Deeploy/Targets/NE16/OptimizationPasses/__init__.py create mode 100644 Deeploy/Targets/NE16/Parsers.py create mode 100644 Deeploy/Targets/NE16/Platform.py create mode 100644 Deeploy/Targets/NE16/Templates/AllocateTemplate.py create mode 100644 Deeploy/Targets/NE16/Templates/ConvTemplate.py create mode 100644 Deeploy/Targets/NE16/Templates/__init__.py create mode 100644 Deeploy/Targets/NE16/TileConstraints/NE16DenseConstraint.py create mode 100644 Deeploy/Targets/NE16/TileConstraints/NE16DepthwiseConstraint.py create mode 100644 Deeploy/Targets/NE16/TileConstraints/NE16PointwiseConstraint.py create mode 100644 Deeploy/Targets/NE16/TileConstraints/RequantHelpers.py create mode 100644 Deeploy/Targets/NE16/TileConstraints/__init__.py create mode 100644 Deeploy/Targets/NE16/Tiler.py create mode 100644 Deeploy/Targets/NE16/TopologyOptimizationPasses/Passes.py create mode 100644 Deeploy/Targets/NE16/TopologyOptimizationPasses/__init__.py create mode 100644 Deeploy/Targets/NE16/__init__.py create mode 100644 DeeployTest/Tests/Kernels/Integer/Conv/Dense_2D_RQ/inputs.npz create mode 100644 DeeployTest/Tests/Kernels/Integer/Conv/Dense_2D_RQ/network.onnx create mode 100644 DeeployTest/Tests/Kernels/Integer/Conv/Dense_2D_RQ/outputs.npz create mode 100644 DeeployTest/deeployRunner_tiled_gap9_w_ne16.py create mode 100644 DeeployTest/test_gap9_ne16_tiled_config.py diff --git a/.github/workflows/_runner-gap9-w-ne16-tiled.yml b/.github/workflows/_runner-gap9-w-ne16-tiled.yml new file mode 100644 index 0000000000..bff85b1b95 --- /dev/null +++ b/.github/workflows/_runner-gap9-w-ne16-tiled.yml @@ -0,0 +1,54 @@ +# SPDX-FileCopyrightText: 2026 ETH Zurich and University of Bologna +# +# SPDX-License-Identifier: Apache-2.0 + +--- +name: _runner-gap9-w-ne16-tiled + +"on": + workflow_call: + inputs: + runner: + required: true + type: string + docker-image: + required: true + type: string + pytest-markers: + required: true + type: string + +jobs: + test-runner-gap9-w-ne16-tiled: + runs-on: ${{ inputs.runner }} + container: + image: ${{ inputs.docker-image }} + steps: + - name: Checkout Repo + uses: actions/checkout@v4 + with: + submodules: recursive + - name: Build Deeploy + shell: bash + run: | + source /app/install/gap9-sdk/.gap9-venv/bin/activate + source /app/install/gap9-sdk/configs/gap9_evk_audio.sh || true + pip install -e . || true + deactivate + - name: Cache ccache + uses: actions/cache/restore@v4 + with: + path: /app/.ccache + key: ccache-gap9 + - name: Run Test + run: | + source /app/install/gap9-sdk/.gap9-venv/bin/activate + source /app/install/gap9-sdk/configs/gap9_evk_audio.sh || true + export GVSOC_INSTALL_DIR=/app/install/gap9-sdk/install/workstation + export GAP_RISCV_GCC_TOOLCHAIN=/app/install/gcc/gap9 + cd DeeployTest + mkdir -p /app/.ccache + export CCACHE_DIR=/app/.ccache + pytest test_platforms.py -v -m "${{ inputs.pytest-markers }}" + deactivate + shell: bash diff --git a/.github/workflows/ci-platform-gap9-w-ne16-tiled.yml b/.github/workflows/ci-platform-gap9-w-ne16-tiled.yml new file mode 100644 index 0000000000..5f45bbafeb --- /dev/null +++ b/.github/workflows/ci-platform-gap9-w-ne16-tiled.yml @@ -0,0 +1,46 @@ +# SPDX-FileCopyrightText: 2026 ETH Zurich and University of Bologna +# +# SPDX-License-Identifier: Apache-2.0 + +--- +name: CI • GAP9 + NE16 (Tiled) + +"on": + push: + branches: + - "**" + tags: + - "v*.*.*" + pull_request: + workflow_dispatch: + inputs: + docker_image_deeploy: + description: "Deeploy Image to use" + required: false + default: "ghcr.io/pulp-platform/deeploy-gap9:latest" + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + select-env: + uses: ./.github/workflows/_select-env.yml + with: + docker_image_deeploy: ${{ github.event.inputs.docker_image_deeploy || github.repository == 'pulp-platform/Deeploy' && 'ghcr.io/pulp-platform/deeploy-gap9:latest'}} + + gap9-w-ne16-kernels-tiled-singlebuffer-L2: + needs: select-env + uses: ./.github/workflows/_runner-gap9-w-ne16-tiled.yml + with: + runner: ${{ needs.select-env.outputs.runner }} + docker-image: ${{ needs.select-env.outputs.image }} + pytest-markers: "gap9_w_ne16_tiled and kernels and singlebuffer and l2" + + gap9-w-ne16-kernels-tiled-doublebuffer-L2: + needs: select-env + uses: ./.github/workflows/_runner-gap9-w-ne16-tiled.yml + with: + runner: ${{ needs.select-env.outputs.runner }} + docker-image: ${{ needs.select-env.outputs.image }} + pytest-markers: "gap9_w_ne16_tiled and kernels and doublebuffer and l2" diff --git a/CMakeLists.txt b/CMakeLists.txt index 0e07d64a9e..3fcbb7800b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -20,8 +20,8 @@ if(TOOLCHAIN STREQUAL GCC) set(CMAKE_INTERPROCEDURAL_OPTIMIZATION TRUE) endif() -set(platform MemPool CACHE STRING "Platform (MemPool, SoftHier, QEMU, Siracusa, Siracusa_w_neureka, PULP-Open, GAP9, Generic, Snitch)") -set_property(CACHE platform PROPERTY STRINGS MemPool SoftHier QEMU Siracusa Siracusa_w_neureka PULP-Open GAP9 Generic Snitch) +set(platform MemPool CACHE STRING "Platform (MemPool, SoftHier, QEMU, Siracusa, Siracusa_w_neureka, PULP-Open, GAP9, GAP9_w_NE16, Generic, Snitch)") +set_property(CACHE platform PROPERTY STRINGS MemPool SoftHier QEMU Siracusa Siracusa_w_neureka PULP-Open GAP9 GAP9_w_NE16 Generic Snitch) if(platform STREQUAL MemPool) message(STATUS "Building for platform 'MemPool'") @@ -33,8 +33,8 @@ elseif(platform STREQUAL Siracusa_w_neureka) message(STATUS "Building for platform 'Siracusa_w_neureka'") elseif(platform STREQUAL PULPOpen) message(STATUS "Building for platform 'PULP-Open'") -elseif(platform STREQUAL GAP9) - message(STATUS "Building for platform 'GAP9'") +elseif(platform STREQUAL GAP9 OR platform STREQUAL GAP9_w_NE16) + message(STATUS "Building for platform '${platform}'") set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}) # Select SDK config based on simulator type @@ -62,7 +62,7 @@ endif() # Import useful functions / macros include(${CMAKE_CURRENT_LIST_DIR}/cmake/Util.cmake) # Only if not GAP9 -if(NOT platform STREQUAL GAP9) +if(NOT platform STREQUAL GAP9 AND NOT platform STREQUAL GAP9_w_NE16) include(${CMAKE_CURRENT_LIST_DIR}/cmake/common.cmake) endif() include(${CMAKE_CURRENT_LIST_DIR}/cmake/simulation.cmake) @@ -231,7 +231,7 @@ if(platform STREQUAL Siracusa OR platform STREQUAL Siracusa_w_neureka OR platfor endif() -if(platform STREQUAL GAP9) +if(platform STREQUAL GAP9 OR platform STREQUAL GAP9_w_NE16) project(${TESTNAME} LANGUAGES C ASM) include(${CMAKE_CURRENT_LIST_DIR}/cmake/gap9/gap9_gvsoc.cmake) include(${CMAKE_CURRENT_LIST_DIR}/cmake/gap9/gap9_board.cmake) diff --git a/Deeploy/Targets/NE16/Bindings.py b/Deeploy/Targets/NE16/Bindings.py new file mode 100644 index 0000000000..58db14aee3 --- /dev/null +++ b/Deeploy/Targets/NE16/Bindings.py @@ -0,0 +1,72 @@ +# SPDX-FileCopyrightText: 2024 ETH Zurich and University of Bologna +# +# SPDX-License-Identifier: Apache-2.0 + +from Deeploy.AbstractDataTypes import PointerClass +from Deeploy.CommonExtensions.DataTypes import int8_t, int32_t, uint8_t +from Deeploy.DeeployTypes import NodeBinding +from Deeploy.Targets.GAP9.Bindings import GAP9ClusterTransformer as ClusterTransformer +from Deeploy.Targets.Generic.TypeCheckers import ConvChecker +from Deeploy.Targets.NE16.Templates.ConvTemplate import NE16DenseConv2D_Template, NE16DWConv2D_Template, \ + NE16PWConv2D_Template, NE16RqntDenseConv2D_Template, NE16RqntDWConv2D_Template, NE16RqntPWConv2D_Template +from Deeploy.Targets.PULPOpen.TypeCheckers import PULPConvChecker + +NE16RQSPWConv2DBindings = [ + NodeBinding( + PULPConvChecker( + [PointerClass(data_in_type), + PointerClass(weight_type), + PointerClass(int32_t), + PointerClass(int32_t)], [PointerClass(data_out_type)]), NE16RqntPWConv2D_Template, ClusterTransformer) + for data_in_type in [uint8_t, int8_t] + for data_out_type in [uint8_t, int8_t] + for weight_type in [uint8_t, int8_t] +] +NE16PWConv2DBindings = [ + NodeBinding( + ConvChecker( + [PointerClass(data_in_type), PointerClass(weight_type), + PointerClass(int32_t)], [PointerClass(int32_t)]), NE16PWConv2D_Template, ClusterTransformer) + for data_in_type in [uint8_t, int8_t] + for weight_type in [uint8_t, int8_t] +] + +NE16RQSDWConv2DBindings = [ + NodeBinding( + PULPConvChecker( + [PointerClass(data_in_type), + PointerClass(weight_type), + PointerClass(int32_t), + PointerClass(int32_t)], [PointerClass(data_out_type)]), NE16RqntDWConv2D_Template, ClusterTransformer) + for data_in_type in [uint8_t, int8_t] + for data_out_type in [uint8_t, int8_t] + for weight_type in [uint8_t, int8_t] +] +NE16DWConv2DBindings = [ + NodeBinding( + ConvChecker( + [PointerClass(data_in_type), PointerClass(weight_type), + PointerClass(int32_t)], [PointerClass(int32_t)]), NE16DWConv2D_Template, ClusterTransformer) + for data_in_type in [uint8_t, int8_t] + for weight_type in [uint8_t, int8_t] +] + +NE16RQSDenseConv2DBindings = [ + NodeBinding( + PULPConvChecker( + [PointerClass(data_in_type), + PointerClass(weight_type), + PointerClass(int32_t), + PointerClass(int32_t)], [PointerClass(data_out_type)]), NE16RqntDenseConv2D_Template, ClusterTransformer) + for data_in_type in [uint8_t, int8_t] + for data_out_type in [uint8_t, int8_t] + for weight_type in [uint8_t, int8_t] +] +NE16DenseConv2DBindings = [ + NodeBinding( + ConvChecker( + [PointerClass(data_in_type), PointerClass(weight_type), + PointerClass(int32_t)], [PointerClass(int32_t)]), NE16DenseConv2D_Template, ClusterTransformer) + for data_in_type in [uint8_t, int8_t] + for weight_type in [uint8_t, int8_t] +] diff --git a/Deeploy/Targets/NE16/Deployer.py b/Deeploy/Targets/NE16/Deployer.py new file mode 100644 index 0000000000..7f5d6c9748 --- /dev/null +++ b/Deeploy/Targets/NE16/Deployer.py @@ -0,0 +1,40 @@ +# SPDX-FileCopyrightText: 2024 ETH Zurich and University of Bologna +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import Callable, Dict, Type + +import onnx_graphsurgeon as gs + +from Deeploy.AbstractDataTypes import Pointer +from Deeploy.CommonExtensions.OptimizationPasses.TopologyOptimizationPasses.LoweringOptimizationPasses import \ + NCHWtoNHWCPass, PULPNCHWtoNHWCPass +from Deeploy.DeeployTypes import DeploymentPlatform, TopologyOptimizer +from Deeploy.Targets.GAP9.Deployer import GAP9Deployer +from Deeploy.Targets.NE16.TopologyOptimizationPasses.Passes import ConvEngineDiscolorationPass, NE16OptimizationPass + + +class NE16Deployer(GAP9Deployer): + + def __init__(self, + graph: gs.Graph, + deploymentPlatform: DeploymentPlatform, + inputTypes: Dict[str, Type[Pointer]], + loweringOptimizer: TopologyOptimizer, + scheduler: Callable = lambda graph: list(graph.nodes), + name: str = 'DeeployNetwork', + default_channels_first = False, + deeployStateDir: str = "DeeployStateDir", + inputOffsets = {}): + super().__init__(graph, deploymentPlatform, inputTypes, loweringOptimizer, scheduler, name, + default_channels_first, deeployStateDir, inputOffsets) + + if self.Platform.engines[0].enable3x3: + for idx in range(len(self.loweringOptimizer.passes)): + if isinstance(self.loweringOptimizer.passes[idx], PULPNCHWtoNHWCPass): + self.loweringOptimizer.passes[idx] = NCHWtoNHWCPass(self.default_channels_first) + + self.loweringOptimizer.passes += [ + ConvEngineDiscolorationPass(), + NE16OptimizationPass(self.default_channels_first, "NE16") + ] diff --git a/Deeploy/Targets/NE16/Engine.py b/Deeploy/Targets/NE16/Engine.py new file mode 100644 index 0000000000..48f5bca284 --- /dev/null +++ b/Deeploy/Targets/NE16/Engine.py @@ -0,0 +1,81 @@ +# SPDX-FileCopyrightText: 2024 ETH Zurich and University of Bologna +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import List + +import onnx_graphsurgeon as gs + +from Deeploy.DeeployTypes import DeploymentEngine, NodeMapper +from Deeploy.Targets.Generic.Layers import ConvLayer +from Deeploy.Targets.NE16.Parsers import NE16DenseConv2DParser, NE16DWConv2DParser, NE16PWConv2DParser, \ + NE16RQSDenseConv2DParser, NE16RQSDWConv2DParser, NE16RQSPWConv2DParser +from Deeploy.Targets.NE16.Tiler import NE16DenseConv2DTilingReadyBindings, NE16DWConv2DTilingReadyBindings, \ + NE16PWConv2DTilingReadyBindings, NE16RQSDenseConv2DTilingReadyBindings, NE16RQSDWConv2DTilingReadyBindings, \ + NE16RQSPWConv2DTilingReadyBindings +from Deeploy.Targets.PULPOpen.Layers import PULPRQSConvLayer + +NE16RqntPWConv2DMapper = NodeMapper(NE16RQSPWConv2DParser(), NE16RQSPWConv2DTilingReadyBindings) +NE16PWConv2DMapper = NodeMapper(NE16PWConv2DParser(), NE16PWConv2DTilingReadyBindings) + +NE16RqntDWConv2DMapper = NodeMapper(NE16RQSDWConv2DParser(), NE16RQSDWConv2DTilingReadyBindings) +NE16DWConv2DMapper = NodeMapper(NE16DWConv2DParser(), NE16DWConv2DTilingReadyBindings) + +NE16RqntDenseConv2DMapper = NodeMapper(NE16RQSDenseConv2DParser(), NE16RQSDenseConv2DTilingReadyBindings) +NE16DenseConv2DMapper = NodeMapper(NE16DenseConv2DParser(), NE16DenseConv2DTilingReadyBindings) + +NE16Mapping = { + 'RequantizedConv': PULPRQSConvLayer([NE16RqntPWConv2DMapper, NE16RqntDWConv2DMapper, NE16RqntDenseConv2DMapper]), + 'Conv': ConvLayer([NE16PWConv2DMapper, NE16DWConv2DMapper, NE16DenseConv2DMapper]), +} + +_includeList = ["pulp_nnx_ne16.h", "pulp_nnx_util.h", "ne16_pulp_bsp.h", "ne16.h", "ne16_task.h"] + +_ne16InitCode = r""" +ne16_pulp_conf_t conf = {.max_stall = 8}; +ne16_nnx_init(ne16_pulp_get_dev(), &conf); +""" + + +class NE16Engine(DeploymentEngine): + + def __init__(self, + name: str, + Mapping = NE16Mapping, + initCode: str = _ne16InitCode, + includeList: List[str] = _includeList, + enable3x3: bool = False, + enableStrides: bool = False) -> None: + super().__init__(name, Mapping, initCode, includeList) + + self.enable3x3 = enable3x3 + self.enableStrides = enableStrides + + def isDenseConv(self, node) -> bool: + return node.op in ["Conv", "RequantizedConv"] and \ + isinstance(node.inputs[1], gs.Constant) and \ + node.attrs['kernel_shape'] == [3, 3] and \ + node.attrs['dilations'] == [1, 1] and \ + node.attrs['group'] == 1 and \ + (node.attrs['strides'] == [1, 1] or self.enableStrides) + + def isPWConv(self, node) -> bool: + return node.op in ["Conv", "RequantizedConv"] and \ + isinstance(node.inputs[1], gs.Constant) and \ + node.attrs['kernel_shape'] == [1, 1] and \ + node.attrs['dilations'] == [1, 1] and \ + (node.attrs['strides'] == [1, 1] or self.enableStrides) + + def isDWConv(self, node) -> bool: + return node.op in ["Conv", "RequantizedConv"] and \ + isinstance(node.inputs[1], gs.Constant) and \ + node.attrs['kernel_shape'] == [3, 3] and \ + node.attrs['dilations'] == [1, 1] and \ + node.attrs['group'] != 1 and \ + (node.attrs['strides'] == [1, 1] or self.enableStrides) + + def canExecute(self, node: gs.Node) -> bool: + if self.enable3x3: + return self.isPWConv(node) or self.isDWConv(node) or self.isDenseConv(node) + else: + return self.isPWConv(node) diff --git a/Deeploy/Targets/NE16/OptimizationPasses/MemoryLevelAnnotationPasses.py b/Deeploy/Targets/NE16/OptimizationPasses/MemoryLevelAnnotationPasses.py new file mode 100644 index 0000000000..b6a530a319 --- /dev/null +++ b/Deeploy/Targets/NE16/OptimizationPasses/MemoryLevelAnnotationPasses.py @@ -0,0 +1,45 @@ +# SPDX-FileCopyrightText: 2023 ETH Zurich and University of Bologna +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import Tuple + +import numpy as np +import onnx_graphsurgeon as gs + +from Deeploy.CommonExtensions.OptimizationPasses.PassClasses import SequentialPass +from Deeploy.DeeployTypes import ConstantBuffer, NetworkContext +from Deeploy.MemoryLevelExtension.MemoryLevels import MemoryLevel + + +class AnnotateNE16WeightMemoryLevel(SequentialPass): + + def __init__(self, ne16EngineName: str, weightMemoryLevel: MemoryLevel): + self._weightMemoryLevel = weightMemoryLevel + self.ne16EngineName = ne16EngineName + super().__init__() + + def apply(self, ctxt: NetworkContext, graph: gs.Graph) -> Tuple[NetworkContext, gs.Graph]: + + def _ne16WeightBufferSize(buffer: ConstantBuffer) -> int: + return int(np.prod(buffer.shape)) # Weights are encoded as bytes so no need to check for typeWidth + + weightMemoryOccupation = 0 + + # Current weight memory occupation + for buffer in {**ctxt.globalObjects, **ctxt.localObjects}.values(): + if hasattr(buffer, "_memoryLevel") and buffer._memoryLevel == self._weightMemoryLevel.name: + weightMemoryOccupation += _ne16WeightBufferSize(buffer) + + ne16Nodes = [node for node in graph.nodes if node.attrs["engine"] == self.ne16EngineName] + for node in ne16Nodes: + if node.op in ["Conv", "RequantizedConv"]: + + if not (ctxt.is_local(node.inputs[1].name) or ctxt.is_global(node.inputs[1].name)): + continue + + buffer = ctxt.lookup(node.inputs[1].name) + if weightMemoryOccupation + _ne16WeightBufferSize(buffer) < self._weightMemoryLevel.size: + buffer._memoryLevel = self._weightMemoryLevel.name + weightMemoryOccupation += _ne16WeightBufferSize(buffer) + return ctxt, graph diff --git a/Deeploy/Targets/NE16/OptimizationPasses/__init__.py b/Deeploy/Targets/NE16/OptimizationPasses/__init__.py new file mode 100644 index 0000000000..be436b64a3 --- /dev/null +++ b/Deeploy/Targets/NE16/OptimizationPasses/__init__.py @@ -0,0 +1,5 @@ +# SPDX-FileCopyrightText: 2024 ETH Zurich and University of Bologna +# +# SPDX-License-Identifier: Apache-2.0 + +from . import * diff --git a/Deeploy/Targets/NE16/Parsers.py b/Deeploy/Targets/NE16/Parsers.py new file mode 100644 index 0000000000..3d157114fc --- /dev/null +++ b/Deeploy/Targets/NE16/Parsers.py @@ -0,0 +1,203 @@ +# SPDX-FileCopyrightText: 2024 ETH Zurich and University of Bologna +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import Tuple + +import onnx_graphsurgeon as gs + +from Deeploy.DeeployTypes import NetworkContext +from Deeploy.Targets.Generic.Parsers import Conv2DParser, ConvParser, RQSParserInterface + + +class NE16Conv2DBaseParser(Conv2DParser): + + def parseNode(self, node: gs.Node) -> bool: + if not super().parseNode(node): + return False + + if not all([ + # No dilation support + self.operatorRepresentation['dilations'] == [1, 1], + # Channels have to be last + 'channels_first' in self.operatorRepresentation and not self.operatorRepresentation['channels_first'], + # Expect "weight_offset" attribute in the node + "weight_offset" in node.attrs, + ]): + return False + + self.operatorRepresentation['padding_y_top'] = int(self.operatorRepresentation['pads'][0]) + self.operatorRepresentation['padding_x_left'] = int(self.operatorRepresentation['pads'][1]) + self.operatorRepresentation['padding_y_bottom'] = int(self.operatorRepresentation['pads'][2]) + self.operatorRepresentation['padding_x_right'] = int(self.operatorRepresentation['pads'][3]) + self.operatorRepresentation['weight_offset'] = int(node.attrs["weight_offset"]) + + return True + + def parseNodeCtxt(self, + ctxt: NetworkContext, + node: gs.Node, + channels_first: bool = True) -> Tuple[NetworkContext, bool]: + # LMACAN: Cannot reuse the Conv2DParser's parserNodeCtxt because it requires the weight shape + # to be of length 4 whereas ne16 does a specific weight encoding so the shape + # ends up being equal to 3. + newCtxt, ret = ConvParser.parseNodeCtxt(self, ctxt, node, channels_first) + + if not ret: + return ctxt, False + + # LMACAN: c/p of Conv2DParser's parserNodeCtxt but with a different weight shape check + # and enforcing that the channels_first is false + data_in = newCtxt.lookup(self.operatorRepresentation['data_in']) + data_out = newCtxt.lookup(self.operatorRepresentation['data_out']) + weight = newCtxt.lookup(self.operatorRepresentation['weight']) + + if not all([ + channels_first == False, + len(data_in.shape) == 4, + # LMACAN: weight shape should be equal to 3 because we have to do the ne16's + # special weight encoding. Dense 3x3 uses rank 4, + # PW/DW use rank 3. + len(weight.shape) in (3, 4), + ]): + return newCtxt, False + + self.operatorRepresentation['batch'] = data_in.shape[0] + self.operatorRepresentation['dim_im_in_x'] = data_in.shape[1] + self.operatorRepresentation['dim_im_in_y'] = data_in.shape[2] + self.operatorRepresentation['ch_im_in'] = data_in.shape[3] + self.operatorRepresentation['dim_im_out_x'] = data_out.shape[1] + self.operatorRepresentation['dim_im_out_y'] = data_out.shape[2] + self.operatorRepresentation['ch_im_out'] = data_out.shape[3] + + # No requantization + self.operatorRepresentation['mul'] = 'NULL' + self.operatorRepresentation['add'] = 'NULL' + self.operatorRepresentation['shift'] = 'NULL' + + return newCtxt, True + + +class NE16DWConv2DParser(NE16Conv2DBaseParser): + + def parseNode(self, node: gs.Node) -> bool: + if not super().parseNode(node): + return False + + # After NE16 weight encoding for DW, the encoded weight shape no longer + # carries cout==group (all channels are packed into the cinMinor + # dimension). Trust the ONNX `group` attribute alone: for DW, + # group > 1 AND group == channel_out AND kernel_shape == [3,3]. + if not all([ + self.operatorRepresentation['kernel_shape'] == [3, 3], + self.operatorRepresentation['group'] > 1, + ]): + return False + + return True + + +class NE16RQSDWConv2DParser(NE16DWConv2DParser, RQSParserInterface): + + def parseNode(self, node: gs.Node) -> bool: + ret = all([ + RQSParserInterface.parseNode(self, node), + NE16DWConv2DParser.parseNode(self, node), + ]) + + return ret + + def parseNodeCtxt(self, + ctxt: NetworkContext, + node: gs.Node, + channels_first: bool = True) -> Tuple[NetworkContext, bool]: + newCtxt, ret = super().parseNodeCtxt(ctxt, node, channels_first) + + if not ret: + return ctxt, False + + inputs = ['data_in', 'weight', 'mul', 'add'] + for idx, inputNode in enumerate(node.inputs): + self.operatorRepresentation[inputs[idx]] = ctxt.lookup(inputNode.name).name + + return newCtxt, True + + +class NE16PWConv2DParser(NE16Conv2DBaseParser): + + def parseNode(self, node: gs.Node) -> bool: + if not super().parseNode(node): + return False + + if not all([ + self.operatorRepresentation['kernel_shape'] == [1, 1], + self.operatorRepresentation['group'] == 1, + ]): + return False + + return True + + +class NE16RQSPWConv2DParser(NE16PWConv2DParser, RQSParserInterface): + + def parseNode(self, node: gs.Node) -> bool: + ret = all([ + RQSParserInterface.parseNode(self, node), + NE16PWConv2DParser.parseNode(self, node), + ]) + return ret + + def parseNodeCtxt(self, + ctxt: NetworkContext, + node: gs.Node, + channels_first: bool = True) -> Tuple[NetworkContext, bool]: + newCtxt, ret = super().parseNodeCtxt(ctxt, node, channels_first) + + if not ret: + return ctxt, False + + inputs = ['data_in', 'weight', 'mul', 'add'] + for idx, inputNode in enumerate(node.inputs): + self.operatorRepresentation[inputs[idx]] = ctxt.lookup(inputNode.name).name + + return newCtxt, True + + +class NE16DenseConv2DParser(NE16Conv2DBaseParser): + + def parseNode(self, node: gs.Node) -> bool: + if not super().parseNode(node): + return False + + if not all([ + self.operatorRepresentation['kernel_shape'] == [3, 3], + self.operatorRepresentation['group'] == 1, + ]): + return False + + return True + + +class NE16RQSDenseConv2DParser(NE16DenseConv2DParser, RQSParserInterface): + + def parseNode(self, node: gs.Node) -> bool: + ret = all([ + RQSParserInterface.parseNode(self, node), + NE16DenseConv2DParser.parseNode(self, node), + ]) + return ret + + def parseNodeCtxt(self, + ctxt: NetworkContext, + node: gs.Node, + channels_first: bool = True) -> Tuple[NetworkContext, bool]: + newCtxt, ret = super().parseNodeCtxt(ctxt, node, channels_first) + + if not ret: + return ctxt, False + + inputs = ['data_in', 'weight', 'mul', 'add'] + for idx, inputNode in enumerate(node.inputs): + self.operatorRepresentation[inputs[idx]] = ctxt.lookup(inputNode.name).name + + return newCtxt, True diff --git a/Deeploy/Targets/NE16/Platform.py b/Deeploy/Targets/NE16/Platform.py new file mode 100644 index 0000000000..2c6fddf8e5 --- /dev/null +++ b/Deeploy/Targets/NE16/Platform.py @@ -0,0 +1,63 @@ +# SPDX-FileCopyrightText: 2024 ETH Zurich and University of Bologna +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional + +from Deeploy.CommonExtensions.OptimizationPasses.TopologyOptimizationPasses.LoweringOptimizationPasses import \ + RequantizedGemmToPwPass +from Deeploy.DeeployTypes import TopologyOptimizer +from Deeploy.MemoryLevelExtension.MemoryLevels import MemoryHierarchy, MemoryLevel +from Deeploy.Targets.GAP9.Platform import GAP9ClusterEngine, GAP9ConstantBuffer, GAP9Platform, GAP9StructBuffer, \ + GAP9TransientBuffer, GAP9VariableBuffer, MemoryGAP9Platform, MemoryGAP9PlatformWrapper +from Deeploy.Targets.NE16.Engine import NE16Engine +from Deeploy.Targets.PULPOpen.Platform import PULPOptimizer + +NE16Optimizer = TopologyOptimizer([ + *PULPOptimizer.passes, + RequantizedGemmToPwPass(), +], name = "NE16Optimizer") + + +class NE16Platform(GAP9Platform): + + def __init__(self, + engines = None, + variableBuffer = GAP9VariableBuffer, + constantBuffer = GAP9ConstantBuffer, + structBuffer = GAP9StructBuffer, + transientBuffer = GAP9TransientBuffer) -> None: + if engines is None: + engines = [NE16Engine("NE16"), GAP9ClusterEngine("GAP9Cluster")] + super().__init__(engines, variableBuffer, constantBuffer, structBuffer, transientBuffer) + + +class MemoryNE16Platform(MemoryGAP9Platform): + + def __init__(self, + memoryHierarchy: MemoryHierarchy, + defaultTargetMemoryLevel: MemoryLevel, + weightMemoryLevel: Optional[MemoryLevel] = None, + engines = None, + variableBuffer = GAP9VariableBuffer, + constantBuffer = GAP9ConstantBuffer, + structBuffer = GAP9StructBuffer, + transientBuffer = GAP9TransientBuffer) -> None: + if engines is None: + engines = [NE16Engine("NE16"), GAP9ClusterEngine("GAP9Cluster")] + super().__init__(memoryHierarchy, defaultTargetMemoryLevel, engines, variableBuffer, constantBuffer, + structBuffer, transientBuffer) + self.weightMemoryLevel = weightMemoryLevel + + +class MemoryNE16PlatformWrapper(MemoryGAP9PlatformWrapper): + + def __init__(self, + platform: NE16Platform, + memoryHierarchy: MemoryHierarchy, + defaultTargetMemoryLevel: MemoryLevel, + weightMemoryLevel: Optional[MemoryLevel] = None): + assert isinstance(platform, NE16Platform), \ + f"Given platform is not an instance of NE16Platform. Platform type: {type(platform).__name__}" + super().__init__(platform, memoryHierarchy, defaultTargetMemoryLevel) + self.weightMemoryLevel = weightMemoryLevel diff --git a/Deeploy/Targets/NE16/Templates/AllocateTemplate.py b/Deeploy/Targets/NE16/Templates/AllocateTemplate.py new file mode 100644 index 0000000000..502b5af578 --- /dev/null +++ b/Deeploy/Targets/NE16/Templates/AllocateTemplate.py @@ -0,0 +1,18 @@ +# SPDX-FileCopyrightText: 2023 ETH Zurich and University of Bologna +# +# SPDX-License-Identifier: Apache-2.0 + +from Deeploy.DeeployTypes import NodeTemplate + +ne16GenericGlobalInitTemplate = NodeTemplate(""" +% if _memoryLevel == "L1": +static PI_L1 ${type.referencedType.typeName} ${name}[${size}] = {${values}};\n +% elif _memoryLevel == "L2" or _memoryLevel is None: +static PI_L2 ${type.referencedType.typeName} ${name}[${size}] = {${values}};\n +% elif _memoryLevel == "L3": +// ${name} is allocated in L3 \n +static PI_L2 ${type.referencedType.typeName}* ${name}; +% elif _memoryLevel == "WeightMemory_SRAM": +static __attribute__((section(".weightmem_sram"))) ${type.referencedType.typeName} ${name}[${size}] = {${values}};\n +% endif +""") diff --git a/Deeploy/Targets/NE16/Templates/ConvTemplate.py b/Deeploy/Targets/NE16/Templates/ConvTemplate.py new file mode 100644 index 0000000000..337f5e10c4 --- /dev/null +++ b/Deeploy/Targets/NE16/Templates/ConvTemplate.py @@ -0,0 +1,398 @@ +# SPDX-FileCopyrightText: 2024 ETH Zurich and University of Bologna +# +# SPDX-License-Identifier: Apache-2.0 + +from abc import abstractmethod +from typing import Dict, List, Tuple + +import numpy as np + +from Deeploy.DeeployTypes import ConstantBuffer, NetworkContext, NodeTemplate, OperatorRepresentation + + +def _getNumTiles(fullDim: int, tileDim: int) -> int: + return int(np.ceil(fullDim / tileDim)) + + +def _getBorderTileSize(fullDim: int, tileDim: int) -> int: + return fullDim % tileDim if fullDim % tileDim > 0 else tileDim + + +def ioStridesFromDimensions(width: int, channel: int, bits: int) -> Tuple[int, int]: + """stridesFromDimensions + Returns strides in bytes. + """ + width_stride = channel * bits // 8 + height_stride = width * width_stride + return height_stride, width_stride + + +def getNormQuantConf0(use_relu: bool, layerwise_output_shift: int, scale_bits: int, use_bias: bool, + use_shift: bool) -> int: + conf0 = 0 + conf0 |= 1 << 4 # Use Normalization and quantization + if scale_bits == 32: + conf0 |= 2 << 12 + conf0 |= layerwise_output_shift << 16 + if not use_relu: + conf0 |= 1 << 23 + if use_shift: + conf0 |= 1 << 24 + if use_bias: + conf0 |= 1 << 25 + return conf0 + + +def getInputAddrOffset(width_in: int, width_in_stride: int, padding_top: int, padding_left: int) -> int: + return (padding_top * width_in + padding_left) * width_in_stride + + +class NE16ConvTemplate(NodeTemplate): + + def __init__(self, templateStr: str): + super().__init__(templateStr) + + @classmethod + @abstractmethod + def getCounters( + cls, channel_in: int, height_out: int, width_out: int, channel_out: int, padding_bottom: int, + padding_right: int, + operatorRepresentation: OperatorRepresentation) -> Tuple[int, int, int, int, int, int, int, int, int, int]: + pass + + @classmethod + @abstractmethod + def getWeightStrides(cls, channel_in: int) -> Tuple[int, int, int]: + pass + + @classmethod + @abstractmethod + def getConf0(cls, output_bits: int, weight_bits: int, input_signed: bool, use_wmem: bool) -> int: + pass + + def alignToContext(self, ctxt: NetworkContext, + operatorRepresentation: OperatorRepresentation) -> Tuple[NetworkContext, Dict, List[str]]: + data_in: ConstantBuffer = ctxt.lookup(operatorRepresentation['data_in']) + data_out: ConstantBuffer = ctxt.lookup(operatorRepresentation['data_out']) + weight: ConstantBuffer = ctxt.lookup(operatorRepresentation['weight']) + + operatorRepresentation['input_signed'] = data_in._type.referencedType.typeMin < 0 + operatorRepresentation['use_relu'] = data_out._type.referencedType.typeMin >= 0 + + operatorRepresentation['input_bits'] = data_in._type.referencedType.typeWidth + operatorRepresentation['output_bits'] = data_out._type.referencedType.typeWidth + operatorRepresentation['weight_bits'] = weight._type.referencedType.typeWidth + + operatorRepresentation["input_typeWidth_bytes"] = int(np.ceil(data_in._type.referencedType.typeWidth / 8)) + operatorRepresentation["output_typeWidth_bytes"] = int(np.ceil(data_out._type.referencedType.typeWidth / 8)) + + operatorRepresentation["weight_addr_offset"] = 0 + + operatorRepresentation["use_wmem"] = hasattr(weight, + "_memoryLevel") and weight._memoryLevel == "WeightMemory_SRAM" + + dim_im_in_x_stride, dim_im_in_y_stride = ioStridesFromDimensions(operatorRepresentation["dim_im_in_y"], + operatorRepresentation["ch_im_in"], + operatorRepresentation["input_bits"]) + operatorRepresentation["dim_im_in_y_stride"] = dim_im_in_y_stride + operatorRepresentation["dim_im_in_x_stride"] = dim_im_in_x_stride + + dim_im_out_x_stride, dim_im_out_y_stride = ioStridesFromDimensions(operatorRepresentation["dim_im_out_y"], + operatorRepresentation["ch_im_out"], + operatorRepresentation["output_bits"]) + operatorRepresentation["dim_im_out_y_stride"] = dim_im_out_y_stride + operatorRepresentation["dim_im_out_x_stride"] = dim_im_out_x_stride + + operatorRepresentation["input_addr_offset"] = getInputAddrOffset(operatorRepresentation["dim_im_in_y"], + operatorRepresentation["dim_im_in_y_stride"], + operatorRepresentation["padding_y_top"], + operatorRepresentation["padding_x_left"]) + + nKo, nKi, nHo, nWo, bKo, bKi, bHo, bWo, bHi, bWi = self.getCounters( + operatorRepresentation["ch_im_in"], operatorRepresentation["dim_im_out_x"], + operatorRepresentation["dim_im_out_y"], operatorRepresentation["ch_im_out"], + operatorRepresentation["padding_y_bottom"], operatorRepresentation["padding_x_right"], + operatorRepresentation) + + operatorRepresentation["nKo"] = nKo + operatorRepresentation["nKi"] = nKi + operatorRepresentation["nHo"] = nHo + operatorRepresentation["nWo"] = nWo + operatorRepresentation["bKo"] = bKo + operatorRepresentation["bKi"] = bKi + operatorRepresentation["bHo"] = bHo + operatorRepresentation["bWo"] = bWo + operatorRepresentation["bHi"] = bHi + operatorRepresentation["bWi"] = bWi + + weightStrideD0, weightStrideD1, weightStrideD2 = self.getWeightStrides(operatorRepresentation["ch_im_in"]) + + operatorRepresentation["weightStrideD0"] = weightStrideD0 + operatorRepresentation["weightStrideD1"] = weightStrideD1 + operatorRepresentation["weightStrideD2"] = weightStrideD2 + + operatorRepresentation["conf0"] = self.getConf0(operatorRepresentation["output_bits"], + operatorRepresentation["weight_bits"], + operatorRepresentation["input_signed"], + operatorRepresentation["use_wmem"]) + + operatorRepresentation["wmem_addr_offset"] = 0x10400000 if operatorRepresentation["use_wmem"] else 0 + + operatorRepresentation["ne16_kernel_shape"] = self.NE16_KERNEL_SHAPE + operatorRepresentation["ne16_depthwise"] = self.NE16_IS_DEPTHWISE + operatorRepresentation["ne16_subtile_output_channel"] = self.NE16_SUBTILE_OUTPUT_CHANNEL + + # If requantized + if operatorRepresentation["mul"] != "NULL": + mulBuff = ctxt.lookup(operatorRepresentation["mul"]) + mulBits = mulBuff._type.referencedType.typeWidth + operatorRepresentation["conf0"] |= getNormQuantConf0(operatorRepresentation["use_relu"], + operatorRepresentation["log2D"], mulBits, "add" + in operatorRepresentation, False) + return ctxt, operatorRepresentation, [] + + +class NE162DPWConvTemplate(NE16ConvTemplate): + + NE16_KERNEL_SHAPE = 1 + NE16_IS_DEPTHWISE = 0 + NE16_SUBTILE_OUTPUT_CHANNEL = 32 + + def __init__(self, templateStr: str): + super().__init__(templateStr) + + @classmethod + def getCounters( + cls, channel_in: int, height_out: int, width_out: int, channel_out: int, padding_bottom: int, + padding_right: int, + operatorRepresentation: OperatorRepresentation) -> Tuple[int, int, int, int, int, int, int, int, int, int]: + # NE16 subtiles: INPUT_CHANNEL=16, OUTPUT_HxW=3x3, OUTPUT_CHANNEL=32 + n_channel_out_subtiles = _getNumTiles(channel_out, 32) + n_channel_in_subtiles = _getNumTiles(channel_in, 16) + n_height_out_subtiles = _getNumTiles(height_out, 3) + n_width_out_subtiles = _getNumTiles(width_out, 3) + + channel_out_border = _getBorderTileSize(channel_out, 32) + channel_in_border = _getBorderTileSize(channel_in, 16) + height_out_border = _getBorderTileSize(height_out, 3) + width_out_border = _getBorderTileSize(width_out, 3) + height_in_border = height_out_border - padding_bottom + width_in_border = width_out_border - padding_right + + return (n_channel_out_subtiles, n_channel_in_subtiles, n_height_out_subtiles, n_width_out_subtiles, + channel_out_border, channel_in_border, height_out_border, width_out_border, height_in_border, + width_in_border) + + @classmethod + def getWeightStrides(cls, channel_in: int) -> Tuple[int, int, int]: + # NE16 PW 1x1: per (cout, cinMajor) block = bits * H*W * cinMinorBytes + # = 8 * 1 * 2 = 16 bytes for 8-bit weights with CIN_SUBTILE=16 + n_channel_in = _getNumTiles(channel_in, 16) + _NE16_PW_WEIGHT_BYTES = 16 # bits * HW * cinMinorBytes = 8*1*2 + return _NE16_PW_WEIGHT_BYTES, _NE16_PW_WEIGHT_BYTES * n_channel_in, 0 + + @classmethod + def getConf0(cls, output_bits: int, weight_bits: int, input_signed: bool, use_wmem: bool) -> int: + conf0 = 0 + conf0 |= weight_bits - 1 + conf0 |= 2 << 5 # PW MODE + if use_wmem: + conf0 |= 1 << 9 + conf0 |= 1 << 15 # Layerwise weight offset mode + if output_bits == 32: + conf0 |= 2 << 21 + if input_signed: + conf0 |= 1 << 26 + return conf0 + + +class NE162DDWConvTemplate(NE16ConvTemplate): + + NE16_KERNEL_SHAPE = 3 + NE16_IS_DEPTHWISE = 1 + # For DW, hardware replicates input channels as output channels, so the + # output-channel subtile size equals the input-channel subtile (16). + NE16_SUBTILE_OUTPUT_CHANNEL = 16 + + def __init__(self, templateStr: str): + super().__init__(templateStr) + + @classmethod + def getCounters( + cls, channel_in: int, height_out: int, width_out: int, channel_out: int, padding_bottom: int, + padding_right: int, + operatorRepresentation: OperatorRepresentation) -> Tuple[int, int, int, int, int, int, int, int, int, int]: + _ = operatorRepresentation # operatorRepresentation not accessed for now because it's just for pointwise kernels + + # NE16 DW 3x3: CIN_SUBTILE=16 single mode, output 3x3 + n_channel_out_subtiles = _getNumTiles(channel_out, 16) + n_channel_in_subtiles = n_channel_out_subtiles + n_height_out_subtiles = _getNumTiles(height_out, 3) + n_width_out_subtiles = _getNumTiles(width_out, 3) + + channel_out_border = _getBorderTileSize(channel_out, 16) + channel_in_border = channel_out_border + height_out_border = _getBorderTileSize(height_out, 3) + width_out_border = _getBorderTileSize(width_out, 3) + height_in_border = height_out_border + 2 - padding_bottom + width_in_border = width_out_border + 2 - padding_right + + return (n_channel_out_subtiles, n_channel_in_subtiles, n_height_out_subtiles, n_width_out_subtiles, + channel_out_border, channel_in_border, height_out_border, width_out_border, height_in_border, + width_in_border) + + @classmethod + def getWeightStrides(cls, channel_in: int) -> Tuple[int, int, int]: + # Match ne16_task_set_strides for depthwise 3x3: + # d0 = NE16_FILTER_SIZE * NE16_FILTER_SIZE * weight_d0_stride + # = 3 * 3 * 2 = 18 + # d1 = 0 (DW has no cin-major striding from the HW's perspective). + _NE16_FILTER_SIZE = 3 + _NE16_WEIGHT_D0_STRIDE_MODE8 = 2 + d0 = _NE16_FILTER_SIZE * _NE16_FILTER_SIZE * _NE16_WEIGHT_D0_STRIDE_MODE8 + return d0, 0, 0 + + @classmethod + def getConf0(cls, output_bits: int, weight_bits: int, input_signed: bool, use_wmem: bool) -> int: + conf0 = 0 + conf0 |= weight_bits - 1 + conf0 |= 1 << 5 # DW MODE + if use_wmem: + conf0 |= 1 << 9 + conf0 |= 1 << 15 # Layerwise weight offset mode + if output_bits == 32: + conf0 |= 2 << 21 + if input_signed: + conf0 |= 1 << 26 + return conf0 + + +class NE162DDenseConvTemplate(NE16ConvTemplate): + + NE16_KERNEL_SHAPE = 3 + NE16_IS_DEPTHWISE = 0 + NE16_SUBTILE_OUTPUT_CHANNEL = 32 + + def __init__(self, templateStr: str): + super().__init__(templateStr) + + @classmethod + def getCounters( + cls, channel_in: int, height_out: int, width_out: int, channel_out: int, padding_bottom: int, + padding_right: int, + operatorRepresentation: OperatorRepresentation) -> Tuple[int, int, int, int, int, int, int, int, int, int]: + _ = operatorRepresentation # operatorRepresentation not accessed for now because it's just for pointwise kernels + + # NE16 Dense 3x3: CIN_SUBTILE=16, OUTPUT 3x3x32 + n_channel_out_subtiles = _getNumTiles(channel_out, 32) + n_channel_in_subtiles = _getNumTiles(channel_in, 16) + n_height_out_subtiles = _getNumTiles(height_out, 3) + n_width_out_subtiles = _getNumTiles(width_out, 3) + + channel_out_border = _getBorderTileSize(channel_out, 32) + channel_in_border = _getBorderTileSize(channel_in, 16) + height_out_border = _getBorderTileSize(height_out, 3) + width_out_border = _getBorderTileSize(width_out, 3) + height_in_border = height_out_border + 2 - padding_bottom + width_in_border = width_out_border + 2 - padding_right + + return (n_channel_out_subtiles, n_channel_in_subtiles, n_height_out_subtiles, n_width_out_subtiles, + channel_out_border, channel_in_border, height_out_border, width_out_border, height_in_border, + width_in_border) + + @classmethod + def getWeightStrides(cls, channel_in: int) -> Tuple[int, int, int]: + # Match ne16_task_set_strides for dense 3x3 (non-DW): + # d0 = NE16_FILTER_SIZE * NE16_FILTER_SIZE * weight_d0_stride = 18 + # d1 = NE16_FILTER_SIZE * NE16_FILTER_SIZE * weight_d0_stride * qw * num_k_in + # = 18 * 8 * num_k_in + _NE16_FILTER_SIZE = 3 + _NE16_WEIGHT_D0_STRIDE_MODE8 = 2 + _QW = 8 + n_channel_in = _getNumTiles(channel_in, 16) + d0 = _NE16_FILTER_SIZE * _NE16_FILTER_SIZE * _NE16_WEIGHT_D0_STRIDE_MODE8 + d1 = d0 * _QW * n_channel_in + return d0, d1, 0 + + @classmethod + def getConf0(cls, output_bits: int, weight_bits: int, input_signed: bool, use_wmem: bool) -> int: + conf0 = 0 + conf0 |= weight_bits - 1 + if use_wmem: + conf0 |= 1 << 9 + conf0 |= 1 << 15 # Layerwise weight offset mode + if output_bits == 32: + conf0 |= 2 << 21 + if input_signed: + conf0 |= 1 << 26 + return conf0 + + +NE16TaskInitTemplateStr = """ +// N-EUREKA Task Init +ne16_task_t task = { + .data = (ne16_task_data_t) { + .weights_addr = (uint32_t)${weight} - ${wmem_addr_offset} + ${weight_addr_offset}, + .infeat_addr = (uint32_t)${data_in} - ${input_addr_offset}, + .outfeat_addr = (uint32_t)${data_out}, + .scale_addr = (uint32_t)${mul}, + .scale_shift_addr = (uint32_t)${shift}, + .scale_bias_addr = (uint32_t)${add}, + .cfg = (ne16_cfg_t) { + .input_stride = (ne16_stride_t) { + .d0 = ${dim_im_in_y_stride}, + .d1 = ${dim_im_in_x_stride}, + .d2 = 0 + }, + .output_stride = (ne16_stride_t) { + .d0 = NE16_OUTPUT_BANDWIDTH_BYTES, + .d1 = ${dim_im_out_y_stride}, + .d2 = ${dim_im_out_x_stride} + }, + task.data.cfg.weights_stride = (ne16_stride_t) { + .d0 = ${weightStrideD0}, + .d1 = ${weightStrideD1}, + .d2 = ${weightStrideD2} + }, + .subtile = (ne16_subtile_t) { + .number = { + .KoKi = nnx_concat_half(${nKo}, ${nKi}), + .HoWo = nnx_concat_half(${nHo}, ${nWo}) + }, + .remainder = { + .KoKi = nnx_concat_half(${bKo}, ${bKi}), + .HoWo = nnx_concat_half(${bHo}, ${bWo}), + .HiWi = nnx_concat_half(${bHi}, ${bWi}) + } + }, + .padding = (${padding_y_top} << 28) + (${padding_x_right} << 24) + (${padding_y_bottom} << 20) + (${padding_x_left} << 16), + .weight_offset_factor = ${weight_offset}, + .filter_mask = 0, + .conf0 = ${conf0}, + } + } +}; +// NE16 top-level task struct fields (required by HAL helpers and NE16 HW for +// non-1x1 paths). Kept consistent with ne16_task_set_op_to_conv/_set_bits. +task.weight_d0_stride = NE16_WEIGHT_D0_STRIDE_MODE8; +task.qw = ${weight_bits}; +task.subtile_output_channel = ${ne16_subtile_output_channel}; +task.kernel_shape = ${ne16_kernel_shape}; +task.depthwise = ${ne16_depthwise}; +""" + +NE16TaskExecutionTemplateStr = """ +// N-EUREKA Task Execution +ne16_nnx_dispatch_wait(ne16_pulp_get_dev()); +ne16_nnx_dispatch(ne16_pulp_get_dev(), &task); +ne16_nnx_resolve_wait(ne16_pulp_get_dev(), &task); +""" + +NE16RqntPWConv2D_Template = NE162DPWConvTemplate(NE16TaskInitTemplateStr + NE16TaskExecutionTemplateStr) +NE16PWConv2D_Template = NE162DPWConvTemplate(NE16TaskInitTemplateStr + NE16TaskExecutionTemplateStr) + +NE16RqntDWConv2D_Template = NE162DDWConvTemplate(NE16TaskInitTemplateStr + NE16TaskExecutionTemplateStr) +NE16DWConv2D_Template = NE162DDWConvTemplate(NE16TaskInitTemplateStr + NE16TaskExecutionTemplateStr) + +NE16RqntDenseConv2D_Template = NE162DDenseConvTemplate(NE16TaskInitTemplateStr + NE16TaskExecutionTemplateStr) +NE16DenseConv2D_Template = NE162DDenseConvTemplate(NE16TaskInitTemplateStr + NE16TaskExecutionTemplateStr) diff --git a/Deeploy/Targets/NE16/Templates/__init__.py b/Deeploy/Targets/NE16/Templates/__init__.py new file mode 100644 index 0000000000..be436b64a3 --- /dev/null +++ b/Deeploy/Targets/NE16/Templates/__init__.py @@ -0,0 +1,5 @@ +# SPDX-FileCopyrightText: 2024 ETH Zurich and University of Bologna +# +# SPDX-License-Identifier: Apache-2.0 + +from . import * diff --git a/Deeploy/Targets/NE16/TileConstraints/NE16DenseConstraint.py b/Deeploy/Targets/NE16/TileConstraints/NE16DenseConstraint.py new file mode 100644 index 0000000000..4288e0f1de --- /dev/null +++ b/Deeploy/Targets/NE16/TileConstraints/NE16DenseConstraint.py @@ -0,0 +1,268 @@ +# SPDX-FileCopyrightText: 2024 ETH Zurich and University of Bologna +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import Dict, List, Tuple + +from Deeploy.AbstractDataTypes import PointerClass +from Deeploy.CommonExtensions.DataTypes import uint8_t, uint16_t, uint32_t +from Deeploy.DeeployTypes import NetworkContext, OperatorRepresentation, VariableBuffer +from Deeploy.Targets.NE16.Templates.ConvTemplate import NE162DDenseConvTemplate, getInputAddrOffset, \ + ioStridesFromDimensions +from Deeploy.Targets.NE16.TileConstraints.RequantHelpers import requantAddGeometricalConstraint, requantLoadSchedule +from Deeploy.Targets.PULPOpen.TileConstraints.ConvTileConstraint import Conv2DTileConstraint +from Deeploy.TilingExtension.MemoryConstraints import NodeMemoryConstraint +from Deeploy.TilingExtension.TileConstraint import TileConstraint +from Deeploy.TilingExtension.TilerModel import PerformanceHint, TilerModel +from Deeploy.TilingExtension.TilingCodegen import AbsoluteHyperRectangle, HyperRectangle, TilingSchedule, \ + VariableReplacementScheme, calculateFlatOffsetInBytes + + +class NE16DenseConv2DTileConstraint(TileConstraint): + + @staticmethod + def addGeometricalConstraint(tilerModel: TilerModel, parseDict: Dict, ctxt: NetworkContext) -> TilerModel: + inputBufferName = parseDict['data_in'] + weightBufferName = parseDict['weight'] + outputBufferName = parseDict['data_out'] + + strides = parseDict["strides"] + padding = parseDict["pads"] + dilation = parseDict["dilations"] + + for bufferName in [inputBufferName, weightBufferName, outputBufferName]: + tilerModel.addTensorDimToModel(ctxt, bufferName) + + inputBatchVar = tilerModel.getTensorDimVar(tensorName = inputBufferName, dimIdx = 0) + inputHeightVar = tilerModel.getTensorDimVar(tensorName = inputBufferName, dimIdx = 1) + inputWidthVar = tilerModel.getTensorDimVar(tensorName = inputBufferName, dimIdx = 2) + inputChannelVar = tilerModel.getTensorDimVar(tensorName = inputBufferName, dimIdx = 3) + + weightOutChannelVar = tilerModel.getTensorDimVar(tensorName = weightBufferName, dimIdx = 0) + weightInChannelMajorVar = tilerModel.getTensorDimVar(tensorName = weightBufferName, dimIdx = 1) + weightBitsVar = tilerModel.getTensorDimVar(tensorName = weightBufferName, dimIdx = 2) + weightBandwidthVar = tilerModel.getTensorDimVar(tensorName = weightBufferName, dimIdx = 3) + + outputBatchVar = tilerModel.getTensorDimVar(tensorName = outputBufferName, dimIdx = 0) + outputHeightVar = tilerModel.getTensorDimVar(tensorName = outputBufferName, dimIdx = 1) + outputWidthVar = tilerModel.getTensorDimVar(tensorName = outputBufferName, dimIdx = 2) + outputChannelVar = tilerModel.getTensorDimVar(tensorName = outputBufferName, dimIdx = 3) + + # Map output dims to inputs dims + tilerModel.addConstraint(outputBatchVar == inputBatchVar) + + weightBuffer = ctxt.lookup(weightBufferName) + if hasattr(weightBuffer, "_memoryLevel") and weightBuffer._memoryLevel == "WeightMemory_SRAM": + tilerModel.addConstraint(weightOutChannelVar == weightOutChannelVar.Max()) + else: + tilerModel.addConstraint(weightOutChannelVar == outputChannelVar) + + inputBuffer = ctxt.lookup(inputBufferName) + + effectiveHeight = inputHeightVar + ((padding[0] + padding[2]) * (inputHeightVar == inputBuffer.shape[1])) + effectiveWidth = inputWidthVar + ((padding[1] + padding[3]) * (inputWidthVar == inputBuffer.shape[2])) + + tilerModel.addConstraint((outputHeightVar == (effectiveHeight - (3 - 1) - 1) // strides[0] + 1)) + tilerModel.addConstraint((outputWidthVar == (effectiveWidth - (3 - 1) - 1) // strides[1] + 1)) + + return tilerModel + + @staticmethod + def addPolicyConstraint(tilerModel: TilerModel, parseDict: Dict, ctxt: NetworkContext) -> TilerModel: + inputHeightVar = tilerModel.getTensorDimVar(tensorName = parseDict['data_in'], dimIdx = 1) + inputWidthVar = tilerModel.getTensorDimVar(tensorName = parseDict['data_in'], dimIdx = 2) + inputChannelVar = tilerModel.getTensorDimVar(tensorName = parseDict['data_in'], dimIdx = 3) + + strides = parseDict["strides"] + + tilerModel.addConstraint((inputHeightVar % strides[0]) == 0) + tilerModel.addConstraint((inputWidthVar % strides[1]) == 0) + + tilerModel.addConstraint(inputChannelVar == inputChannelVar.Max()) + + tilerModel.addConstraint(inputHeightVar == inputHeightVar.Max(), strategy = PerformanceHint(1)) + tilerModel.addConstraint(inputWidthVar == inputWidthVar.Max(), strategy = PerformanceHint(1)) + + tilerModel.addConstraint(inputHeightVar >= parseDict['dim_kernel_x']) + tilerModel.addConstraint(inputWidthVar >= parseDict['dim_kernel_y']) + + return tilerModel + + @classmethod + def serializeTilingSolution( + cls, tilingSolution: NodeMemoryConstraint, absoluteOutputCubes: List[AbsoluteHyperRectangle], + targetMemLevel: str, ctxt: NetworkContext, + operatorRepresentation: OperatorRepresentation) -> Tuple[VariableReplacementScheme, TilingSchedule]: + outputCubes = [cube.rectangle for cube in absoluteOutputCubes] + + addrNames = ['data_in', 'data_out'] + inputBaseOffsets, outputBaseOffsets = cls.extractBaseAddr(tilingSolution, targetMemLevel, + operatorRepresentation, addrNames) + + varWeight = operatorRepresentation['weight'] + varOut = operatorRepresentation['data_out'] + + inputInCubes = [] + replacements: Dict[str, List[int]] = { + "padding_y_top": [], + "padding_y_bottom": [], + "padding_x_left": [], + "padding_x_right": [], + "dim_im_in_x_stride": [], + "dim_im_in_y_stride": [], + "dim_im_out_x_stride": [], + "dim_im_out_y_stride": [], + "input_addr_offset": [], + "nKo": [], + "nKi": [], + "nHo": [], + "nWo": [], + "bKo": [], + "bKi": [], + "bHo": [], + "bWo": [], + "bHi": [], + "bWi": [], + } + + replacementTypes = { + "padding_y_top": PointerClass(uint8_t), + "padding_y_bottom": PointerClass(uint8_t), + "padding_x_left": PointerClass(uint8_t), + "padding_x_right": PointerClass(uint8_t), + "dim_im_in_x_stride": PointerClass(uint32_t), + "dim_im_in_y_stride": PointerClass(uint32_t), + "dim_im_out_x_stride": PointerClass(uint32_t), + "dim_im_out_y_stride": PointerClass(uint32_t), + "input_addr_offset": PointerClass(uint32_t), + "nKo": PointerClass(uint16_t), + "nKi": PointerClass(uint16_t), + "nHo": PointerClass(uint16_t), + "nWo": PointerClass(uint16_t), + "bKo": PointerClass(uint16_t), + "bKi": PointerClass(uint16_t), + "bHo": PointerClass(uint16_t), + "bWo": PointerClass(uint16_t), + "bHi": PointerClass(uint16_t), + "bWi": PointerClass(uint16_t), + } + + weightH = operatorRepresentation['dim_kernel_y'] + weightW = operatorRepresentation['dim_kernel_x'] + weightC = operatorRepresentation['ch_im_in'] + + pads = operatorRepresentation['pads'] + strides = operatorRepresentation['strides'] + + outputBuffer = ctxt.lookup(varOut) + assert isinstance(outputBuffer, VariableBuffer) + + for cube in outputCubes: + (BatchOffset, HOffset, WOffset, COffset) = cube.offset + (BatchSize, HSize, WSize, CSize) = cube.dims + + InCube, padding_tuple = Conv2DTileConstraint.computeInputCube((weightH, weightW), pads, strides, weightC, + cube, outputBuffer.shape) + padding_left, padding_right, padding_top, padding_bottom = padding_tuple + + replacements['padding_y_top'].append(padding_top) + replacements['padding_y_bottom'].append(padding_bottom) + replacements['padding_x_left'].append(padding_left) + replacements['padding_x_right'].append(padding_right) + + inBSize, inHSize, inWSize, inCSize = InCube.dims + + dim_im_in_x_stride, dim_im_in_y_stride = ioStridesFromDimensions(inWSize, inCSize, + operatorRepresentation["input_bits"]) + replacements['dim_im_in_x_stride'].append(dim_im_in_x_stride) + replacements['dim_im_in_y_stride'].append(dim_im_in_y_stride) + dim_im_out_x_stride, dim_im_out_y_stride = ioStridesFromDimensions(WSize, CSize, + operatorRepresentation["output_bits"]) + replacements['dim_im_out_x_stride'].append(dim_im_out_x_stride) + replacements['dim_im_out_y_stride'].append(dim_im_out_y_stride) + + replacements['input_addr_offset'].append( + getInputAddrOffset(inWSize, dim_im_in_y_stride, padding_top, padding_left)) + + nKo, nKi, nHo, nWo, bKo, bKi, bHo, bWo, bHi, bWi = NE162DDenseConvTemplate.getCounters( + inCSize, HSize, WSize, CSize, padding_bottom, padding_right, operatorRepresentation) + + replacements["nKo"].append(nKo) + replacements["nKi"].append(nKi) + replacements["nHo"].append(nHo) + replacements["nWo"].append(nWo) + replacements["bKo"].append(bKo) + replacements["bKi"].append(bKi) + replacements["bHo"].append(bHo) + replacements["bWo"].append(bWo) + replacements["bHi"].append(bHi) + replacements["bWi"].append(bWi) + + inputInCubes.append(InCube) + + inputLoadSchedule = [] + outputLoadSchedule = [] + + for a in inputInCubes: + inputLoadSchedule.append({"data_in": a}) + + for out in outputCubes: + outputLoadSchedule.append({"data_out": out}) + + weightBuffer = ctxt.lookup(varWeight) + assert isinstance(weightBuffer, VariableBuffer) + weightShape = weightBuffer.shape + + if hasattr(weightBuffer, "_memoryLevel") and weightBuffer._memoryLevel == "WeightMemory_SRAM": + replacements['weight_addr_offset'] = [] + replacementTypes['weight_addr_offset'] = PointerClass(uint32_t) + for absoluteCube in absoluteOutputCubes: + COffset, CSize = absoluteCube.absoluteOffset[-1], absoluteCube.rectangle.dims[-1] + WeightCube = HyperRectangle((COffset, 0, 0), (CSize, weightShape[-2], weightShape[-1])) + replacements['weight_addr_offset'].append(calculateFlatOffsetInBytes(WeightCube, weightBuffer)) + else: + inputWeightBaseOffsets, outputWeightBaseOffsets = cls.extractBaseAddr(tilingSolution, targetMemLevel, + operatorRepresentation, ['weight']) + inputBaseOffsets.update(inputWeightBaseOffsets) + outputBaseOffsets.update(outputWeightBaseOffsets) + + for cube, load in zip(outputCubes, inputLoadSchedule): + COffset, CSize = cube.offset[-1], cube.dims[-1] + load['weight'] = HyperRectangle((COffset, 0, 0), (CSize, weightShape[-2], weightShape[-1])) + + tilingSchedule = TilingSchedule(inputBaseOffsets, outputBaseOffsets, inputLoadSchedule, outputLoadSchedule) + variableReplacementSchedule = VariableReplacementScheme(replacements, replacementTypes) + + return variableReplacementSchedule, tilingSchedule + + +class NE16RQSDenseConv2DTileConstraint(NE16DenseConv2DTileConstraint): + + @staticmethod + def addGeometricalConstraint(tilerModel: TilerModel, parseDict: Dict, ctxt: NetworkContext) -> TilerModel: + tilerModel = NE16DenseConv2DTileConstraint.addGeometricalConstraint(tilerModel, parseDict, ctxt) + return requantAddGeometricalConstraint(tilerModel, parseDict, ctxt) + + @classmethod + def serializeTilingSolution( + cls, tilingSolution: NodeMemoryConstraint, absoluteOutputCubes: List[AbsoluteHyperRectangle], + targetMemLevel: str, ctxt: NetworkContext, + operatorRepresentation: OperatorRepresentation) -> Tuple[VariableReplacementScheme, TilingSchedule]: + variableReplacementSchedule, tilingSchedule = super().serializeTilingSolution( + tilingSolution, absoluteOutputCubes, targetMemLevel, ctxt, operatorRepresentation) + + addrNames = ['mul', 'add'] + inputRequantBaseOffsets, _ = cls.extractBaseAddr(tilingSolution, targetMemLevel, operatorRepresentation, + addrNames) + newInputBaseOffsets = {**tilingSchedule.inputBaseOffsets, **inputRequantBaseOffsets} + + requantSchedule = requantLoadSchedule(absoluteOutputCubes, ctxt, operatorRepresentation) + newInputLoadSchedule = [{ + **load, + **rqLoad + } for load, rqLoad in zip(tilingSchedule.inputLoadSchedule, requantSchedule)] + + newTilingSchedule = TilingSchedule(newInputBaseOffsets, tilingSchedule.outputBaseOffsets, newInputLoadSchedule, + tilingSchedule.outputLoadSchedule) + + return variableReplacementSchedule, newTilingSchedule diff --git a/Deeploy/Targets/NE16/TileConstraints/NE16DepthwiseConstraint.py b/Deeploy/Targets/NE16/TileConstraints/NE16DepthwiseConstraint.py new file mode 100644 index 0000000000..b9221b74f3 --- /dev/null +++ b/Deeploy/Targets/NE16/TileConstraints/NE16DepthwiseConstraint.py @@ -0,0 +1,265 @@ +# SPDX-FileCopyrightText: 2024 ETH Zurich and University of Bologna +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import Dict, List, Tuple + +from Deeploy.AbstractDataTypes import PointerClass +from Deeploy.CommonExtensions.DataTypes import uint8_t, uint16_t, uint32_t +from Deeploy.DeeployTypes import NetworkContext, OperatorRepresentation, VariableBuffer +from Deeploy.Targets.NE16.Templates.ConvTemplate import NE162DDWConvTemplate, getInputAddrOffset, \ + ioStridesFromDimensions +from Deeploy.Targets.NE16.TileConstraints.RequantHelpers import requantAddGeometricalConstraint, requantLoadSchedule +from Deeploy.Targets.PULPOpen.TileConstraints.ConvTileConstraint import Conv2DTileConstraint +from Deeploy.TilingExtension.MemoryConstraints import NodeMemoryConstraint +from Deeploy.TilingExtension.TileConstraint import TileConstraint +from Deeploy.TilingExtension.TilerModel import PerformanceHint, TilerModel +from Deeploy.TilingExtension.TilingCodegen import AbsoluteHyperRectangle, HyperRectangle, TilingSchedule, \ + VariableReplacementScheme + + +class NE16DWConv2DTileConstraint(TileConstraint): + + @staticmethod + def addGeometricalConstraint(tilerModel: TilerModel, parseDict: Dict, ctxt: NetworkContext) -> TilerModel: + inputBufferName = parseDict['data_in'] + weightBufferName = parseDict['weight'] + outputBufferName = parseDict['data_out'] + + strides = parseDict["strides"] + padding = parseDict["pads"] + dilation = parseDict["dilations"] + + for bufferName in [inputBufferName, weightBufferName, outputBufferName]: + tilerModel.addTensorDimToModel(ctxt, bufferName) + + inputBatchVar = tilerModel.getTensorDimVar(tensorName = inputBufferName, dimIdx = 0) + inputHeightVar = tilerModel.getTensorDimVar(tensorName = inputBufferName, dimIdx = 1) + inputWidthVar = tilerModel.getTensorDimVar(tensorName = inputBufferName, dimIdx = 2) + inputChannelVar = tilerModel.getTensorDimVar(tensorName = inputBufferName, dimIdx = 3) + + weightOutChannelVar = tilerModel.getTensorDimVar(tensorName = weightBufferName, dimIdx = 0) + + outputBatchVar = tilerModel.getTensorDimVar(tensorName = outputBufferName, dimIdx = 0) + outputHeightVar = tilerModel.getTensorDimVar(tensorName = outputBufferName, dimIdx = 1) + outputWidthVar = tilerModel.getTensorDimVar(tensorName = outputBufferName, dimIdx = 2) + outputChannelVar = tilerModel.getTensorDimVar(tensorName = outputBufferName, dimIdx = 3) + + # Map output dims to inputs dims + tilerModel.addConstraint(outputBatchVar == inputBatchVar) + tilerModel.addConstraint(outputChannelVar == inputChannelVar) + + weightBuffer = ctxt.lookup(weightBufferName) + # NE16 DW weight is packed as a single (1, 1, packed_bytes) block + # containing all output channels (up to NE16_SUBTILE_INPUT_CHANNEL=16). + # Keep the outermost dim fixed at its full (=1) value regardless of + # the output channel tiling. + tilerModel.addConstraint(weightOutChannelVar == weightOutChannelVar.Max()) + + tilerModel.addConstraint(inputHeightVar >= 3) + tilerModel.addConstraint(inputWidthVar >= 3) + + inputBuffer = ctxt.lookup(inputBufferName) + + effectiveHeight = inputHeightVar + ((padding[0] + padding[2]) * (inputHeightVar == inputBuffer.shape[1])) + effectiveWidth = inputWidthVar + ((padding[1] + padding[3]) * (inputWidthVar == inputBuffer.shape[2])) + + tilerModel.addConstraint((outputHeightVar == (effectiveHeight - (3 - 1) - 1) // strides[0] + 1)) + tilerModel.addConstraint((outputWidthVar == (effectiveWidth - (3 - 1) - 1) // strides[1] + 1)) + + return tilerModel + + @staticmethod + def addPolicyConstraint(tilerModel: TilerModel, parseDict: Dict, ctxt: NetworkContext) -> TilerModel: + inputHeightVar = tilerModel.getTensorDimVar(tensorName = parseDict['data_in'], dimIdx = 1) + inputWidthVar = tilerModel.getTensorDimVar(tensorName = parseDict['data_in'], dimIdx = 2) + + strides = parseDict["strides"] + + tilerModel.addConstraint((inputHeightVar % strides[0]) == 0) + tilerModel.addConstraint((inputWidthVar % strides[1]) == 0) + + tilerModel.addConstraint(inputHeightVar == inputHeightVar.Max(), strategy = PerformanceHint(1)) + tilerModel.addConstraint(inputWidthVar == inputWidthVar.Max(), strategy = PerformanceHint(1)) + + return tilerModel + + @classmethod + def serializeTilingSolution( + cls, tilingSolution: NodeMemoryConstraint, absoluteOutputCubes: List[AbsoluteHyperRectangle], + targetMemLevel: str, ctxt: NetworkContext, + operatorRepresentation: OperatorRepresentation) -> Tuple[VariableReplacementScheme, TilingSchedule]: + outputCubes = [cube.rectangle for cube in absoluteOutputCubes] + + addrNames = ['data_in', 'data_out'] + inputBaseOffsets, outputBaseOffsets = cls.extractBaseAddr(tilingSolution, targetMemLevel, + operatorRepresentation, addrNames) + + varWeight = operatorRepresentation['weight'] + varOut = operatorRepresentation['data_out'] + + inputInCubes = [] + replacements: Dict[str, List[int]] = { + "padding_y_top": [], + "padding_y_bottom": [], + "padding_x_left": [], + "padding_x_right": [], + "dim_im_in_x_stride": [], + "dim_im_in_y_stride": [], + "dim_im_out_x_stride": [], + "dim_im_out_y_stride": [], + "input_addr_offset": [], + "nKo": [], + "nKi": [], + "nHo": [], + "nWo": [], + "bKo": [], + "bKi": [], + "bHo": [], + "bWo": [], + "bHi": [], + "bWi": [], + } + + replacementTypes = { + "padding_y_top": PointerClass(uint8_t), + "padding_y_bottom": PointerClass(uint8_t), + "padding_x_left": PointerClass(uint8_t), + "padding_x_right": PointerClass(uint8_t), + "dim_im_in_x_stride": PointerClass(uint32_t), + "dim_im_in_y_stride": PointerClass(uint32_t), + "dim_im_out_x_stride": PointerClass(uint32_t), + "dim_im_out_y_stride": PointerClass(uint32_t), + "input_addr_offset": PointerClass(uint32_t), + "nKo": PointerClass(uint16_t), + "nKi": PointerClass(uint16_t), + "nHo": PointerClass(uint16_t), + "nWo": PointerClass(uint16_t), + "bKo": PointerClass(uint16_t), + "bKi": PointerClass(uint16_t), + "bHo": PointerClass(uint16_t), + "bWo": PointerClass(uint16_t), + "bHi": PointerClass(uint16_t), + "bWi": PointerClass(uint16_t), + } + + weightH = operatorRepresentation['dim_kernel_y'] + weightW = operatorRepresentation['dim_kernel_x'] + weightC = operatorRepresentation['ch_im_in'] + + pads = operatorRepresentation['pads'] + strides = operatorRepresentation['strides'] + + outputBuffer = ctxt.lookup(varOut) + assert isinstance(outputBuffer, VariableBuffer) + + for cube in outputCubes: + (BatchOffset, HOffset, WOffset, COffset) = cube.offset + (BatchSize, HSize, WSize, CSize) = cube.dims + + InCube, padding_tuple = Conv2DTileConstraint.computeInputCube((weightH, weightW), pads, strides, weightC, + cube, + ctxt.lookup(varOut).shape) + padding_left, padding_right, padding_top, padding_bottom = padding_tuple + + replacements['padding_y_top'].append(padding_top) + replacements['padding_y_bottom'].append(padding_bottom) + replacements['padding_x_left'].append(padding_left) + replacements['padding_x_right'].append(padding_right) + + inBSize, inHSize, inWSize, inCSize = InCube.dims + + dim_im_in_x_stride, dim_im_in_y_stride = ioStridesFromDimensions(inWSize, inCSize, + operatorRepresentation["input_bits"]) + replacements['dim_im_in_x_stride'].append(dim_im_in_x_stride) + replacements['dim_im_in_y_stride'].append(dim_im_in_y_stride) + dim_im_out_x_stride, dim_im_out_y_stride = ioStridesFromDimensions(WSize, CSize, + operatorRepresentation["output_bits"]) + replacements['dim_im_out_x_stride'].append(dim_im_out_x_stride) + replacements['dim_im_out_y_stride'].append(dim_im_out_y_stride) + + replacements['input_addr_offset'].append( + getInputAddrOffset(inWSize, dim_im_in_y_stride, padding_top, padding_left)) + + nKo, nKi, nHo, nWo, bKo, bKi, bHo, bWo, bHi, bWi = NE162DDWConvTemplate.getCounters( + inCSize, HSize, WSize, CSize, padding_bottom, padding_right, operatorRepresentation) + + replacements["nKo"].append(nKo) + replacements["nKi"].append(nKi) + replacements["nHo"].append(nHo) + replacements["nWo"].append(nWo) + replacements["bKo"].append(bKo) + replacements["bKi"].append(bKi) + replacements["bHo"].append(bHo) + replacements["bWo"].append(bWo) + replacements["bHi"].append(bHi) + replacements["bWi"].append(bWi) + + inputInCubes.append(InCube) + + inputLoadSchedule = [] + outputLoadSchedule = [] + + for a in inputInCubes: + inputLoadSchedule.append({"data_in": a}) + + for out in outputCubes: + outputLoadSchedule.append({"data_out": out}) + + weightBuffer = ctxt.lookup(varWeight) + assert isinstance(weightBuffer, VariableBuffer) + weightShape = weightBuffer.shape + + if hasattr(weightBuffer, "_memoryLevel") and weightBuffer._memoryLevel == "WeightMemory_SRAM": + replacements['weight_addr_offset'] = [] + replacementTypes['weight_addr_offset'] = PointerClass(uint32_t) + for _ in absoluteOutputCubes: + # DW weight is a single packed block — no per-cout offset. + replacements['weight_addr_offset'].append(0) + else: + inputWeightBaseOffsets, outputWeightBaseOffsets = cls.extractBaseAddr(tilingSolution, targetMemLevel, + operatorRepresentation, ['weight']) + inputBaseOffsets.update(inputWeightBaseOffsets) + outputBaseOffsets.update(outputWeightBaseOffsets) + + # DW weight is a single packed (1, 1, packed_bytes) block used + # across all output-channel tiles — same cube every iteration. + for _cube, load in zip(outputCubes, inputLoadSchedule): + load['weight'] = HyperRectangle((0, 0, 0), (weightShape[0], weightShape[1], weightShape[2])) + + tilingSchedule = TilingSchedule(inputBaseOffsets, outputBaseOffsets, inputLoadSchedule, outputLoadSchedule) + variableReplacementSchedule = VariableReplacementScheme(replacements, replacementTypes) + + return variableReplacementSchedule, tilingSchedule + + +class NE16RQSDWConv2DTileConstraint(NE16DWConv2DTileConstraint): + + @staticmethod + def addGeometricalConstraint(tilerModel: TilerModel, parseDict: Dict, ctxt: NetworkContext) -> TilerModel: + tilerModel = NE16DWConv2DTileConstraint.addGeometricalConstraint(tilerModel, parseDict, ctxt) + return requantAddGeometricalConstraint(tilerModel, parseDict, ctxt) + + @classmethod + def serializeTilingSolution( + cls, tilingSolution: NodeMemoryConstraint, absoluteOutputCubes: List[AbsoluteHyperRectangle], + targetMemLevel: str, ctxt: NetworkContext, + operatorRepresentation: OperatorRepresentation) -> Tuple[VariableReplacementScheme, TilingSchedule]: + variableReplacementSchedule, tilingSchedule = super().serializeTilingSolution( + tilingSolution, absoluteOutputCubes, targetMemLevel, ctxt, operatorRepresentation) + + addrNames = ['mul', 'add'] + inputRequantBaseOffsets, _ = cls.extractBaseAddr(tilingSolution, targetMemLevel, operatorRepresentation, + addrNames) + newInputBaseOffsets = {**tilingSchedule.inputBaseOffsets, **inputRequantBaseOffsets} + + requantSchedule = requantLoadSchedule(absoluteOutputCubes, ctxt, operatorRepresentation) + newInputLoadSchedule = [{ + **load, + **rqLoad + } for load, rqLoad in zip(tilingSchedule.inputLoadSchedule, requantSchedule)] + + newTilingSchedule = TilingSchedule(newInputBaseOffsets, tilingSchedule.outputBaseOffsets, newInputLoadSchedule, + tilingSchedule.outputLoadSchedule) + + return variableReplacementSchedule, newTilingSchedule diff --git a/Deeploy/Targets/NE16/TileConstraints/NE16PointwiseConstraint.py b/Deeploy/Targets/NE16/TileConstraints/NE16PointwiseConstraint.py new file mode 100644 index 0000000000..a32137a826 --- /dev/null +++ b/Deeploy/Targets/NE16/TileConstraints/NE16PointwiseConstraint.py @@ -0,0 +1,298 @@ +# SPDX-FileCopyrightText: 2024 ETH Zurich and University of Bologna +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import Dict, List, Tuple + +from Deeploy.AbstractDataTypes import PointerClass +from Deeploy.CommonExtensions.DataTypes import uint8_t, uint16_t, uint32_t +from Deeploy.DeeployTypes import NetworkContext, OperatorRepresentation, VariableBuffer +from Deeploy.Targets.NE16.Templates.ConvTemplate import NE162DPWConvTemplate, getInputAddrOffset, \ + ioStridesFromDimensions +from Deeploy.Targets.NE16.TileConstraints.RequantHelpers import requantAddGeometricalConstraint, requantLoadSchedule +from Deeploy.Targets.PULPOpen.TileConstraints.ConvTileConstraint import Conv2DTileConstraint +from Deeploy.TilingExtension.MemoryConstraints import NodeMemoryConstraint +from Deeploy.TilingExtension.TileConstraint import TileConstraint +from Deeploy.TilingExtension.TilerModel import PerformanceHint, TilerModel +from Deeploy.TilingExtension.TilingCodegen import AbsoluteHyperRectangle, HyperRectangle, TilingSchedule, \ + VariableReplacementScheme, calculateFlatOffsetInBytes + + +class NE16PWConv2DTileConstraint(TileConstraint): + + @staticmethod + def addGeometricalConstraint(tilerModel: TilerModel, parseDict: Dict, ctxt: NetworkContext) -> TilerModel: + inputBufferName = parseDict['data_in'] + weightBufferName = parseDict['weight'] + outputBufferName = parseDict['data_out'] + + for bufferName in [inputBufferName, weightBufferName, outputBufferName]: + tilerModel.addTensorDimToModel(ctxt, bufferName) + + inputBatchVar = tilerModel.getTensorDimVar(tensorName = inputBufferName, dimIdx = 0) + inputHeightVar = tilerModel.getTensorDimVar(tensorName = inputBufferName, dimIdx = 1) + inputWidthVar = tilerModel.getTensorDimVar(tensorName = inputBufferName, dimIdx = 2) + + weightOutChannelVar = tilerModel.getTensorDimVar(tensorName = weightBufferName, dimIdx = 0) + + outputBatchVar = tilerModel.getTensorDimVar(tensorName = outputBufferName, dimIdx = 0) + outputHeightVar = tilerModel.getTensorDimVar(tensorName = outputBufferName, dimIdx = 1) + outputWidthVar = tilerModel.getTensorDimVar(tensorName = outputBufferName, dimIdx = 2) + outputChannelVar = tilerModel.getTensorDimVar(tensorName = outputBufferName, dimIdx = 3) + + # Map output dims to inputs dims + tilerModel.addConstraint(outputBatchVar == inputBatchVar) + tilerModel.addConstraint(outputHeightVar == inputHeightVar) + tilerModel.addConstraint(outputWidthVar == inputWidthVar) + + weightBuffer = ctxt.lookup(weightBufferName) + if hasattr(weightBuffer, "_memoryLevel") and weightBuffer._memoryLevel == "WeightMemory_SRAM": + tilerModel.addConstraint(weightOutChannelVar == weightOutChannelVar.Max()) + else: + tilerModel.addConstraint(weightOutChannelVar == outputChannelVar) + + tilerModel.addConstraint(inputHeightVar >= 1) + tilerModel.addConstraint(inputWidthVar >= 1) + + return tilerModel + + @staticmethod + def addPolicyConstraint(tilerModel: TilerModel, parseDict: Dict, ctxt: NetworkContext) -> TilerModel: + + # Get to-be-tiled tensor's buffers + inputBuffer = ctxt.lookup(name = parseDict['data_in']) + weightBuffer = ctxt.lookup(name = parseDict['weight']) + outputBuffer = ctxt.lookup(name = parseDict['data_out']) + + inputHeightVar = tilerModel.getTensorDimVar(tensorName = inputBuffer.name, dimIdx = 1) + inputWidthVar = tilerModel.getTensorDimVar(tensorName = inputBuffer.name, dimIdx = 2) + inputChannelVar = tilerModel.getTensorDimVar(tensorName = inputBuffer.name, dimIdx = 3) + + weightOutChannelVar = tilerModel.getTensorDimVar(tensorName = weightBuffer.name, dimIdx = 0) + weightInChannelMajorVar = tilerModel.getTensorDimVar(tensorName = weightBuffer.name, dimIdx = 1) + weightBandwidthVar = tilerModel.getTensorDimVar(tensorName = weightBuffer.name, dimIdx = 2) + + outputHeightVar = tilerModel.getTensorDimVar(tensorName = outputBuffer.name, dimIdx = 1) + outputWidthVar = tilerModel.getTensorDimVar(tensorName = outputBuffer.name, dimIdx = 2) + outputChannelVar = tilerModel.getTensorDimVar(tensorName = outputBuffer.name, dimIdx = 3) + + strides = parseDict["strides"] + padding = parseDict["pads"] + + # LMACAN: Force full input channel to avoid partial results + tilerModel.addConstraint(inputChannelVar == inputChannelVar.Max()) + tilerModel.addConstraint(weightInChannelMajorVar == weightInChannelMajorVar.Max()) + tilerModel.addConstraint(weightBandwidthVar == weightBandwidthVar.Max()) + + tilerModel.addConstraint((inputHeightVar % strides[0]) == 0) + tilerModel.addConstraint((inputWidthVar % strides[1]) == 0) + + # N-EUREKA tile constraints to align with N-EUREKA's hardware subtiling + if parseDict["dim_im_out_x"] > 6: + tilerModel.addTileSizeDivisibleConstraint(parseDict, + "dim_im_out_x", + outputHeightVar, + 6, + strategy = PerformanceHint(priority = 3)) + else: + tilerModel.addConstraint(outputHeightVar == outputHeightVar.Max(), strategy = PerformanceHint(priority = 3)) + + if parseDict["dim_im_out_y"] > 6: + tilerModel.addTileSizeDivisibleConstraint(parseDict, + "dim_im_out_y", + outputWidthVar, + 6, + strategy = PerformanceHint(priority = 2)) + else: + tilerModel.addConstraint(outputWidthVar == outputWidthVar.Max(), strategy = PerformanceHint(priority = 2)) + + if parseDict["ch_im_out"] > 32: + tilerModel.addTileSizeDivisibleConstraint(parseDict, + "ch_im_out", + outputChannelVar, + 32, + strategy = PerformanceHint(priority = 1)) + else: + tilerModel.addConstraint(outputChannelVar == outputChannelVar.Max(), + strategy = PerformanceHint(priority = 1)) + + return tilerModel + + @classmethod + def serializeTilingSolution( + cls, tilingSolution: NodeMemoryConstraint, absoluteOutputCubes: List[AbsoluteHyperRectangle], + targetMemLevel: str, ctxt: NetworkContext, + operatorRepresentation: OperatorRepresentation) -> Tuple[VariableReplacementScheme, TilingSchedule]: + outputCubes = [cube.rectangle for cube in absoluteOutputCubes] + + addrNames = ['data_in', 'data_out'] + inputBaseOffsets, outputBaseOffsets = cls.extractBaseAddr(tilingSolution, targetMemLevel, + operatorRepresentation, addrNames) + + varWeight = operatorRepresentation['weight'] + varOut = operatorRepresentation['data_out'] + + inputInCubes = [] + replacements: Dict[str, List[int]] = { + "padding_y_top": [], + "padding_y_bottom": [], + "padding_x_left": [], + "padding_x_right": [], + "dim_im_in_x_stride": [], + "dim_im_in_y_stride": [], + "dim_im_out_x_stride": [], + "dim_im_out_y_stride": [], + "input_addr_offset": [], + "nKo": [], + "nKi": [], + "nHo": [], + "nWo": [], + "bKo": [], + "bKi": [], + "bHo": [], + "bWo": [], + "bHi": [], + "bWi": [], + } + + replacementTypes = { + "padding_y_top": PointerClass(uint8_t), + "padding_y_bottom": PointerClass(uint8_t), + "padding_x_left": PointerClass(uint8_t), + "padding_x_right": PointerClass(uint8_t), + "dim_im_in_x_stride": PointerClass(uint32_t), + "dim_im_in_y_stride": PointerClass(uint32_t), + "dim_im_out_x_stride": PointerClass(uint32_t), + "dim_im_out_y_stride": PointerClass(uint32_t), + "input_addr_offset": PointerClass(uint32_t), + "nKo": PointerClass(uint16_t), + "nKi": PointerClass(uint16_t), + "nHo": PointerClass(uint16_t), + "nWo": PointerClass(uint16_t), + "bKo": PointerClass(uint16_t), + "bKi": PointerClass(uint16_t), + "bHo": PointerClass(uint16_t), + "bWo": PointerClass(uint16_t), + "bHi": PointerClass(uint16_t), + "bWi": PointerClass(uint16_t), + } + + weightH = operatorRepresentation['dim_kernel_y'] + weightW = operatorRepresentation['dim_kernel_x'] + weightC = operatorRepresentation['ch_im_in'] + + pads = operatorRepresentation['pads'] + strides = operatorRepresentation['strides'] + + outputBuffer = ctxt.lookup(varOut) + assert isinstance(outputBuffer, VariableBuffer) + + for cube in outputCubes: + (BatchOffset, HOffset, WOffset, COffset) = cube.offset + (BatchSize, HSize, WSize, CSize) = cube.dims + + InCube, padding_tuple = Conv2DTileConstraint.computeInputCube((weightH, weightW), pads, strides, weightC, + cube, outputBuffer.shape) + padding_left, padding_right, padding_top, padding_bottom = padding_tuple + + replacements['padding_y_top'].append(padding_top) + replacements['padding_y_bottom'].append(padding_bottom) + replacements['padding_x_left'].append(padding_left) + replacements['padding_x_right'].append(padding_right) + + inBSize, inHSize, inWSize, inCSize = InCube.dims + + dim_im_in_x_stride, dim_im_in_y_stride = ioStridesFromDimensions(inWSize, inCSize, + operatorRepresentation["input_bits"]) + replacements['dim_im_in_x_stride'].append(dim_im_in_x_stride) + replacements['dim_im_in_y_stride'].append(dim_im_in_y_stride) + dim_im_out_x_stride, dim_im_out_y_stride = ioStridesFromDimensions(WSize, CSize, + operatorRepresentation["output_bits"]) + replacements['dim_im_out_x_stride'].append(dim_im_out_x_stride) + replacements['dim_im_out_y_stride'].append(dim_im_out_y_stride) + + replacements['input_addr_offset'].append( + getInputAddrOffset(inWSize, dim_im_in_y_stride, padding_top, padding_left)) + + nKo, nKi, nHo, nWo, bKo, bKi, bHo, bWo, bHi, bWi = NE162DPWConvTemplate.getCounters( + inCSize, HSize, WSize, CSize, padding_bottom, padding_right, operatorRepresentation) + + replacements["nKo"].append(nKo) + replacements["nKi"].append(nKi) + replacements["nHo"].append(nHo) + replacements["nWo"].append(nWo) + replacements["bKo"].append(bKo) + replacements["bKi"].append(bKi) + replacements["bHo"].append(bHo) + replacements["bWo"].append(bWo) + replacements["bHi"].append(bHi) + replacements["bWi"].append(bWi) + + inputInCubes.append(InCube) + + inputLoadSchedule = [] + outputLoadSchedule = [] + + for a in inputInCubes: + inputLoadSchedule.append({"data_in": a}) + + for out in outputCubes: + outputLoadSchedule.append({"data_out": out}) + + weightBuffer = ctxt.lookup(varWeight) + assert isinstance(weightBuffer, VariableBuffer) + weightShape = weightBuffer.shape + + if hasattr(weightBuffer, "_memoryLevel") and weightBuffer._memoryLevel == "WeightMemory_SRAM": + replacements['weight_addr_offset'] = [] + replacementTypes['weight_addr_offset'] = PointerClass(uint32_t) + for absoluteCube in absoluteOutputCubes: + COffset, CSize = absoluteCube.absoluteOffset[-1], absoluteCube.rectangle.dims[-1] + WeightCube = HyperRectangle((COffset, 0, 0), (CSize, weightShape[-2], weightShape[-1])) + replacements['weight_addr_offset'].append(calculateFlatOffsetInBytes(WeightCube, weightBuffer)) + else: + inputWeightBaseOffsets, outputWeightBaseOffsets = cls.extractBaseAddr(tilingSolution, targetMemLevel, + operatorRepresentation, ['weight']) + inputBaseOffsets.update(inputWeightBaseOffsets) + outputBaseOffsets.update(outputWeightBaseOffsets) + + for cube, load in zip(outputCubes, inputLoadSchedule): + COffset, CSize = cube.offset[-1], cube.dims[-1] + load['weight'] = HyperRectangle((COffset, 0, 0), (CSize, weightShape[-2], weightShape[-1])) + + tilingSchedule = TilingSchedule(inputBaseOffsets, outputBaseOffsets, inputLoadSchedule, outputLoadSchedule) + variableReplacementSchedule = VariableReplacementScheme(replacements, replacementTypes) + + return variableReplacementSchedule, tilingSchedule + + +class NE16RQSPWConv2DTileConstraint(NE16PWConv2DTileConstraint): + + @staticmethod + def addGeometricalConstraint(tilerModel: TilerModel, parseDict: Dict, ctxt: NetworkContext) -> TilerModel: + tilerModel = NE16PWConv2DTileConstraint.addGeometricalConstraint(tilerModel, parseDict, ctxt) + return requantAddGeometricalConstraint(tilerModel, parseDict, ctxt) + + @classmethod + def serializeTilingSolution( + cls, tilingSolution: NodeMemoryConstraint, absoluteOutputCubes: List[AbsoluteHyperRectangle], + targetMemLevel: str, ctxt: NetworkContext, + operatorRepresentation: OperatorRepresentation) -> Tuple[VariableReplacementScheme, TilingSchedule]: + variableReplacementSchedule, tilingSchedule = super().serializeTilingSolution( + tilingSolution, absoluteOutputCubes, targetMemLevel, ctxt, operatorRepresentation) + + addrNames = ['mul', 'add'] + inputRequantBaseOffsets, _ = cls.extractBaseAddr(tilingSolution, targetMemLevel, operatorRepresentation, + addrNames) + newInputBaseOffsets = {**tilingSchedule.inputBaseOffsets, **inputRequantBaseOffsets} + + requantSchedule = requantLoadSchedule(absoluteOutputCubes, ctxt, operatorRepresentation) + newInputLoadSchedule = [{ + **load, + **rqLoad + } for load, rqLoad in zip(tilingSchedule.inputLoadSchedule, requantSchedule)] + + newTilingSchedule = TilingSchedule(newInputBaseOffsets, tilingSchedule.outputBaseOffsets, newInputLoadSchedule, + tilingSchedule.outputLoadSchedule) + + return variableReplacementSchedule, newTilingSchedule diff --git a/Deeploy/Targets/NE16/TileConstraints/RequantHelpers.py b/Deeploy/Targets/NE16/TileConstraints/RequantHelpers.py new file mode 100644 index 0000000000..e1e4b16aea --- /dev/null +++ b/Deeploy/Targets/NE16/TileConstraints/RequantHelpers.py @@ -0,0 +1,53 @@ +# SPDX-FileCopyrightText: 2025 ETH Zurich and University of Bologna +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import Dict, List + +from Deeploy.DeeployTypes import NetworkContext, OperatorRepresentation +from Deeploy.TilingExtension.TilerModel import TilerModel +from Deeploy.TilingExtension.TilingCodegen import AbsoluteHyperRectangle, HyperRectangle + + +def requantAddGeometricalConstraint(tilerModel: TilerModel, operatorRepresentation: OperatorRepresentation, + ctxt: NetworkContext) -> TilerModel: + outputBufferName = operatorRepresentation['data_out'] + mulBufferName = operatorRepresentation['mul'] + addBufferName = operatorRepresentation['add'] + + # Add I/O dimensions to the model as variables + for bufferName in [mulBufferName, addBufferName]: + tilerModel.addTensorDimToModel(ctxt, bufferName) + + outputChannelVar = tilerModel.getTensorDimVar(tensorName = outputBufferName, dimIdx = 3) + + addBuffer = ctxt.lookup(addBufferName) + addChannelVar = tilerModel.getTensorDimVar(tensorName = addBufferName, dimIdx = len(addBuffer.shape) - 1) + mulBuffer = ctxt.lookup(mulBufferName) + mulChannelVar = tilerModel.getTensorDimVar(tensorName = mulBufferName, dimIdx = len(mulBuffer.shape) - 1) + + tilerModel.addConstraint(outputChannelVar == addChannelVar) + tilerModel.addConstraint(outputChannelVar == mulChannelVar) + + return tilerModel + + +def requantLoadSchedule( + absoluteOutputCubes: List[AbsoluteHyperRectangle], + ctxt: NetworkContext, + operatorRepresentation: OperatorRepresentation, +) -> List[Dict[str, HyperRectangle]]: + outputCubes = [cube.rectangle for cube in absoluteOutputCubes] + + shapeMul = ctxt.lookup(operatorRepresentation["mul"]).shape + shapeAdd = ctxt.lookup(operatorRepresentation["add"]).shape + + schedule = [] + for cube in outputCubes: + (_, _, _, COffset) = cube.offset + (_, _, _, CSize) = cube.dims + MulCube = HyperRectangle((0,) * (len(shapeMul) - 1) + (COffset,), (1,) * (len(shapeMul) - 1) + (CSize,)) + AddCube = HyperRectangle((0,) * (len(shapeAdd) - 1) + (COffset,), (1,) * (len(shapeAdd) - 1) + (CSize,)) + schedule.append({"mul": MulCube, "add": AddCube}) + + return schedule diff --git a/Deeploy/Targets/NE16/TileConstraints/__init__.py b/Deeploy/Targets/NE16/TileConstraints/__init__.py new file mode 100644 index 0000000000..be436b64a3 --- /dev/null +++ b/Deeploy/Targets/NE16/TileConstraints/__init__.py @@ -0,0 +1,5 @@ +# SPDX-FileCopyrightText: 2024 ETH Zurich and University of Bologna +# +# SPDX-License-Identifier: Apache-2.0 + +from . import * diff --git a/Deeploy/Targets/NE16/Tiler.py b/Deeploy/Targets/NE16/Tiler.py new file mode 100644 index 0000000000..2bc53a441a --- /dev/null +++ b/Deeploy/Targets/NE16/Tiler.py @@ -0,0 +1,29 @@ +# SPDX-FileCopyrightText: 2024 ETH Zurich and University of Bologna +# +# SPDX-License-Identifier: Apache-2.0 + + +from Deeploy.Targets.NE16.Bindings import NE16DenseConv2DBindings, NE16DWConv2DBindings, NE16PWConv2DBindings, \ + NE16RQSDenseConv2DBindings, NE16RQSDWConv2DBindings, NE16RQSPWConv2DBindings +from Deeploy.Targets.NE16.TileConstraints.NE16DenseConstraint import NE16DenseConv2DTileConstraint, \ + NE16RQSDenseConv2DTileConstraint +from Deeploy.Targets.NE16.TileConstraints.NE16DepthwiseConstraint import NE16DWConv2DTileConstraint, \ + NE16RQSDWConv2DTileConstraint +from Deeploy.Targets.NE16.TileConstraints.NE16PointwiseConstraint import NE16PWConv2DTileConstraint, \ + NE16RQSPWConv2DTileConstraint +from Deeploy.TilingExtension.TilerExtension import TilingReadyNodeBindings + +NE16RQSPWConv2DTilingReadyBindings = TilingReadyNodeBindings(nodeBindings = NE16RQSPWConv2DBindings, + tileConstraint = NE16RQSPWConv2DTileConstraint()) +NE16PWConv2DTilingReadyBindings = TilingReadyNodeBindings(nodeBindings = NE16PWConv2DBindings, + tileConstraint = NE16PWConv2DTileConstraint()) + +NE16RQSDWConv2DTilingReadyBindings = TilingReadyNodeBindings(nodeBindings = NE16RQSDWConv2DBindings, + tileConstraint = NE16RQSDWConv2DTileConstraint()) +NE16DWConv2DTilingReadyBindings = TilingReadyNodeBindings(nodeBindings = NE16DWConv2DBindings, + tileConstraint = NE16DWConv2DTileConstraint()) + +NE16RQSDenseConv2DTilingReadyBindings = TilingReadyNodeBindings(nodeBindings = NE16RQSDenseConv2DBindings, + tileConstraint = NE16RQSDenseConv2DTileConstraint()) +NE16DenseConv2DTilingReadyBindings = TilingReadyNodeBindings(nodeBindings = NE16DenseConv2DBindings, + tileConstraint = NE16DenseConv2DTileConstraint()) diff --git a/Deeploy/Targets/NE16/TopologyOptimizationPasses/Passes.py b/Deeploy/Targets/NE16/TopologyOptimizationPasses/Passes.py new file mode 100644 index 0000000000..60a747aa00 --- /dev/null +++ b/Deeploy/Targets/NE16/TopologyOptimizationPasses/Passes.py @@ -0,0 +1,278 @@ +# SPDX-FileCopyrightText: 2024 ETH Zurich and University of Bologna +# +# SPDX-License-Identifier: Apache-2.0 + +import itertools +import math +from functools import partial +from typing import Generator, List, Tuple + +import numpy as np +import numpy.typing as npt +import onnx_graphsurgeon as gs + +from Deeploy.CommonExtensions.OptimizationPasses.Matchers import Match, NonBranchingMatcher +from Deeploy.CommonExtensions.OptimizationPasses.PassClasses import ReplaceSequentialPatternPass, SequentialPass, \ + contextagnostic +from Deeploy.CommonExtensions.OptimizationPasses.TopologyOptimizationPasses.LoweringOptimizationPasses import \ + RemoveGlobalOutputReshapePass, _createReshape +from Deeploy.EngineExtension.OptimizationPasses.TopologyOptimizationPasses.EngineColoringPasses import \ + EngineDiscolorationPass +from Deeploy.Targets.Generic.TopologyOptimizationPasses.Passes import ReshapeConstOptPass, ReshapeMergePass + + +def _weightEncode(weight: npt.NDArray[np.uint8], bits: int, depthwise: bool = False) -> npt.NDArray[np.uint8]: + """NE16 weight encoder, ported from pulp-nnx/test/Ne16Weight.py. + + Expected weight shape: (cout, cin, H, W). + Output layout: (cout, cinMajor, Bits, H*W, cinMinorBytes) where + CIN_SUBTILE = 16 (single mode, no 1x1 vs 3x3 split like Neureka). + """ + _NE16_CIN_SUBTILE = 16 + + if depthwise: + weight = weight.transpose(1, 0, 2, 3) # Swap cout and cin + + cout, cin, height, width = weight.shape + + # Pad cin to be divisible with CIN_SUBTILE + if cin % _NE16_CIN_SUBTILE != 0: + cinPad = _NE16_CIN_SUBTILE - cin % _NE16_CIN_SUBTILE + weight = np.pad( + weight, + ((0, 0), (0, cinPad), (0, 0), (0, 0)), + "constant", + constant_values = 0, + ) + cin = cin + cinPad + + cinMajor = cin // _NE16_CIN_SUBTILE + cinMinor = _NE16_CIN_SUBTILE + + # (cout, cinMajor, cinMinor, H*W, 1) + weight = weight.reshape(cout, cinMajor, cinMinor, height * width, 1) + # (cout, cinMajor, cinMinor, H*W, Bits) + weight = np.unpackbits(weight, axis = -1, count = bits, bitorder = "little") + # (cout, cinMajor, Bits, H*W, cinMinor) + weight = weight.transpose(0, 1, 4, 3, 2) + # Pack cinMinor bits into bytes — 16 bits = 2 bytes + weight = weight.reshape(-1, 8) + weight = np.packbits(weight, axis = -1, bitorder = "little") + cinMinorBytes = cinMinor // 8 + # Layout rank varies by conv kind: + # - Dense 3x3 (!depthwise, kernel 3x3): rank 4 + # (cout, cinMajor, Bits, H*W*cinMinorBytes) + # — NE16DenseConstraint tiles over weight.shape[3]. + # - PW 1x1 and DW 3x3: rank 3 + # (cout, cinMajor, Bits*H*W*cinMinorBytes) + # — NE16{Pointwise,Depthwise}Constraint don't need a bits dim. + if not depthwise and height == 3 and width == 3: + return weight.reshape(cout, cinMajor, bits, height * width * cinMinorBytes) + return weight.reshape(cout, cinMajor, bits * height * width * cinMinorBytes) + + +def _ne16_adjust_weight_memory_layout_fun(graph: gs.Graph, match: Match, name: str, default_channels_first: bool, + ne16EngineName: str): + matched_nodes = list(match.nodes_map.values()) + node = matched_nodes[0] + + if not ("engine" in node.attrs and node.attrs["engine"] == ne16EngineName): + return graph + + weightTensor = node.inputs[1] + + if not isinstance(weightTensor, gs.Constant): + return graph + + # Adjust N-EUREKA's weights + values = weightTensor.values + + # Extract weight offset and translate weights by the offset + weight_offset = values.min() + values = values - weight_offset + node.attrs["weight_offset"] = weight_offset + + if "channels_first" in node.attrs: + channels_first = node.attrs["channels_first"] + else: + channels_first = default_channels_first + + # Weight encode expects channels-first (cout, cin_per_group, H, W) + if not channels_first: + values = values.transpose(0, 3, 1, 2) + + bits = 8 # Support only 8 bit weights for now + if node.attrs['group'] == 1: + weightTensor.values = _weightEncode(values.astype(np.uint8), bits, depthwise = False) + else: + # Depthwise: Deeploy's NHWC pass leaves weight as + # (cin_per_group=1, cout=group, H, W) after the transpose above; + # Ne16Weight.py's encode expects standard (cout, cin_per_group, H, W) + # — swap axes 0/1 before encoding so the result is a single packed + # (1, 1, packed_bytes) block across up to NE16_SUBTILE_INPUT_CHANNEL=16 + # parallel output channels. + values = values.transpose(1, 0, 2, 3) + weightTensor.values = _weightEncode(values.astype(np.uint8), bits, depthwise = True) + weightTensor.name = f"{name}_{weightTensor.name}" + + return graph + + +@contextagnostic +class NE16AdjustWeightMemoryLayoutPass(ReplaceSequentialPatternPass): + + def __init__(self, default_channels_first: bool, ne16EngineName: str): + graph = gs.Graph() + _input = gs.Variable(name = 'input_1') + output = graph.layer(inputs = [_input], outputs = ['out'], op = 'RequantizedConv|Conv', name = 'node') + graph.outputs.append(output) + graph.inputs.append(_input) + + super().__init__( + graph, + partial(_ne16_adjust_weight_memory_layout_fun, + default_channels_first = default_channels_first, + ne16EngineName = ne16EngineName), "_NE16_ADJUST_WEIGHT_MEMORY_LAYOUT_PASS", + NonBranchingMatcher(regex_op = True)) + + +def _findAllMultiplicands(x: int) -> List[int]: + multiplicands = [] + tmpX = x + for i in range(2, math.ceil(math.sqrt(x))): # Ceil cause range doesn't include the last number + while tmpX % i == 0: + multiplicands.append(i) + tmpX = tmpX / i + + if x // math.prod(multiplicands) > 1: + multiplicands.append(x // math.prod(multiplicands)) + + return multiplicands + + +def _findAllReshapeOptions(dim: int) -> Generator[Tuple[int, int], None, None]: + multiplicands = _findAllMultiplicands(dim) + for combLen in range(1, 1 + (len(multiplicands) // 2)): + for comb in itertools.combinations(multiplicands, combLen): + a = math.prod(comb) + b = dim // a + yield a, b + + +def _nSubtiles(dims: Tuple[int, int]): + return math.ceil(dims[0] / 6) * math.ceil(dims[1] / 6) + + +def _findLowestNumberOfSubtilesReshapeOptions(dim: int) -> List[Tuple[int, int]]: + lowestNumberOfSubtiles = dim + bestOptions: List[Tuple[int, int]] = [(dim, 1)] + for option in _findAllReshapeOptions(dim): + nSubtiles = _nSubtiles(option) + if nSubtiles < lowestNumberOfSubtiles: + lowestNumberOfSubtiles = nSubtiles + bestOptions = [option] + elif nSubtiles == lowestNumberOfSubtiles: + bestOptions.append(option) + return bestOptions + + +def _bestReshapeOption(dim: int) -> Tuple[int, int]: + smallestDim = dim + biggestDim = 1 + for option in _findLowestNumberOfSubtilesReshapeOptions(dim): + if option[0] < smallestDim: + smallestDim = option[0] + biggestDim = option[1] + elif option[1] < smallestDim: + smallestDim = option[1] + biggestDim = option[0] + return biggestDim, smallestDim + + +def _ne16_reshape_pointwise_convolution_fun(graph: gs.Graph, match: Match, name: str, default_channels_first: bool, + ne16EngineName: str): + matched_nodes = list(match.nodes_map.values()) + node = matched_nodes[0] + + if not ("engine" in node.attrs and node.attrs["engine"] == ne16EngineName): + return graph + + if not (node.attrs["kernel_shape"] == [1, 1]): + return graph + + if "channels_first" in node.attrs: + channels_first = node.attrs["channels_first"] + else: + channels_first = default_channels_first + + def extractSpatialDims(shape: List[int]) -> List[int]: + if channels_first: + return shape[-2:] + else: + return shape[-3:-1] + + def replaceSpatialDims(shape: List[int], newSpatialDims: Tuple[int, int]) -> List[int]: + if channels_first: + return shape[:-2] + list(newSpatialDims) + else: + return shape[:-3] + list(newSpatialDims) + shape[-1:] + + _input = node.inputs[0] + spatialDims = extractSpatialDims(_input.shape) + newSpatialDims = _bestReshapeOption(math.prod(spatialDims)) + newInputShape = replaceSpatialDims(_input.shape, newSpatialDims) + + inputReshapeNode, reshapedInput = _createReshape(_input, name, newInputShape) + graph.nodes.append(inputReshapeNode) + node.inputs[0] = reshapedInput + + output = node.outputs[0] + newOutputShape = replaceSpatialDims(output.shape, newSpatialDims) + reshapedOutput = gs.Variable(output.name + "_Reshaped", dtype = output.dtype, shape = newOutputShape) + outputReshapeNode, _ = _createReshape(reshapedOutput, name, output.shape, output) + graph.nodes.append(outputReshapeNode) + node.outputs[0] = reshapedOutput + + return graph + + +@contextagnostic +class NE16ReshapePointwiseConvolutionPass(ReplaceSequentialPatternPass): + """Reshape pointwise convolution's spatial dimensions so that they work better for N-EUREKA's hardware tiling""" + + def __init__(self, default_channels_first: bool, ne16EngineName: str): + graph = gs.Graph() + _input = gs.Variable(name = 'input_1') + output = graph.layer(inputs = [_input], outputs = ['out'], op = 'RequantizedConv|Conv', name = 'node') + graph.outputs.append(output) + graph.inputs.append(_input) + + super().__init__( + graph, + partial(_ne16_reshape_pointwise_convolution_fun, + default_channels_first = default_channels_first, + ne16EngineName = ne16EngineName), "_NE16_RESHAPE_POINTWISE_CONVOLUTION_PASS", + NonBranchingMatcher(regex_op = True)) + + +class ConvEngineDiscolorationPass(EngineDiscolorationPass): + + def __init__(self): + pattern = gs.Graph() + _input = gs.Variable(name = 'input') + output = pattern.layer(inputs = [_input], outputs = ['output'], op = 'RequantizedConv|Conv', name = 'conv') + pattern.outputs.append(output) + pattern.inputs.append(_input) + super().__init__(pattern, "_CONV_ENGINE_DISCOLORATION_PASS", matcher = NonBranchingMatcher(regex_op = True)) + + +@contextagnostic +class NE16OptimizationPass(SequentialPass): + + def __init__(self, default_channels_first: bool, ne16EngineName: str): + super().__init__(NE16AdjustWeightMemoryLayoutPass(default_channels_first, ne16EngineName), + NE16ReshapePointwiseConvolutionPass(default_channels_first, ne16EngineName), + ReshapeMergePass(), + ReshapeConstOptPass(), + RemoveGlobalOutputReshapePass(), + name_prefix = '') diff --git a/Deeploy/Targets/NE16/TopologyOptimizationPasses/__init__.py b/Deeploy/Targets/NE16/TopologyOptimizationPasses/__init__.py new file mode 100644 index 0000000000..be436b64a3 --- /dev/null +++ b/Deeploy/Targets/NE16/TopologyOptimizationPasses/__init__.py @@ -0,0 +1,5 @@ +# SPDX-FileCopyrightText: 2024 ETH Zurich and University of Bologna +# +# SPDX-License-Identifier: Apache-2.0 + +from . import * diff --git a/Deeploy/Targets/NE16/__init__.py b/Deeploy/Targets/NE16/__init__.py new file mode 100644 index 0000000000..be436b64a3 --- /dev/null +++ b/Deeploy/Targets/NE16/__init__.py @@ -0,0 +1,5 @@ +# SPDX-FileCopyrightText: 2024 ETH Zurich and University of Bologna +# +# SPDX-License-Identifier: Apache-2.0 + +from . import * diff --git a/Deeploy/Targets/PULPOpen/Templates/FloatGemmTemplate.py b/Deeploy/Targets/PULPOpen/Templates/FloatGemmTemplate.py index 59499706e5..280cb4ff6e 100644 --- a/Deeploy/Targets/PULPOpen/Templates/FloatGemmTemplate.py +++ b/Deeploy/Targets/PULPOpen/Templates/FloatGemmTemplate.py @@ -4,7 +4,10 @@ from typing import Dict, List, Tuple -from Deeploy.AbstractDataTypes import float32_tPtr +from Deeploy.AbstractDataTypes import PointerClass +from Deeploy.CommonExtensions.DataTypes import float32_t + +float32_tPtr = PointerClass(float32_t) from Deeploy.DeeployTypes import NetworkContext, NodeTemplate, OperatorRepresentation diff --git a/DeeployTest/CMakeLists.txt b/DeeployTest/CMakeLists.txt index b7f3535790..4e45904541 100644 --- a/DeeployTest/CMakeLists.txt +++ b/DeeployTest/CMakeLists.txt @@ -50,7 +50,7 @@ elseif(DEEPLOY_ARCH STREQUAL SNITCH) add_subdirectory(Platforms/Snitch) elseif(DEEPLOY_ARCH STREQUAL CHIMERA) add_subdirectory(Platforms/Chimera) -elseif(platform STREQUAL GAP9) +elseif(platform STREQUAL GAP9 OR platform STREQUAL GAP9_w_NE16) # Search for hex files generated by Python code generator # These files indicate L3 mode (external memory with readfs) diff --git a/DeeployTest/Tests/Kernels/Integer/Conv/Dense_2D_RQ/inputs.npz b/DeeployTest/Tests/Kernels/Integer/Conv/Dense_2D_RQ/inputs.npz new file mode 100644 index 0000000000000000000000000000000000000000..cfd8568bf15e5dc8924e45317d77c9725336c764 GIT binary patch literal 8456 zcmb8#L28wE7=`ig8ZFq7gDl`Wx)p-pq!n@8Nl-HBpraBrqEKQ;R1hh24;{J`7vcuo zM|-dO>)?TxOQPZ7ocEmb{}zyw&rKnr{{mXzI=YR z^M5`4^XhEtU%foNINR!<9vrP7{Cu>2xcz_cX#Lmv`aeIezJ11WBjn@v^FkkdH&EXl z^-!Lc+$vZDDOULd3$teN6wzybeGeUpESR;;~hH__1ojy>76gbuLMym#b~k9t_&p8h^!BR2PwUmE-R1dq-LbwN z>f_XNFYV}IecyrZ-P4zImq&NhYu;T?Zt~^Oo;>Z2eEWRlH&-q1J>;W3cRTLY^6q-{ zl-F~o^^uSKsn>Vq_1mk~*Kz-Y{9lCeEJ><*t(H;8^tzXSYJCokmceiuh@6^|;_FeVqzRS1gZijZ~qdPtA$#q9w zALaDquy;9rcWAlpxT768Y+t^UOsxScm9-fM}6<}C*57zjy}KIt~;&I$M$K|m$$>G<*|7=)GzfX zzq@jiZx7{t2fuX3-f4T@FZJcB?dZwT@+dDym;3I`dvEgHeaE!VM|=8wbhnH0`)Ef_ zpHI8&w$vpGG^~rBU84Umx93&mH-8Xw>hWZwK`z+Ce?+E-i0Y9($KVJMw%z zcRtFyqkH?k%kkyATP=@v!#e)HGe(YNP2`83+6_3G>Mj!&1~lS4bGH)%WWsON6Ke)I0~`|k4S zJ*;oHzIRZcZk~3xi}mfb$Csz&?A5Ps-yPlU*+)5SjxOz@J(Q=Lo3vi z_2qH8>rd2kr|n{W-^Zts-(9+tM>|+wzuzbCPM3OUALZR?wC{d$yt;9CF#Yj@^!D!+ az83!fQ-+h{TR-fMtL?vEw*0i-ss8}SjW_53 literal 0 HcmV?d00001 diff --git a/DeeployTest/Tests/Kernels/Integer/Conv/Dense_2D_RQ/network.onnx b/DeeployTest/Tests/Kernels/Integer/Conv/Dense_2D_RQ/network.onnx new file mode 100644 index 0000000000000000000000000000000000000000..c357ce307e354f5ceef2f5c2343ac300481c6989 GIT binary patch literal 9817 zcmbta&2F5<67^sM!>1q_^SeP@vGjrsJ2qn_AfX7cfCUmlvWz@8u`rVm+qqdEQX{Nm;D@^JC^U~ljC z(I1P02`&0@zW1(?kJWdJ9{Ji!JE(VkW#=B!1kPQ?oYYlD z@}*C{^8sTY9@q1&yER@ID{y%`{|qO-CEa=bo1D|zef1s451jLHTZ%3gNIN1?PR7laDUU2zS(*Y z*X1Lx=BM6yzIC)8^*r78jF{)*?MHrMo>_V6?L5SESBQ<>O9%9M`_Hm08NRS{Qj;9pUW9xB9B)9W$tjo0@t9SGGL8c7AB9PklpruQcM+Yp<>!a*E6OzDwG z<~FuvR&M&8iww(>KXyp{I7^)7YhU3z8Q81SYTiAysnav=zu=;Lgk>eY;w8L9G zbH`I$dLm=tP%k04ydu`K)S~_ChuNtS{*`4Tt zPT$&Tg;S$T%r9~Ec=J?G+_1*BUui-QP4uQ^uOm)b;Q{Zs^o6D0`|gcd>Sy!F=dC@^ zORK7#Tq`UzGjrT4t^C@ldihs(`-FylVqIS1!h8o|k*z#bX;h2PEbp7T{T9B>KK0Vf z?Mcq}?zG4u#}nD?b%EEsd^e?eh~x5%v*`^yYrIE2xYCUr`tqc%X0bPY=n`wL_UqXr z9)5Uu#@r6$oYV!LbILB)eCq;pJoL$p-q;D8G~!i# zmk{AcnEFxM^_+`+gsH|{;hJx~c>kvGLaUv%hx!S7O#~Iyj$dIo*%#}CaO=UQEozl`X z=u1OfmpWebJbJT654d}&!b+s4efDdgx+?Lws6ez6WKOjx=v$x zot`-IVm4}c+3T*7V+AHYDlk?5F69*`A2ad>2V8UIzpqyoE01zi8`^f%d}S5K**=r+ zyp?H~{Ihk{Qsu7-T<@BD&y1a!^jfa?sLT*Y9eN$@7rN9Gm%8$z@ATP+SKMltgeNStzH;@>Jm=++R-9M8sfl;y zs4on2;*BbN~90|4U05z8`*hW_jfB#CJfecT~>c6^^%f;aLF+xGoC zZ{<=yYKY~fk7s8tJbuOua>@|LD!je!X6rjj1Fic!OBuD3?NT|boIi4jSBCN0Y2AmIR=x};%xxi&z^8WmvDgXNFKE=a>Y47muzX3m+lSlvn literal 0 HcmV?d00001 diff --git a/DeeployTest/Tests/Kernels/Integer/Conv/Dense_2D_RQ/outputs.npz b/DeeployTest/Tests/Kernels/Integer/Conv/Dense_2D_RQ/outputs.npz new file mode 100644 index 0000000000000000000000000000000000000000..76f10c49d9a4bbb40094fa02b0c2b90d36766a89 GIT binary patch literal 8458 zcmeI&TWg$U5C`ygn^tQ)RgLG<8{4?8ZqA({W&`Lx8Pm3q#4A9z}mcPpp=c(XZEW3*>Ktq11hdFP8t@Be?YYZs8Gm%j3Do_Z`5X z;$^?^>XgUpWpAIt-Jt*c;@PX8-5U*maBr8MAN#9Le)*G^d6KIGeW4Wku7B|2yE2c|C!c*uy*){JkI(+X>vPVBJ}G?4 zvmf?NT=U8L%)xx%(tq;o>;*n??oF@%_DH_zC;9Z}9waXF;$J*r1L|Ik70-K6hy2B}hi?{MpZLF8a{gZm96#?x zys>yxKh%T1RXD_*`>+K1cddBTeE3zDeb!g@;$1I4?l(W|g2gRp9cO>VOeN~72)a^dnPkSI=`>sB7u5<=`z6bTd z+4r#o{Mkcyh#!7Ycyog1SN-Z{r!IY#k2w5!FXqXP-u%VoM|^egBOiCooWxC>`pik+ zeLt~RpZ?i;g@eUX3WEslHS-jU0bT>soL zIN#6eu~#UZ-zhmayghYK;l+2q`Aa!JD4hA&Yj;yUz8~Pp@%op)m&AwDuZ*t`euw0K zW?lNoZ_L*r`s}NH z<{!_mIr8H@h{K=XMfxNj`K;=_T9hCCgZs4fh&;uIUoU{ZZx(Ovxi99%j$gd_^CzA> z=<)VB_edS$!?8;$U-En(5n`qzv@lB zJ8A#v?Hk@6qvWXmnUg-~o4Lt{UVpM5-b>~oUXOZDJzP(pdG+x0J-b#P+4rbE>W}`j z6CaNJywse|YtL?HdvalIt8QG&HrBEq(P7WmSJc6(n(2^VKQvZVZV(F>>pF7 BPul None: "siracusa_neureka_tiled: mark test as a Siracusa + Neureka platform test (tiled)") config.addinivalue_line("markers", "gap9: mark test as a GAP9 platform test") config.addinivalue_line("markers", "gap9_tiled: mark test as a GAP9 platform test (tiled)") + config.addinivalue_line("markers", "gap9_w_ne16_tiled: mark test as a GAP9 + NE16 platform test (tiled)") config.addinivalue_line("markers", "kernels: mark test as a kernel test (individual operators)") config.addinivalue_line("markers", "models: mark test as a model test (full networks)") config.addinivalue_line("markers", "singlebuffer: mark test as single-buffer configuration") diff --git a/DeeployTest/deeployRunner_tiled_gap9_w_ne16.py b/DeeployTest/deeployRunner_tiled_gap9_w_ne16.py new file mode 100644 index 0000000000..63c2277789 --- /dev/null +++ b/DeeployTest/deeployRunner_tiled_gap9_w_ne16.py @@ -0,0 +1,22 @@ +#!/usr/bin/env python +# SPDX-FileCopyrightText: 2025 ETH Zurich and University of Bologna +# +# SPDX-License-Identifier: Apache-2.0 + +import sys + +from testUtils.deeployRunner import main + +if __name__ == "__main__": + + # Define parser setup callback to add GAP9+NE16-specific arguments + def setup_parser(parser): + parser.add_argument('--cores', type = int, default = 8, help = 'Number of cores (default: 8)\n') + parser.add_argument('--ne16-wmem', action = 'store_true', help = 'Enable NE16 weight memory\n') + parser.add_argument('--enable-3x3', action = 'store_true', help = 'Enable 3x3 convolutions\n') + + sys.exit( + main(default_platform = "GAP9_w_NE16", + default_simulator = "gvsoc", + tiling_enabled = True, + parser_setup_callback = setup_parser)) diff --git a/DeeployTest/testMVP.py b/DeeployTest/testMVP.py index 9678bc4e4f..56d05806c0 100644 --- a/DeeployTest/testMVP.py +++ b/DeeployTest/testMVP.py @@ -96,7 +96,7 @@ def setupDeployer(graph: gs.Graph, memoryHierarchy: MemoryHierarchy, defaultTarg scheduler = _mockScheduler) # Make the deployer engine-color-aware - if args.platform == "Siracusa_w_neureka": + if args.platform in ("Siracusa_w_neureka", "GAP9_w_NE16"): deployer = EngineColoringDeployerWrapper(deployer) # Make platform memory-aware after mapDeployer because it requires the platform to be an instance of an unwrapped platform diff --git a/DeeployTest/testUtils/core/execution.py b/DeeployTest/testUtils/core/execution.py index 4c6c972679..24c1f7154e 100644 --- a/DeeployTest/testUtils/core/execution.py +++ b/DeeployTest/testUtils/core/execution.py @@ -132,7 +132,7 @@ def build_binary(config: DeeployTestConfig) -> None: ] # GAP9 requires the 'image' target to generate MRAM .bin files for GVSOC - if config.platform == 'GAP9': + if config.platform in ('GAP9', 'GAP9_w_NE16'): cmd.append("image") env = os.environ.copy() diff --git a/DeeployTest/testUtils/platformMapping.py b/DeeployTest/testUtils/platformMapping.py index 9d526906f9..4dc2bbf824 100644 --- a/DeeployTest/testUtils/platformMapping.py +++ b/DeeployTest/testUtils/platformMapping.py @@ -20,6 +20,8 @@ from Deeploy.Targets.Generic.Platform import GenericOptimizer, GenericPlatform from Deeploy.Targets.MemPool.Deployer import MemPoolDeployer from Deeploy.Targets.MemPool.Platform import MemPoolOptimizer, MemPoolPlatform +from Deeploy.Targets.NE16.Deployer import NE16Deployer +from Deeploy.Targets.NE16.Platform import MemoryNE16Platform, MemoryNE16PlatformWrapper, NE16Optimizer, NE16Platform from Deeploy.Targets.Neureka.Deployer import NeurekaDeployer from Deeploy.Targets.Neureka.Platform import MemoryNeurekaPlatform, MemoryNeurekaPlatformWrapper, NeurekaOptimizer, \ NeurekaPlatform @@ -31,7 +33,7 @@ from Deeploy.Targets.SoftHier.Platform import SoftHierOptimizer, SoftHierPlatform _SIGNPROP_PLATFORMS = ["Apollo3", "Apollo4", "QEMU-ARM", "Generic", "MemPool", "SoftHier"] -_NONSIGNPROP_PLATFORMS = ["Siracusa", "Siracusa_w_neureka", "PULPOpen", "Snitch", "Chimera", "GAP9"] +_NONSIGNPROP_PLATFORMS = ["Siracusa", "Siracusa_w_neureka", "PULPOpen", "Snitch", "Chimera", "GAP9", "GAP9_w_NE16"] _PLATFORMS = _SIGNPROP_PLATFORMS + _NONSIGNPROP_PLATFORMS @@ -67,6 +69,9 @@ def mapPlatform(platformName: str) -> Tuple[DeploymentPlatform, bool]: elif platformName == "Siracusa_w_neureka": Platform = NeurekaPlatform() + elif platformName == "GAP9_w_NE16": + Platform = NE16Platform() + elif platformName == "Snitch": Platform = SnitchPlatform() @@ -90,6 +95,8 @@ def setupMemoryPlatform(platform: DeploymentPlatform, memoryHierarchy: MemoryHie weightMemoryLevel = memoryHierarchy.memoryLevels["WeightMemory_SRAM"] \ if "WeightMemory_SRAM" in memoryHierarchy.memoryLevels else None return MemoryNeurekaPlatformWrapper(platform, memoryHierarchy, defaultTargetMemoryLevel, weightMemoryLevel) + elif isinstance(platform, NE16Platform): + return MemoryNE16PlatformWrapper(platform, memoryHierarchy, defaultTargetMemoryLevel) if isinstance(platform, GAP9Platform): return MemoryGAP9PlatformWrapper(platform, memoryHierarchy, defaultTargetMemoryLevel) else: @@ -207,6 +214,23 @@ def mapDeployer(platform: DeploymentPlatform, default_channels_first = default_channels_first, deeployStateDir = deeployStateDir) + elif isinstance(platform, (NE16Platform, MemoryNE16Platform, MemoryNE16PlatformWrapper)): + + if loweringOptimizer is None: + loweringOptimizer = NE16Optimizer + + if default_channels_first is None: + default_channels_first = False + + deployer = NE16Deployer(graph, + platform, + inputTypes, + loweringOptimizer, + scheduler, + name = name, + default_channels_first = default_channels_first, + deeployStateDir = deeployStateDir) + elif isinstance(platform, (GAP9Platform, MemoryGAP9Platform, MemoryGAP9PlatformWrapper)): if loweringOptimizer is None: diff --git a/DeeployTest/test_gap9_ne16_tiled_config.py b/DeeployTest/test_gap9_ne16_tiled_config.py new file mode 100644 index 0000000000..7dde8bdf5d --- /dev/null +++ b/DeeployTest/test_gap9_ne16_tiled_config.py @@ -0,0 +1,38 @@ +# SPDX-FileCopyrightText: 2026 ETH Zurich and University of Bologna +# +# SPDX-License-Identifier: Apache-2.0 +"""Test configuration for GAP9 platform with NE16 accelerator (tiled). + +NE16-supported convolution kernels verified to dispatch to NE16 +(`ne16_nnx_dispatch` appears in generated Network.c) and PASS on +gvsoc gap9.evk: + +- PW 1x1 RQ Conv (PW_2D_RQ/Regular_RQ) +- PW 1x1 Conv (PW_2D) +- DW 3x3 RQ Conv (DW_2D_RQ, with --enable-3x3) +- 3x3 strided RQ (StriddedPadded_2D_RQ) — falls back to cluster + because stride 2x2 requires --enableStrides which isn't wired into + the tiled runner today; still PASS via the cluster kernel. +""" + +DEFAULT_CORES = 8 + +L2_SINGLEBUFFER_KERNELS = { + "Kernels/Integer/Conv/PW_2D_RQ/Regular_RQ": [32000, 16000], + "Kernels/Integer/Conv/PW_2D": [32000], + "Kernels/Integer/Conv/DW_2D_RQ": [32000, 16000], + "Kernels/Integer/Conv/Dense_2D_RQ": [32000], + "Kernels/Integer/Conv/StriddedPadded_2D_RQ": [32000], +} + +L2_DOUBLEBUFFER_KERNELS = { + "Kernels/Integer/Conv/PW_2D_RQ/Regular_RQ": [32000], + "Kernels/Integer/Conv/DW_2D_RQ": [32000], +} + +L2_SINGLEBUFFER_MODELS = {} + +L3_SINGLEBUFFER_MODELS = {} +L3_DOUBLEBUFFER_MODELS = {} +L2_SINGLEBUFFER_KERNELS_WMEM = {} +L3_DOUBLEBUFFER_MODELS_WMEM = {} diff --git a/DeeployTest/test_platforms.py b/DeeployTest/test_platforms.py index 6d9f3cfcd7..daad19e6ce 100644 --- a/DeeployTest/test_platforms.py +++ b/DeeployTest/test_platforms.py @@ -11,6 +11,10 @@ from test_gap9_config import DEFAULT_NUM_CORES as GAP9_DEFAULT_NUM_CORES from test_gap9_config import KERNEL_TESTS as GAP9_KERNEL_TESTS from test_gap9_config import MODEL_TESTS as GAP9_MODEL_TESTS +from test_gap9_ne16_tiled_config import DEFAULT_CORES as GAP9_NE16_TILED_DEFAULT_CORES +from test_gap9_ne16_tiled_config import L2_DOUBLEBUFFER_KERNELS as GAP9_NE16_L2_DOUBLEBUFFER_KERNELS +from test_gap9_ne16_tiled_config import L2_SINGLEBUFFER_KERNELS as GAP9_NE16_L2_SINGLEBUFFER_KERNELS +from test_gap9_ne16_tiled_config import L2_SINGLEBUFFER_MODELS as GAP9_NE16_L2_SINGLEBUFFER_MODELS from test_gap9_tiled_config import DEFAULT_CORES as GAP9_TILED_DEFAULT_CORES from test_gap9_tiled_config import L2_DOUBLEBUFFER_KERNELS as GAP9_L2_DOUBLEBUFFER_KERNELS from test_gap9_tiled_config import L2_DOUBLEBUFFER_MODELS as GAP9_L2_DOUBLEBUFFER_MODELS @@ -133,6 +137,7 @@ def param_id(param): # siracusa_neureka_tiled: tests from the Siracusa + Neureka platform (tiled) # gap9: tests from the GAP9 platform (untiled) # gap9_tiled: tests from the GAP9 platform (tiled) +# gap9_w_ne16_tiled: tests from the GAP9 + NE16 platform (tiled) # Test type markers: # kernels: single kernel (or single layer) tests # models: full model (multiple layer) tests @@ -987,3 +992,104 @@ def test_gap9_tiled_models_l3_doublebuffer(test_params, deeploy_test_dir, toolch double_buffer = True, ) run_and_assert_test(test_name, config, skipgen, skipsim) + + +@pytest.mark.gap9_w_ne16_tiled +@pytest.mark.kernels +@pytest.mark.singlebuffer +@pytest.mark.l2 +@pytest.mark.parametrize( + "test_params", + generate_test_params(GAP9_NE16_L2_SINGLEBUFFER_KERNELS, "L2-singlebuffer"), + ids = param_id, +) +def test_gap9_w_ne16_tiled_kernels_l2_singlebuffer(test_params, deeploy_test_dir, toolchain, toolchain_dir, cmake_args, + skipgen, skipsim) -> None: + test_name, l1, config_name = test_params + + ne16_cmake_args = cmake_args + [f"NUM_CORES={GAP9_NE16_TILED_DEFAULT_CORES}"] + + # --enable-3x3 is additive (extends NE16Engine.canExecute to DW/Dense 3x3); + # safe to enable for all three kernel cases (PW 1x1 + DW 3x3 + Dense 3x3). + config = create_test_config( + test_name = test_name, + platform = "GAP9_w_NE16", + simulator = "gvsoc", + deeploy_test_dir = deeploy_test_dir, + toolchain = toolchain, + toolchain_dir = toolchain_dir, + cmake_args = ne16_cmake_args, + tiling = True, + cores = GAP9_NE16_TILED_DEFAULT_CORES, + l1 = l1, + default_mem_level = "L2", + double_buffer = False, + gen_args = ["--enable-3x3"], + ) + run_and_assert_test(test_name, config, skipgen, skipsim) + + +@pytest.mark.gap9_w_ne16_tiled +@pytest.mark.models +@pytest.mark.singlebuffer +@pytest.mark.l2 +@pytest.mark.parametrize( + "test_params", + generate_test_params(GAP9_NE16_L2_SINGLEBUFFER_MODELS, "L2-singlebuffer"), + ids = param_id, +) +def test_gap9_w_ne16_tiled_models_l2_singlebuffer(test_params, deeploy_test_dir, toolchain, toolchain_dir, cmake_args, + skipgen, skipsim) -> None: + test_name, l1, config_name = test_params + + ne16_cmake_args = cmake_args + [f"NUM_CORES={GAP9_NE16_TILED_DEFAULT_CORES}"] + + config = create_test_config( + test_name = test_name, + platform = "GAP9_w_NE16", + simulator = "gvsoc", + deeploy_test_dir = deeploy_test_dir, + toolchain = toolchain, + toolchain_dir = toolchain_dir, + cmake_args = ne16_cmake_args, + tiling = True, + cores = GAP9_NE16_TILED_DEFAULT_CORES, + l1 = l1, + default_mem_level = "L2", + double_buffer = False, + gen_args = ["--enable-3x3"], + ) + run_and_assert_test(test_name, config, skipgen, skipsim) + + +@pytest.mark.gap9_w_ne16_tiled +@pytest.mark.kernels +@pytest.mark.doublebuffer +@pytest.mark.l2 +@pytest.mark.parametrize( + "test_params", + generate_test_params(GAP9_NE16_L2_DOUBLEBUFFER_KERNELS, "L2-doublebuffer"), + ids = param_id, +) +def test_gap9_w_ne16_tiled_kernels_l2_doublebuffer(test_params, deeploy_test_dir, toolchain, toolchain_dir, cmake_args, + skipgen, skipsim) -> None: + test_name, l1, config_name = test_params + + ne16_cmake_args = cmake_args + [f"NUM_CORES={GAP9_NE16_TILED_DEFAULT_CORES}"] + + config = create_test_config( + test_name = test_name, + platform = "GAP9_w_NE16", + simulator = "gvsoc", + deeploy_test_dir = deeploy_test_dir, + toolchain = toolchain, + toolchain_dir = toolchain_dir, + cmake_args = ne16_cmake_args, + tiling = True, + cores = GAP9_NE16_TILED_DEFAULT_CORES, + l1 = l1, + default_mem_level = "L2", + double_buffer = True, + gen_args = ["--enable-3x3"], + ) + run_and_assert_test(test_name, config, skipgen, skipsim) diff --git a/TargetLibraries/GAP9/CMakeLists.txt b/TargetLibraries/GAP9/CMakeLists.txt index ca4c3ffbeb..8051484dd0 100644 --- a/TargetLibraries/GAP9/CMakeLists.txt +++ b/TargetLibraries/GAP9/CMakeLists.txt @@ -80,5 +80,21 @@ endif() target_link_libraries(deeploygap9 PUBLIC pulp-nn-mixed) +# NE16 accelerator (via pulp-nnx) for GAP9_w_NE16 platform +if(platform STREQUAL "GAP9_w_NE16") + set(USE_NE16 ON CACHE BOOL "Use the NE16 accelerator." FORCE) + add_subdirectory(../third_party/pulp-nnx ${CMAKE_CURRENT_BINARY_DIR}/pulp-nnx) + target_link_libraries(pulp-nnx PUBLIC pmsis) + target_compile_options(pulp-nnx PRIVATE + -Wno-error + -Wno-implicit-int-conversion + -Wno-sign-conversion + -Wno-typedef-redefinition + -Wno-unused-parameter + -Wno-incompatible-pointer-types-discards-qualifiers + ) + target_link_libraries(deeploygap9 PUBLIC pulp-nnx) +endif() + target_link_libraries(deeploygap9 PUBLIC m) From 3a8bf481227fc9d1846f1791f27e3afab7d29e19 Mon Sep 17 00:00:00 2001 From: Pu DENG Date: Tue, 14 Apr 2026 10:22:39 +0000 Subject: [PATCH 2/6] [Deeploy PR] NE16 Linear Layer Kernels MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add NE16 linear layer kernels, including a topology pass, NE16 templates, parsers, tile constraints, and bindings - The topology pass recognizes NE16-compatible GEMM layers, adjusts the weight layout for the NE16, and converts the requant shift/scale to the NE16 format - The template detects whether the input is signed; if so, it adds a +128 offset to the input during C runtime and compensates via the bias - Add GAP9 SDK-based Dequant/Quant templates using CNN_Copy.c kernels, replacing the generic templates - Add a generic DequantQuantMergePass that folds adjacent Dequant→Quant pairs into identity or RequantShift - Add a GAP9-specific TopologyOptimizer (GAP9Optimizer) to replace PULPOptimizer Bug fixes: - Add output signedness check in QuantChecker - Fix L3 DMA template (add proper casts) and remove the blocking L3 DMA hack - Isolate dory memory functions from other libraries in CMakeLists so they compile with -Og while compute kernels compile with -O3 - Disable PULPAddRequantMergePass due to incorrect pattern matching when Add has multiple consumers Co-authored-by: runwangdl --- Deeploy/Targets/GAP9/Bindings.py | 42 ++- Deeploy/Targets/GAP9/DMA/L3Dma.py | 9 +- Deeploy/Targets/GAP9/Parsers.py | 37 +++ Deeploy/Targets/GAP9/Platform.py | 72 ++++-- .../Templates/GAP9SDKDequantQuantTemplate.py | 168 ++++++++++++ .../GAP9/Templates/NE16GEMMTemplate.py | 242 ++++++++++++++++++ .../TileConstraints/NE16GEMMTileConstraint.py | 196 ++++++++++++++ Deeploy/Targets/GAP9/Tiler.py | 30 ++- .../GAP9/TopologyOptimizationPasses/Passes.py | 132 ++++++++++ .../TopologyOptimizationPasses/__init__.py | 3 + .../TopologyOptimizationPasses/Passes.py | 82 ++++++ Deeploy/Targets/Generic/TypeCheckers.py | 9 + DeeployTest/Platforms/GAP9/CMakeLists.txt | 1 + DeeployTest/testUtils/platformMapping.py | 4 +- TargetLibraries/GAP9/CMakeLists.txt | 46 +++- TargetLibraries/GAP9/inc/ne16_utils.h | 23 ++ TargetLibraries/GAP9/src/ne16_utils.c | 35 +++ 17 files changed, 1087 insertions(+), 44 deletions(-) create mode 100644 Deeploy/Targets/GAP9/Parsers.py create mode 100644 Deeploy/Targets/GAP9/Templates/GAP9SDKDequantQuantTemplate.py create mode 100644 Deeploy/Targets/GAP9/Templates/NE16GEMMTemplate.py create mode 100644 Deeploy/Targets/GAP9/TileConstraints/NE16GEMMTileConstraint.py create mode 100644 Deeploy/Targets/GAP9/TopologyOptimizationPasses/Passes.py create mode 100644 Deeploy/Targets/GAP9/TopologyOptimizationPasses/__init__.py create mode 100644 TargetLibraries/GAP9/inc/ne16_utils.h create mode 100644 TargetLibraries/GAP9/src/ne16_utils.c diff --git a/Deeploy/Targets/GAP9/Bindings.py b/Deeploy/Targets/GAP9/Bindings.py index 2bda98af8f..ad215b9193 100644 --- a/Deeploy/Targets/GAP9/Bindings.py +++ b/Deeploy/Targets/GAP9/Bindings.py @@ -18,11 +18,12 @@ from Deeploy.DeeployTypes import CodeTransformation, NodeBinding from Deeploy.FutureExtension.Bindings.AutoFutureBinding import AutoFutureBinding from Deeploy.FutureExtension.CodeTransformationPasses.FutureCodeTransformation import FutureGeneration -from Deeploy.Targets.GAP9.DMA.L3Dma import gap9L3DmaHack +from Deeploy.Targets.GAP9.DMA.L3Dma import GAP9L3Dma from Deeploy.Targets.GAP9.DMA.MchanDma import GAP9MchanDma +from Deeploy.Targets.GAP9.Templates import GAP9SDKDequantQuantTemplate, NE16GEMMTemplate # Import templates from PULPOpen and Generic from Deeploy.Targets.Generic.Templates import AddTemplate, ConcatTemplate, DequantTemplate, FloatReduceMeanTemplate, \ - FloatReduceSumTemplate, GatherTemplate, QuantTemplate, RQSiGELUTemplate, SliceTemplate, iHardswishTemplate + FloatReduceSumTemplate, GatherTemplate, RQSiGELUTemplate, SliceTemplate, iHardswishTemplate from Deeploy.Targets.Generic.TypeCheckers import AddChecker, ConcatChecker, ConvChecker, DequantChecker, \ GatherChecker, GELUChecker, GEMMChecker, HardswishChecker, LayerNormChecker, MatMulChecker, MulChecker, \ QuantChecker, ReduceMeanChecker, ReluChecker, ReshapeChecker, RQAddChecker, RQHardswishChecker, SGDChecker, \ @@ -57,7 +58,7 @@ MemoryManagementGeneration("L1"), TilingVariableReplacement("L2"), MemoryAwareFunctionCallClosure(writeback = False, generateStruct = True), - PULPL3Tiling("L3", "L2", gap9L3DmaHack), # Use GAP9-specific L3 DMA + PULPL3Tiling("L3", "L2", GAP9L3Dma()), # Use GAP9-specific L3 DMA PULPProfileUntiled(), ArgumentStructGeneration(), L3MemoryAwareFunctionCallClosure(writeback = False), @@ -76,7 +77,7 @@ MemoryManagementGeneration("L1"), TilingVariableReplacement("L2"), MemoryAwareFunctionCallClosure(writeback = False, generateStruct = True), - PULPL3Tiling("L3", "L2", gap9L3DmaHack), # Use GAP9-specific L3 DMA + PULPL3Tiling("L3", "L2", GAP9L3Dma()), # Use GAP9-specific L3 DMA PULPProfileUntiled(), ArgumentStructGeneration(), L3MemoryAwareFunctionCallClosure(writeback = False), @@ -183,6 +184,26 @@ GAP9Transformer) for type1, type2 in zip([int8_t, uint8_t, int8_t, uint8_t], [int8_t, uint8_t, uint8_t, int8_t]) ] +GAP9NE16RQSGEMMBindings = [ + NodeBinding( + PULPLinearChecker([ + PointerClass(type1), + PointerClass(int8_t), + PointerClass(int32_t), + PointerClass(uint8_t), + PointerClass(uint8_t) + ], [PointerClass(type2)]), NE16GEMMTemplate.referenceTemplate, GAP9ClusterTransformer) + for type1 in [int8_t, uint8_t] + for type2 in [int8_t, uint8_t] +] + +GAP9NE16GEMMInt32Bindings = [ + NodeBinding( + GEMMChecker([PointerClass(type1), PointerClass(int8_t), + PointerClass(int32_t)], [PointerClass(int32_t)]), NE16GEMMTemplate.int32OutputTemplate, + GAP9ClusterTransformer) for type1 in [int8_t, uint8_t] +] + GAP9FloatGEMMBindings = [ NodeBinding( GEMMChecker([PointerClass(float32_t), PointerClass(float32_t), @@ -386,14 +407,17 @@ ] GAP9QuantBindings = [ - NodeBinding(QuantChecker([PointerClass(float32_t)], [PointerClass(int8_t)]), QuantTemplate.referenceTemplate, - GAP9Transformer), + NodeBinding(QuantChecker([PointerClass(float32_t)], [PointerClass(int8_t)]), + GAP9SDKDequantQuantTemplate.fp32QuantI8Template, GAP9Transformer), + NodeBinding(QuantChecker([PointerClass(float32_t)], [PointerClass(uint8_t)]), + GAP9SDKDequantQuantTemplate.fp32QuantU8Template, GAP9Transformer), ] GAP9DequantBindings = [ - NodeBinding(DequantChecker([PointerClass(int8_t)], [PointerClass(float32_t)]), DequantTemplate.referenceTemplate, - GAP9Transformer), -] + [ + NodeBinding(DequantChecker([PointerClass(int8_t)], [PointerClass(float32_t)]), + GAP9SDKDequantQuantTemplate.fp32DequantI8Template, GAP9Transformer), + NodeBinding(DequantChecker([PointerClass(uint8_t)], [PointerClass(float32_t)]), + GAP9SDKDequantQuantTemplate.fp32DequantU8Template, GAP9Transformer), NodeBinding(DequantChecker([PointerClass(int32_t)], [PointerClass(float32_t)]), DequantTemplate.referenceTemplate, GAP9Transformer), ] diff --git a/Deeploy/Targets/GAP9/DMA/L3Dma.py b/Deeploy/Targets/GAP9/DMA/L3Dma.py index adbf161328..aadc5974b9 100644 --- a/Deeploy/Targets/GAP9/DMA/L3Dma.py +++ b/Deeploy/Targets/GAP9/DMA/L3Dma.py @@ -6,8 +6,7 @@ from typing import Dict, Tuple from Deeploy.DeeployTypes import NetworkContext, NodeTemplate, OperatorRepresentation, VariableBuffer -from Deeploy.TilingExtension.AsyncDma import AsyncDma, BlockingDmaFromAsyncDmaAdapter, DmaDirection, Future, \ - PerTensorWaitingStrategy +from Deeploy.TilingExtension.AsyncDma import AsyncDma, DmaDirection, Future, PerTensorWaitingStrategy class GAP9L3DmaFuture(Future): @@ -29,7 +28,7 @@ class GAP9L3Dma(AsyncDma): _transferTemplates = { 2: NodeTemplate( - "pi_cl_ram_copy_2d(get_ram_ptr(), ${ext}, ${loc}, ${transfer_size}, ${stride}, ${length}, ${ext2loc}, &${future});" + "pi_cl_ram_copy_2d(get_ram_ptr(), (uint32_t)${ext}, (void *)${loc}, (uint32_t)${transfer_size}, (uint32_t)${stride}, (uint32_t)${length}, ${ext2loc}, &${future});" ) } _waitingStrategy = PerTensorWaitingStrategy(GAP9L3DmaFuture) @@ -58,7 +57,3 @@ def transferOpRepr(self, externalBuffer: VariableBuffer, localBuffer: VariableBu "stride": strideExt[0], }) return operatorRepresentation - - -# Blocking adapter for L3 DMA (used in GAP9 L3 tiling) -gap9L3DmaHack = BlockingDmaFromAsyncDmaAdapter(GAP9L3Dma()) diff --git a/Deeploy/Targets/GAP9/Parsers.py b/Deeploy/Targets/GAP9/Parsers.py new file mode 100644 index 0000000000..4d730b7cae --- /dev/null +++ b/Deeploy/Targets/GAP9/Parsers.py @@ -0,0 +1,37 @@ +# SPDX-FileCopyrightText: 2025 ETH Zurich and University of Bologna +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import Tuple + +import onnx_graphsurgeon as gs + +from Deeploy.DeeployTypes import NetworkContext +from Deeploy.Targets.Generic.Parsers import GEMMParser, RQSParserInterface + + +class NE16GEMMParser(GEMMParser, RQSParserInterface): + """Parser for NE16 RequantizedGemm nodes with 5 inputs [A, B, C, mul, scale_n].""" + + def __init__(self): + super().__init__(noBiasHoisting = True) + + def parseNode(self, node: gs.Node) -> bool: + ret_rqs = RQSParserInterface.parseNode(self, node) + ret_matmul = GEMMParser.parseNode(self, node) + ret = all([ret_rqs, ret_matmul, 'shift' in node.attrs, len(node.inputs) == 5]) + if ret: + self.operatorRepresentation['shift'] = int(node.attrs['shift'].values) + return ret + + def parseNodeCtxt(self, + ctxt: NetworkContext, + node: gs.Node, + channels_first: bool = True) -> Tuple[NetworkContext, bool]: + newCtxt, ret = GEMMParser.parseNodeCtxt(self, ctxt, node, channels_first) + if ret: + inputs = ['A', 'B', 'C', 'mul', 'scale_n'] + for idx, inputNode in enumerate(node.inputs): + self.operatorRepresentation[inputs[idx]] = newCtxt.lookup(inputNode.name).name + return newCtxt, True + return ctxt, False diff --git a/Deeploy/Targets/GAP9/Platform.py b/Deeploy/Targets/GAP9/Platform.py index bad6f8d859..c482c14dab 100644 --- a/Deeploy/Targets/GAP9/Platform.py +++ b/Deeploy/Targets/GAP9/Platform.py @@ -5,24 +5,29 @@ import numpy as np import onnx_graphsurgeon as gs +from Deeploy.CommonExtensions.OptimizationPasses.TopologyOptimizationPasses.LoweringOptimizationPasses import \ + RemoveEmptyConvBiasPass, RemoveOnlySingletonReduceMeanPass from Deeploy.DeeployTypes import ConstantBuffer, DeploymentEngine, DeploymentPlatform, NetworkContext, NodeMapper, \ - NodeTemplate, StructBuffer, TransientBuffer, VariableBuffer + NodeTemplate, StructBuffer, TopologyOptimizer, TransientBuffer, VariableBuffer from Deeploy.MemoryLevelExtension.MemoryLevels import MemoryHierarchy, MemoryLevel from Deeploy.MemoryLevelExtension.NetworkDeployers.MemoryLevelDeployer import MemoryPlatform, MemoryPlatformWrapper +from Deeploy.Targets.GAP9.Parsers import NE16GEMMParser from Deeploy.Targets.GAP9.Templates import AllocateTemplate, FreeTemplate # Import GAP9-specific tiler bindings -from Deeploy.Targets.GAP9.Tiler import GAP9AddTilingReadyBindings, GAP9ConcatTilingReadyBindings, \ - GAP9Conv2DTilingReadyBindings, GAP9DWConv2DTilingReadyBindings, GAP9FlattenTilingReadyBindings, \ - GAP9FPGELUTilingReadyBindings, GAP9FPGEMMTilingReadyBindings, GAP9GatherTilingReadyBindings, \ - GAP9iHardswishTilingReadyBindings, GAP9iRMSNormTilingReadyBindings, GAP9iRQSGELUTilingReadyBindings, \ - GAP9LayernormTilingReadyBindings, GAP9MatMulTilingReadyBindings, GAP9MaxPool2DTilingReadyBindings, \ - GAP9MulTilingReadyBindings, GAP9ReduceSumTilingReadyBindings, GAP9ReluTilingReadyBindings, \ +from Deeploy.Targets.GAP9.Tiler import DeQuantTilingReadyBindings, GAP9AddTilingReadyBindings, \ + GAP9ConcatTilingReadyBindings, GAP9Conv2DTilingReadyBindings, GAP9DWConv2DTilingReadyBindings, \ + GAP9FlattenTilingReadyBindings, GAP9FPGELUTilingReadyBindings, GAP9FPGEMMTilingReadyBindings, \ + GAP9GatherTilingReadyBindings, GAP9iHardswishTilingReadyBindings, GAP9iRMSNormTilingReadyBindings, \ + GAP9iRQSGELUTilingReadyBindings, GAP9LayernormTilingReadyBindings, GAP9MatMulTilingReadyBindings, \ + GAP9MaxPool2DTilingReadyBindings, GAP9MulTilingReadyBindings, GAP9NE16GEMMInt32TilingReadyBindings, \ + GAP9NE16RQSGEMMTilingReadyBindings, GAP9ReduceSumTilingReadyBindings, GAP9ReluTilingReadyBindings, \ GAP9RQAddTilingReadyBindings, GAP9RQSConv2DTilingReadyBindings, GAP9RQSDWConv2DTilingReadyBindings, \ GAP9RQSGEMMTilingReadyBindings, GAP9RQSiHardswishTilingReadyBindings, GAP9RQSMatrixVecTilingReadyBindings, \ GAP9RQSTallGEMMTilingReadyBindings, GAP9RQSTilingReadyBindings, GAP9SGDTilingReadyBindings, \ GAP9SoftmaxCrossEntropyGradTilingReadyBindings, GAP9SoftmaxCrossEntropyTilingReadyBindings, \ GAP9SoftmaxGradTilingReadyBindings, GAP9SoftmaxTilingReadyBindings, GAP9TransposeTilingReadyBindings, \ - GAP9UniformRQSTilingReadyBindings + GAP9UniformRQSTilingReadyBindings, QuantTilingReadyBindings +from Deeploy.Targets.GAP9.TopologyOptimizationPasses.Passes import NE16AdjustGEMMWeightLayoutPass from Deeploy.Targets.Generic.Bindings import BasicGEMMBindings, BasicPad1DBindings, BasicPad2DBindings, \ BasicRQIntegerDivBinding from Deeploy.Targets.Generic.Layers import AddLayer, ConcatLayer, ConvLayer, GatherLayer, GELULayer, GEMMLayer, \ @@ -37,12 +42,18 @@ SoftmaxCrossEntropyLossGradParser, SoftmaxCrossEntropyLossParser, SoftmaxGradParser, SoftmaxParser, \ TransposeParser, UniformRequantShiftParser, UnsqueezeParser, iHardswishParser, iRMSNormParser, iSoftmaxParser from Deeploy.Targets.Generic.Templates import AllocateTemplate as BasicAllocateTemplate -from Deeploy.Targets.PULPOpen.Bindings import BasicDequantBindings, BasicQuantBindings, PULPDMASliceBindings, \ - PULPDWConv1DBinding, PULPReduceMeanBindings, PULPRQSConv1DBindings, PULPSliceBindings +from Deeploy.Targets.Generic.TopologyOptimizationPasses.Passes import DequantPatternPass, DequantQuantMergePass, \ + IntegerDivRequantMergePass, MatMulAddMergePass, MergeConstAddAndRequantPass, MergeTrueIntegerDivRequantShiftPass, \ + QuantPatternPass, RQSSplitPass, SkipEmptyConcatPass, SkipUnityRequantPass, iGELURequantMergePass, \ + iHardswishRequantMergePass +from Deeploy.Targets.PULPOpen.Bindings import BasicDequantBindings, BasicQuantBindings, PULPConv1DBinding, \ + PULPDMASliceBindings, PULPDWConv1DBinding, PULPReduceMeanBindings, PULPRQSConv1DBindings, PULPSliceBindings from Deeploy.Targets.PULPOpen.Layers import PULPRQSConvLayer, PULPRQSGEMMLayer from Deeploy.Targets.PULPOpen.Parsers import PULPConv1DParser, PULPConv2DParser, PULPDWConv1DParser, \ PULPDWConv2DParser, PULPFPConv2DParser, PULPFPDWConv2DParser, PULPGEMMParser, PULPMatrixVecParser, \ PULPTallGEMMParser +from Deeploy.Targets.PULPOpen.TopologyOptimizationPasses.Passes import PULPConvRequantMergePass, \ + PULPGEMMRequantMergePass, PULPMatMulRequantMergePass # Create GAP9-specific NodeMappers GAP9_RQAddMapper = NodeMapper(RQAddParser(), GAP9RQAddTilingReadyBindings) @@ -90,9 +101,37 @@ GAP9_SoftmaxCrossEntropyLossGradMapper = NodeMapper(SoftmaxCrossEntropyLossGradParser(), GAP9SoftmaxCrossEntropyGradTilingReadyBindings) GAP9_SGDMapper = NodeMapper(SGDParser(), GAP9SGDTilingReadyBindings) -GAP9_QuantMapper = NodeMapper(QuantParser(), BasicQuantBindings) -GAP9_DequantMapper = NodeMapper(DequantParser(), BasicDequantBindings) +GAP9_QuantMapper = NodeMapper(QuantParser(), QuantTilingReadyBindings) +GAP9_DequantMapper = NodeMapper(DequantParser(), DeQuantTilingReadyBindings) GAP9_GEMMDequantMapper = NodeMapper(PULPGEMMParser(), BasicGEMMBindings) +GAP9_NE16GEMMMapper = NodeMapper(NE16GEMMParser(), GAP9NE16RQSGEMMTilingReadyBindings) +GAP9_NE16GEMMInt32Mapper = NodeMapper(GEMMParser(), GAP9NE16GEMMInt32TilingReadyBindings) + +GAP9Optimizer = TopologyOptimizer( + [ + QuantPatternPass(), + DequantPatternPass(), + DequantQuantMergePass(), + MatMulAddMergePass(), + SkipEmptyConcatPass(), + SkipUnityRequantPass(previous_op_regex = "Concat", num_inputs = 2), + SkipUnityRequantPass(previous_op_regex = "Reshape|Transpose", num_inputs = 1), + SkipUnityRequantPass(previous_op_regex = "Reshape|Transpose", num_inputs = 1), + RQSSplitPass(), + MergeTrueIntegerDivRequantShiftPass(), + IntegerDivRequantMergePass(), + iGELURequantMergePass(), + iHardswishRequantMergePass(), + PULPConvRequantMergePass(), + MergeConstAddAndRequantPass(), + PULPGEMMRequantMergePass(), + PULPMatMulRequantMergePass(), + # PULPAddRequantMergePass(), + RemoveEmptyConvBiasPass(), + RemoveOnlySingletonReduceMeanPass(), + NE16AdjustGEMMWeightLayoutPass(), + ], + name = "GAP9Optimizer") # GAP9-specific mapping using ClDma GAP9Mapping = { @@ -101,9 +140,9 @@ 'RequantizedConv': PULPRQSConvLayer([GAP9_Conv2DMapper, GAP9_DWConv2DMapper, GAP9_Conv1DMapper, GAP9_DWConv1DMapper]), 'RequantizedGemm': - PULPRQSGEMMLayer([GAP9_MatrixVecMapper, GAP9_TallGEMMMapper, GAP9_GEMMMapper]), + PULPRQSGEMMLayer([GAP9_NE16GEMMMapper, GAP9_MatrixVecMapper, GAP9_TallGEMMMapper, GAP9_GEMMMapper]), 'Gemm': - GEMMLayer([GAP9_FloatGEMMMapper, GAP9_GEMMDequantMapper]), + GEMMLayer([GAP9_NE16GEMMInt32Mapper, GAP9_FloatGEMMMapper, GAP9_GEMMDequantMapper]), 'Gelu': GELULayer([GAP9_GELUMapper]), 'LayerNormalization': @@ -244,7 +283,10 @@ class GAP9StructBuffer(StructBuffer): deallocTemplate = NodeTemplate("") -_includeList = ["pmsis.h", "DeeployGAP9Math.h", "pulp_nn_kernels.h", "DeeployMchan.h"] +_includeList = [ + "pmsis.h", "DeeployGAP9Math.h", "pulp_nn_kernels.h", "DeeployMchan.h", "CNN_BasicKernels_fp32.h", + "CNN_BasicKernels_NE16.h", "CNN_Copy.h", "ne16_utils.h" +] class GAP9ClusterEngine(DeploymentEngine): diff --git a/Deeploy/Targets/GAP9/Templates/GAP9SDKDequantQuantTemplate.py b/Deeploy/Targets/GAP9/Templates/GAP9SDKDequantQuantTemplate.py new file mode 100644 index 0000000000..cd4374466e --- /dev/null +++ b/Deeploy/Targets/GAP9/Templates/GAP9SDKDequantQuantTemplate.py @@ -0,0 +1,168 @@ +# SPDX-FileCopyrightText: 2025 ETH Zurich and University of Bologna +# +# SPDX-License-Identifier: Apache-2.0 +# +# Quant/Dequant templates using GAP9 SDK kernels (CNN_Copy.c). +# All called via GAP9Transformer which handles pi_cl_team_fork. + +from Deeploy.DeeployTypes import NodeTemplate + +# ============================================================ +# Dequant templates: int → fp16 (SDK kernels from CNN_Copy.c) +# ============================================================ + +# int8 → fp16: SDK kernel CNN_FpsIEEE16 +fp16DequantI8Template = NodeTemplate(""" +// FP16 Dequant int8→fp16 (Name: ${nodeName}, Op: ${nodeOp}) +{ + signed char _dq_infos[8]; + *((float *)(_dq_infos + 0)) = (float)(-(${zero_point})); + *((float *)(_dq_infos + 4)) = (float)(${scale}); + CNN_Quantize_T _dq_arg = { + .In = (void *)${data_in}, + .Out = (void *)${data_out}, + .W = ${size}, + .H = 1, + .Infos = _dq_infos, + }; + CNN_FpsIEEE16(&_dq_arg); +} +""") + +# uint8 → fp16: SDK kernel CNN_UFpsIEEE16 +fp16DequantU8Template = NodeTemplate(""" +// FP16 Dequant uint8→fp16 (Name: ${nodeName}, Op: ${nodeOp}) +{ + signed char _dq_infos[8]; + *((float *)(_dq_infos + 0)) = (float)(-(${zero_point})); + *((float *)(_dq_infos + 4)) = (float)(${scale}); + CNN_Quantize_T _dq_arg = { + .In = (void *)${data_in}, + .Out = (void *)${data_out}, + .W = ${size}, + .H = 1, + .Infos = _dq_infos, + }; + CNN_UFpsIEEE16(&_dq_arg); +} +""") + +# ============================================================ +# Dequant templates: int → fp32 (SDK kernels from CNN_Copy.c) +# ============================================================ + +# int8 → fp32: SDK kernel CNN_FpsFloat32 +fp32DequantI8Template = NodeTemplate(""" +// FP32 Dequant int8→fp32 (Name: ${nodeName}, Op: ${nodeOp}) +{ + signed char _dq_infos[8]; + *((float *)(_dq_infos + 0)) = (float)(-(${zero_point})); + *((float *)(_dq_infos + 4)) = (float)(${scale}); + CNN_FpsFloat32_T _dq_arg = { + .In = (signed char *)${data_in}, + .Out = (float *)${data_out}, + .W = ${size}, + .H = 1, + .Infos = _dq_infos, + }; + CNN_FpsFloat32(&_dq_arg); +} +""") + +# uint8 → fp32: SDK kernel CNN_UFpsFloat32 +fp32DequantU8Template = NodeTemplate(""" +// FP32 Dequant uint8→fp32 (Name: ${nodeName}, Op: ${nodeOp}) +{ + signed char _dq_infos[8]; + *((float *)(_dq_infos + 0)) = (float)(-(${zero_point})); + *((float *)(_dq_infos + 4)) = (float)(${scale}); + CNN_UFpsFloat32_T _dq_arg = { + .In = (unsigned char *)${data_in}, + .Out = (float *)${data_out}, + .W = ${size}, + .H = 1, + .Infos = _dq_infos, + }; + CNN_UFpsFloat32(&_dq_arg); +} +""") + +# ============================================================ +# Quant templates: fp16 → int (SDK kernels from CNN_Copy.c) +# ============================================================ + +# fp16 → int8: SDK kernel CNN_IEEE16Fps +fp16QuantI8Template = NodeTemplate(""" +// FP16 Quant fp16→int8 (Name: ${nodeName}, Op: ${nodeOp}) +{ + signed char _q_infos[8]; + *((float *)(_q_infos + 0)) = (float)(${zero_point}); + *((float *)(_q_infos + 4)) = (float)(${scale}); + CNN_Quantize_T _q_arg = { + .In = (void *)${data_in}, + .Out = (void *)${data_out}, + .W = ${size}, + .H = 1, + .Infos = _q_infos, + }; + CNN_IEEE16Fps(&_q_arg); +} +""") + +# fp16 → uint8: SDK kernel CNN_IEEE16UFps +fp16QuantU8Template = NodeTemplate(""" +// FP16 Quant fp16→uint8 (Name: ${nodeName}, Op: ${nodeOp}) +{ + signed char _q_infos[8]; + *((float *)(_q_infos + 0)) = (float)(${zero_point}); + *((float *)(_q_infos + 4)) = (float)(${scale}); + CNN_Quantize_T _q_arg = { + .In = (void *)${data_in}, + .Out = (void *)${data_out}, + .W = ${size}, + .H = 1, + .Infos = _q_infos, + }; + CNN_IEEE16UFps(&_q_arg); +} +""") + +# ============================================================ +# Quant templates: fp32 → int (SDK kernels from CNN_Copy.c) +# ============================================================ + +# fp32 → int8: SDK kernel CNN_Float32Fps +fp32QuantI8Template = NodeTemplate(""" +// FP32 Quant fp32→int8 (Name: ${nodeName}, Op: ${nodeOp}) +{ + signed char _q_infos[8]; + *((float *)(_q_infos + 0)) = (float)(${zero_point}); + *((float *)(_q_infos + 4)) = (float)(${scale}); + CNN_Float32Fps_T _q_arg = { + .In = (float *)${data_in}, + .Out = (signed char *)${data_out}, + .W = ${size}, + .H = 1, + .Infos = _q_infos, + }; + CNN_Float32Fps(&_q_arg); +} +""") + +# fp32 → uint8: SDK kernel CNN_Float32UFps +fp32QuantU8Template = NodeTemplate(""" +// FP32 Quant fp32→uint8 (Name: ${nodeName}, Op: ${nodeOp}) +{ + signed char _q_infos[8]; + *((float *)(_q_infos + 0)) = (float)(${zero_point}); + *((float *)(_q_infos + 4)) = (float)(${scale}); + CNN_Float32UFps_T _q_arg = { + .In = (float *)${data_in}, + .Out = (unsigned char *)${data_out}, + .W = ${size}, + .H = 1, + .Infos = _q_infos, + }; + CNN_Float32UFps(&_q_arg); +} +""") diff --git a/Deeploy/Targets/GAP9/Templates/NE16GEMMTemplate.py b/Deeploy/Targets/GAP9/Templates/NE16GEMMTemplate.py new file mode 100644 index 0000000000..809d03d9c5 --- /dev/null +++ b/Deeploy/Targets/GAP9/Templates/NE16GEMMTemplate.py @@ -0,0 +1,242 @@ +# SPDX-FileCopyrightText: 2025 ETH Zurich and University of Bologna +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import Dict, List, Tuple + +import numpy as np + +from Deeploy.DeeployTypes import NetworkContext, NodeTemplate, OperatorRepresentation + + +def _ne16_conv_1x1_weight_layout(W, w_bits = 8): + """Pack int8 weights for NE16 1x1 conv mode. + W: int8 [Ko, Ki] -> uint8 [Ko, Nb_KI, Qw, 2] (bitplane packed) + Weights stored as uint8 = int8 + 128. + """ + tp_in = 16 + Ko_, Ki_ = W.shape + W_uint8 = (W.astype(np.int32) + 128).astype(np.uint8) + nb_ki = (Ki_ + tp_in - 1) // tp_in + w_binary = np.zeros((Ko_ * nb_ki, w_bits, 8, tp_in // 8), dtype = np.uint8) + for ko in range(Ko_): + for ki_maj in range(nb_ki): + for ki_min in range(tp_in): + idx = ko * nb_ki + ki_maj + ki = ki_maj * tp_in + ki_min + val = int(W_uint8[ko, ki]) if ki < Ki_ else 0 + for q in range(w_bits): + w_binary[idx, q, ki_min % 8, ki_min // 8] = (val >> q) & 1 + space = np.logspace(0, 7, num = 8, base = 2, dtype = np.int32).reshape((8, 1)) + w_layout = np.sum(w_binary * space, axis = 2, dtype = np.uint8) + return w_layout.reshape((Ko_, nb_ki, w_bits, tp_in // 8)) + + +class NE16GEMMTemplate(NodeTemplate): + + def __init__(self, templateStr): + super().__init__(templateStr) + + def alignToContext(self, ctxt: NetworkContext, + operatorRepresentation: OperatorRepresentation) -> Tuple[NetworkContext, Dict, List[str]]: + + A = ctxt.lookup(operatorRepresentation['A']) + B = ctxt.lookup(operatorRepresentation['B']) + C = ctxt.lookup(operatorRepresentation['C']) + data_out = ctxt.lookup(operatorRepresentation['data_out']) + + # Determine signedness from type system (reliable, post-type-inference) + input_signed = A._type.referencedType.typeMin < 0 + output_bits = data_out._type.referencedType.typeWidth + output_signed = data_out._type.referencedType.typeMin < 0 + + operatorRepresentation['input_signed'] = input_signed + operatorRepresentation['output_bits'] = output_bits + operatorRepresentation['quant_bits'] = 2 if output_bits == 32 else 0 + operatorRepresentation['quant_norect'] = 1 if (output_bits == 32 or output_signed) else 0 + + # Weight packing and signed bias compensation + w_int8 = B.values.astype(np.int8) # [Ko, Ki], still int8 at this point + Ko, Ki = w_int8.shape + + # Truncate broadcast bias [M, O] → per-channel [Ko] for NE16 + bias_flat = C.values.flatten() + if bias_flat.size > Ko: + C.values = bias_flat[:Ko].copy() + + # Signed input bias compensation + if input_signed: + # Compute w_sum BEFORE packing (needed for signed bias compensation) + w_sum = w_int8.astype(np.int64).sum(axis = 1) # [Ko] + bias_values = C.values.flatten().astype(np.int64) + if 'mul' in operatorRepresentation: + # RequantizedGemm: bias -= 128 * w_sum * scale + scale_buf = ctxt.lookup(operatorRepresentation['mul']) + scale_values = scale_buf.values.flatten().astype(np.int64) + bias_values -= 128 * w_sum * scale_values + else: + # Gemm int32: bias -= 128 * w_sum (no scale) + bias_values -= 128 * w_sum + C.values = bias_values.astype(np.int32) + + # Pack weights to NE16 bitplane format + ne16_weights = _ne16_conv_1x1_weight_layout(w_int8) + B.values = ne16_weights.reshape(Ko, -1) + + return ctxt, operatorRepresentation, [] + + +# 8-bit output template (RequantizedGemm) — uses tiled ${mul} and ${scale_n} +referenceTemplate = NE16GEMMTemplate(""" +// NE16 Linear 8-bit (Name: ${nodeName}, Op: ${nodeOp}) + +% if input_signed: +// Signed input: add 128 offset to convert int8 -> uint8 (multi-core SIMD) +{ + ne16_int8_to_uint8_T _offset_arg = { + .In = (int8_t *)${A}, + .Out = (uint8_t *)${A}, + .size = ${batch} * ${M} * ${N} + }; + pi_cl_team_fork(NUM_CORES, (void *)ne16_int8_to_uint8, &_offset_arg); +} +% endif + +{ + unsigned int _ne16_cfg = 0; + _ne16_cfg |= ((8 - 1) & NE16_MASK_WBITS_M1) << NE16_SHIFT_WBITS_M1; + _ne16_cfg |= (0 & NE16_MASK_MODE16) << NE16_SHIFT_MODE16; + _ne16_cfg |= (1 & NE16_MASK_OUTQUANT) << NE16_SHIFT_OUTQUANT; + _ne16_cfg |= (NE16_FILTER_MODE_1x1 & NE16_MASK_FILTER_MODE) << NE16_SHIFT_FILTER_MODE; + _ne16_cfg |= (0 & NE16_MASK_LINEAR_MODE) << NE16_SHIFT_LINEAR_MODE; + _ne16_cfg |= (0 & NE16_MASK_STRIDED_MODE) << NE16_SHIFT_STRIDED_MODE; + _ne16_cfg |= (NE16_BITS_8BIT & NE16_MASK_NORM_BITS) << NE16_SHIFT_NORM_BITS; + _ne16_cfg |= (0 & NE16_MASK_STREAMIN) << NE16_SHIFT_STREAMIN; + _ne16_cfg |= (1 & NE16_MASK_WEIGHT_OFFSET_CFG) << NE16_SHIFT_WEIGHT_OFFSET_CFG; + _ne16_cfg |= (0 & NE16_MASK_QUANT_RIGHT_SHIFT) << NE16_SHIFT_QUANT_RIGHT_SHIFT; + _ne16_cfg |= (${quant_bits} & NE16_MASK_QUANT_BITS) << NE16_SHIFT_QUANT_BITS; + _ne16_cfg |= (${quant_norect} & NE16_MASK_QUANT_NORECT) << NE16_SHIFT_QUANT_NORECT; + _ne16_cfg |= (1 & NE16_MASK_NORM_SHIFT) << NE16_SHIFT_NORM_SHIFT; + _ne16_cfg |= (1 & NE16_MASK_NORM_BIAS) << NE16_SHIFT_NORM_BIAS; + + NE16_Enable(); + NE16_SoftReset(); + + KerConv_NE16_T _ne16_arg = { + .In = (void *)${A}, + .Filter = (unsigned short *)${B}, + .Bias = (int *)${C}, + .Out = (void *)${data_out}, + .Scale = (unsigned char *)${mul}, + .ScaleN = (unsigned char *)${scale_n}, + .Tile_InFeat = ${N}, + .TotalInFeatures = ${N}, + .Tile_InH = 1, + .Tile_InW = ${batch} * ${M}, + .Tile_OutFeat = ${O}, + .Tile_OutH = 1, + .Tile_OutW = ${batch} * ${M}, + .FilterSize = 1, + .Pad_Val = 0, + .Pad = (v4s){0, 0, 0, 0}, + .W_Offset = -128, + .Qw = 8, + .Mode16 = 0, + .FirstD0 = 1, + .LastD0 = 1, + .Default_NE16_Job_Cfg = _ne16_cfg, + .Fx = 1, + .Fy = 1, + .Sx = 1, + .Sy = 1, + .Dx = 1, + .Dy = 1, + .BuffOut = NULL, + .Infos = NULL, + .Extra = NULL, + }; + KerConv1x1_SmallHW_Stride1_NE16(&_ne16_arg); + + NE16_Disable(); +} +""") + +# Int32 output template (plain Gemm) — hardcoded scale=1, scale_n=0 +int32OutputTemplate = NE16GEMMTemplate(""" +// NE16 Linear Int32 (Name: ${nodeName}, Op: ${nodeOp}) + +% if input_signed: +// Signed input: add 128 offset to convert int8 -> uint8 (multi-core SIMD) +{ + ne16_int8_to_uint8_T _offset_arg = { + .In = (int8_t *)${A}, + .Out = (uint8_t *)${A}, + .size = ${batch} * ${M} * ${N} + }; + pi_cl_team_fork(NUM_CORES, (void *)ne16_int8_to_uint8, &_offset_arg); +} +% endif + +{ + unsigned char _ne16_ones[${O}]; + unsigned char _ne16_zeros[${O}]; + memset(_ne16_ones, 1, ${O}); + memset(_ne16_zeros, 0, ${O}); + + unsigned int _ne16_cfg = 0; + _ne16_cfg |= ((8 - 1) & NE16_MASK_WBITS_M1) << NE16_SHIFT_WBITS_M1; + _ne16_cfg |= (0 & NE16_MASK_MODE16) << NE16_SHIFT_MODE16; + _ne16_cfg |= (1 & NE16_MASK_OUTQUANT) << NE16_SHIFT_OUTQUANT; + _ne16_cfg |= (NE16_FILTER_MODE_1x1 & NE16_MASK_FILTER_MODE) << NE16_SHIFT_FILTER_MODE; + _ne16_cfg |= (0 & NE16_MASK_LINEAR_MODE) << NE16_SHIFT_LINEAR_MODE; + _ne16_cfg |= (0 & NE16_MASK_STRIDED_MODE) << NE16_SHIFT_STRIDED_MODE; + _ne16_cfg |= (NE16_BITS_8BIT & NE16_MASK_NORM_BITS) << NE16_SHIFT_NORM_BITS; + _ne16_cfg |= (0 & NE16_MASK_STREAMIN) << NE16_SHIFT_STREAMIN; + _ne16_cfg |= (1 & NE16_MASK_WEIGHT_OFFSET_CFG) << NE16_SHIFT_WEIGHT_OFFSET_CFG; + _ne16_cfg |= (0 & NE16_MASK_QUANT_RIGHT_SHIFT) << NE16_SHIFT_QUANT_RIGHT_SHIFT; + _ne16_cfg |= (${quant_bits} & NE16_MASK_QUANT_BITS) << NE16_SHIFT_QUANT_BITS; + _ne16_cfg |= (${quant_norect} & NE16_MASK_QUANT_NORECT) << NE16_SHIFT_QUANT_NORECT; + _ne16_cfg |= (1 & NE16_MASK_NORM_SHIFT) << NE16_SHIFT_NORM_SHIFT; + _ne16_cfg |= (1 & NE16_MASK_NORM_BIAS) << NE16_SHIFT_NORM_BIAS; + + NE16_Enable(); + NE16_SoftReset(); + + KerConv_NE16_T _ne16_arg = { + .In = (void *)${A}, + .Filter = (unsigned short *)${B}, + .Bias = (int *)${C}, + .Out = (void *)${data_out}, + .Scale = _ne16_ones, + .ScaleN = _ne16_zeros, + .Tile_InFeat = ${N}, + .TotalInFeatures = ${N}, + .Tile_InH = 1, + .Tile_InW = ${batch} * ${M}, + .Tile_OutFeat = ${O}, + .Tile_OutH = 1, + .Tile_OutW = ${batch} * ${M}, + .FilterSize = 1, + .Pad_Val = 0, + .Pad = (v4s){0, 0, 0, 0}, + .W_Offset = -128, + .Qw = 8, + .Mode16 = 0, + .FirstD0 = 1, + .LastD0 = 1, + .Default_NE16_Job_Cfg = _ne16_cfg, + .Fx = 1, + .Fy = 1, + .Sx = 1, + .Sy = 1, + .Dx = 1, + .Dy = 1, + .BuffOut = NULL, + .Infos = NULL, + .Extra = NULL, + }; + KerConv1x1_SmallHW_Stride1_NE16(&_ne16_arg); + + NE16_Disable(); +} +""") diff --git a/Deeploy/Targets/GAP9/TileConstraints/NE16GEMMTileConstraint.py b/Deeploy/Targets/GAP9/TileConstraints/NE16GEMMTileConstraint.py new file mode 100644 index 0000000000..6f490b9803 --- /dev/null +++ b/Deeploy/Targets/GAP9/TileConstraints/NE16GEMMTileConstraint.py @@ -0,0 +1,196 @@ +# SPDX-FileCopyrightText: 2025 ETH Zurich and University of Bologna +# +# SPDX-License-Identifier: Apache-2.0 + +import math +from typing import Dict, List, Tuple + +from Deeploy.AbstractDataTypes import PointerClass +from Deeploy.CommonExtensions.DataTypes import uint8_t, uint16_t +from Deeploy.DeeployTypes import NetworkContext, OperatorRepresentation +from Deeploy.TilingExtension.MemoryConstraints import NodeMemoryConstraint +from Deeploy.TilingExtension.TileConstraint import TileConstraint +from Deeploy.TilingExtension.TilerModel import PerformanceHint, TilerModel +from Deeploy.TilingExtension.TilingCodegen import AbsoluteHyperRectangle, HyperRectangle, TilingSchedule, \ + VariableReplacementScheme + + +class NE16GEMMTileConstraint(TileConstraint): + """Tile constraint for NE16 GEMM with bitplane-packed weights stored as 2D [Ko, Ki].""" + + @staticmethod + def addGeometricalConstraint(tilerModel: TilerModel, parseDict: Dict, ctxt: NetworkContext) -> TilerModel: + + bufferA = ctxt.lookup(name = parseDict['A']) + bufferB = ctxt.lookup(name = parseDict['B']) + bufferC = ctxt.lookup(name = parseDict['C']) + outputBuffer = ctxt.lookup(name = parseDict['data_out']) + + bufferNames = [bufferA.name, bufferB.name, bufferC.name, outputBuffer.name] + hasMul = 'mul' in parseDict and isinstance(parseDict['mul'], str) + if hasMul: + mulBuffer = ctxt.lookup(name = parseDict['mul']) + bufferNames.append(mulBuffer.name) + hasScaleN = 'scale_n' in parseDict and isinstance(parseDict['scale_n'], str) + if hasScaleN: + scaleNBuffer = ctxt.lookup(name = parseDict['scale_n']) + bufferNames.append(scaleNBuffer.name) + + for bufferName in bufferNames: + tilerModel.addTensorDimToModel(ctxt, bufferName) + + dimOffsetA = len(bufferA.shape) - 2 + dimOffsetB = len(bufferB.shape) - 2 + dimOffsetOut = len(outputBuffer.shape) - 2 + + AFirstDimVar = tilerModel.getTensorDimVar(tensorName = bufferA.name, dimIdx = dimOffsetA + parseDict['transA']) + ASecondDimVar = tilerModel.getTensorDimVar(tensorName = bufferA.name, + dimIdx = dimOffsetA + 1 - parseDict['transA']) + BFirstDimVar = tilerModel.getTensorDimVar(tensorName = bufferB.name, dimIdx = dimOffsetB + parseDict['transB']) + BSecondDimVar = tilerModel.getTensorDimVar(tensorName = bufferB.name, + dimIdx = dimOffsetB + 1 - parseDict['transB']) + outputFirstDimVar = tilerModel.getTensorDimVar(tensorName = outputBuffer.name, dimIdx = dimOffsetOut) + outputSecondDimVar = tilerModel.getTensorDimVar(tensorName = outputBuffer.name, dimIdx = dimOffsetOut + 1) + + tilerModel.addConstraint(outputFirstDimVar == AFirstDimVar) + tilerModel.addConstraint(outputSecondDimVar == BSecondDimVar) + tilerModel.addConstraint(ASecondDimVar == BFirstDimVar) + + addDimVar = tilerModel.getTensorDimVar(tensorName = bufferC.name, dimIdx = 0) + tilerModel.addConstraint(outputSecondDimVar == addDimVar) + + if hasMul: + mulDimVar = tilerModel.getTensorDimVar(tensorName = mulBuffer.name, dimIdx = 0) + tilerModel.addConstraint(outputSecondDimVar == mulDimVar) + + if hasScaleN: + scaleNDimVar = tilerModel.getTensorDimVar(tensorName = scaleNBuffer.name, dimIdx = 0) + tilerModel.addConstraint(outputSecondDimVar == scaleNDimVar) + + return tilerModel + + @staticmethod + def addPolicyConstraint(tilerModel: TilerModel, parseDict: Dict, ctxt: NetworkContext) -> TilerModel: + + bufferA = ctxt.lookup(name = parseDict['A']) + bufferB = ctxt.lookup(name = parseDict['B']) + + dimOffsetA = len(bufferA.shape) - 2 + dimOffsetB = len(bufferB.shape) - 2 + + # Don't tile N (reduction dimension) — NE16 needs full input channels + ASecondDimVar = tilerModel.getTensorDimVar(tensorName = bufferA.name, + dimIdx = dimOffsetA + 1 - parseDict['transA']) + BFirstDimVar = tilerModel.getTensorDimVar(tensorName = bufferB.name, dimIdx = dimOffsetB + parseDict['transB']) + tilerModel.addConstraint(ASecondDimVar == parseDict['N']) + tilerModel.addConstraint(BFirstDimVar == parseDict['N']) + + # O (output channels) should be divisible by 32 (NE16 TP_OUT) + BSecondDimVar = tilerModel.getTensorDimVar(tensorName = bufferB.name, + dimIdx = dimOffsetB + 1 - parseDict['transB']) + if parseDict["O"] > 32: + tilerModel.addTileSizeDivisibleConstraint(parseDict, + 'O', + BSecondDimVar, + 32, + strategy = PerformanceHint(priority = 1)) + + return tilerModel + + @classmethod + def serializeTilingSolution( + cls, tilingSolution: NodeMemoryConstraint, absoluteOutputCubes: List[AbsoluteHyperRectangle], + targetMemLevel: str, ctxt: NetworkContext, + operatorRepresentation: OperatorRepresentation) -> Tuple[VariableReplacementScheme, TilingSchedule]: + outputCubes = [cube.rectangle for cube in absoluteOutputCubes] + + hasMul = 'mul' in operatorRepresentation and isinstance(operatorRepresentation['mul'], str) + hasScaleN = 'scale_n' in operatorRepresentation and isinstance(operatorRepresentation['scale_n'], str) + addrNames = ['A', 'B', 'C', 'data_out'] + if hasMul: + addrNames.insert(2, 'mul') + if hasScaleN: + addrNames.insert(-1, 'scale_n') + inputBaseOffsets, outputBaseOffsets = cls.extractBaseAddr(tilingSolution, targetMemLevel, + operatorRepresentation, addrNames) + transA = operatorRepresentation['transA'] + transB = operatorRepresentation['transB'] + + buffA = ctxt.lookup(operatorRepresentation['A']) + buffB = ctxt.lookup(operatorRepresentation['B']) + + NSize = buffA.shape[-1] + + inputACubes = [] + inputBCubes = [] + inputMulCubes = [] + inputAddCubes = [] + + replacements = {"M": [], "O": [], "batch": []} + + for cube in outputCubes: + MOffset, OOffset = cube.offset[-2:] + MSize, OSize = cube.dims[-2:] + + if len(cube.offset) > 2: + BatchSize = math.prod(cube.dims[:-2]) + else: + BatchSize = 1 + + replacements["M"].append(MSize) + replacements["O"].append(OSize) + replacements["batch"].append(BatchSize) + + if transA == 0: + AMatrixOffsets = (MOffset, 0) + AMatrixShape = (MSize, NSize) + else: + AMatrixOffsets = (0, MOffset) + AMatrixShape = (NSize, MSize) + + if len(buffA.shape) > 2: + batchDimCount = len(buffA.shape) - 2 + AMatrixOffsets = tuple(cube.offset[:-2][-batchDimCount:]) + AMatrixOffsets + AMatrixShape = tuple(cube.dims[:-2][-batchDimCount:]) + AMatrixShape + + inputACubes.append(HyperRectangle(AMatrixOffsets, AMatrixShape)) + + if transB == 0: + BMatrixOffsets = (0, OOffset) + BMatrixShape = (NSize, OSize) + else: + BMatrixOffsets = (OOffset, 0) + BMatrixShape = (OSize, NSize) + + inputBCubes.append(HyperRectangle(BMatrixOffsets, BMatrixShape)) + + RequantCube = HyperRectangle((OOffset,), (OSize,)) + inputMulCubes.append(RequantCube) + inputAddCubes.append(RequantCube) + + replacements["N"] = [NSize] * len(outputCubes) + + replacementTypes = { + "M": PointerClass(uint16_t), + "N": PointerClass(uint16_t), + "O": PointerClass(uint16_t), + "batch": PointerClass(uint8_t) + } + + inputLoadSchedule = [] + outputLoadSchedule = [] + + for idx, (a, b, c) in enumerate(zip(inputACubes, inputBCubes, inputAddCubes)): + load = {"A": a, "B": b, "C": c} + if hasMul: + load["mul"] = inputMulCubes[idx] + if hasScaleN: + load["scale_n"] = inputMulCubes[idx] # same per-channel slice as mul/C + inputLoadSchedule.append(load) + + for out in outputCubes: + outputLoadSchedule.append({"data_out": out}) + + schedule = TilingSchedule(inputBaseOffsets, outputBaseOffsets, inputLoadSchedule, outputLoadSchedule) + + return VariableReplacementScheme(replacements, replacementTypes), schedule diff --git a/Deeploy/Targets/GAP9/Tiler.py b/Deeploy/Targets/GAP9/Tiler.py index fefe12b6d7..b93aacb9db 100644 --- a/Deeploy/Targets/GAP9/Tiler.py +++ b/Deeploy/Targets/GAP9/Tiler.py @@ -10,14 +10,16 @@ import copy -from Deeploy.Targets.GAP9.Bindings import GAP9AddBindings, GAP9ConcatBindings, GAP9FloatConv2DBindings, \ - GAP9FloatDWConv2DBindings, GAP9FloatGELUBinding, GAP9FloatGEMMBindings, GAP9GatherBindings, \ - GAP9iHardswishBindings, GAP9iRMSNormBindings, GAP9iRQSGELUBindings, GAP9LayernormBinding, GAP9MatMulBindings, \ - GAP9MaxPool2DBindings, GAP9MulBindings, GAP9ReduceSumBindings, GAP9ReluBinding, GAP9ReshapeBindings, \ - GAP9RQAddBindings, GAP9RQSBindings, GAP9RQSConv2DBindings, GAP9RQSDWConv2DBindings, GAP9RQSGEMMBindings, \ - GAP9RQSiHardswishBindings, GAP9RQSMatrixVecBindings, GAP9RQSTallGEMMBindings, GAP9SGDBindings, \ - GAP9SoftmaxBindings, GAP9SoftmaxCrossEntropyLossBindings, GAP9SoftmaxCrossEntropyLossGradBindings, \ - GAP9SoftmaxGradBindings, GAP9TransposeBindings, GAP9UniformRQSBindings +from Deeploy.Targets.GAP9.Bindings import GAP9AddBindings, GAP9ConcatBindings, GAP9DequantBindings, \ + GAP9FloatConv2DBindings, GAP9FloatDWConv2DBindings, GAP9FloatGELUBinding, GAP9FloatGEMMBindings, \ + GAP9GatherBindings, GAP9iHardswishBindings, GAP9iRMSNormBindings, GAP9iRQSGELUBindings, GAP9LayernormBinding, \ + GAP9MatMulBindings, GAP9MaxPool2DBindings, GAP9MulBindings, GAP9NE16GEMMInt32Bindings, GAP9NE16RQSGEMMBindings, \ + GAP9QuantBindings, GAP9ReduceSumBindings, GAP9ReluBinding, GAP9ReshapeBindings, GAP9RQAddBindings, \ + GAP9RQSBindings, GAP9RQSConv2DBindings, GAP9RQSDWConv2DBindings, GAP9RQSGEMMBindings, GAP9RQSiHardswishBindings, \ + GAP9RQSMatrixVecBindings, GAP9RQSTallGEMMBindings, GAP9SGDBindings, GAP9SoftmaxBindings, \ + GAP9SoftmaxCrossEntropyLossBindings, GAP9SoftmaxCrossEntropyLossGradBindings, GAP9SoftmaxGradBindings, \ + GAP9TransposeBindings, GAP9UniformRQSBindings +from Deeploy.Targets.GAP9.TileConstraints.NE16GEMMTileConstraint import NE16GEMMTileConstraint from Deeploy.Targets.Generic.TileConstraints.AddTileConstraint import AddTileConstraint from Deeploy.Targets.Generic.TileConstraints.ConcatTileConstraint import ConcatTileConstraint from Deeploy.Targets.Generic.TileConstraints.iHardswishTileConstraint import iHardswishTileConstraint @@ -60,6 +62,12 @@ GAP9RQSGEMMTilingReadyBindings = TilingReadyNodeBindings(nodeBindings = GAP9RQSGEMMBindings, tileConstraint = GEMMTileConstraint()) +GAP9NE16RQSGEMMTilingReadyBindings = TilingReadyNodeBindings(nodeBindings = GAP9NE16RQSGEMMBindings, + tileConstraint = NE16GEMMTileConstraint()) + +GAP9NE16GEMMInt32TilingReadyBindings = TilingReadyNodeBindings(nodeBindings = GAP9NE16GEMMInt32Bindings, + tileConstraint = NE16GEMMTileConstraint()) + GAP9FPGEMMTilingReadyBindings = TilingReadyNodeBindings(nodeBindings = GAP9FloatGEMMBindings, tileConstraint = FloatGEMMTileConstraint()) @@ -142,3 +150,9 @@ GAP9SGDTilingReadyBindings = TilingReadyNodeBindings(nodeBindings = GAP9SGDBindings, tileConstraint = SGDTileConstraint()) + +QuantTilingReadyBindings = TilingReadyNodeBindings(nodeBindings = GAP9QuantBindings, + tileConstraint = UnaryTileConstraint()) + +DeQuantTilingReadyBindings = TilingReadyNodeBindings(nodeBindings = GAP9DequantBindings, + tileConstraint = UnaryTileConstraint()) diff --git a/Deeploy/Targets/GAP9/TopologyOptimizationPasses/Passes.py b/Deeploy/Targets/GAP9/TopologyOptimizationPasses/Passes.py new file mode 100644 index 0000000000..d5c8df4f9a --- /dev/null +++ b/Deeploy/Targets/GAP9/TopologyOptimizationPasses/Passes.py @@ -0,0 +1,132 @@ +# SPDX-FileCopyrightText: 2025 ETH Zurich and University of Bologna +# +# SPDX-License-Identifier: Apache-2.0 + +import math + +import numpy as np +import onnx_graphsurgeon as gs + +from Deeploy.CommonExtensions.OptimizationPasses.Matchers import Match, NonBranchingMatcher +from Deeploy.CommonExtensions.OptimizationPasses.PassClasses import ReplaceSequentialPatternPass, contextagnostic + + +def _compute_ne16_scale_shift(mul_values, log2D): + """Convert Deeploy's mul/log2D to NE16's per-channel scale/scale_n.""" + Ko = len(mul_values) + ne16_scale = np.zeros(Ko, dtype = np.uint8) + ne16_scale_n = np.zeros(Ko, dtype = np.uint8) + for ko in range(Ko): + sf = float(mul_values[ko]) / float(2**log2D) + if sf >= 1.0: + sn = 0 + sc = min(255, max(1, int(round(sf)))) + elif sf > 0: + sn = min(31, max(0, int(math.floor(math.log2(127.0 / sf))))) + sc = min(255, max(1, int(round(sf * (1 << sn))))) + else: + sn = 0 + sc = 0 + ne16_scale[ko] = sc + ne16_scale_n[ko] = sn + return ne16_scale, ne16_scale_n + + +def _ne16_adjust_gemm_weight_layout_fun(graph: gs.Graph, match: Match, name: str): + """Prepare GEMM node for NE16 execution. + + Handles transB normalization, scale/scale_n computation, and bias rescaling. + Weight bitplane packing and signed bias compensation are deferred to alignToContext + where input signedness is known from the type system. + """ + matched_nodes = list(match.nodes_map.values()) + node = matched_nodes[0] + + # Weight is input[1] for both Gemm and RequantizedGemm + weightTensor = node.inputs[1] + + if not isinstance(weightTensor, gs.Constant): + return graph + + values = weightTensor.values + + # Skip true float weights (Deeploy stores int8 weights as float32) + if not np.array_equal(values, np.round(values)): + return graph + + # Check shape is 2D + if len(values.shape) != 2: + return graph + + # Determine actual Ko, Ki based on transB + transB = node.attrs.get('transB', 0) + if transB: + Ko, Ki = values.shape + else: + Ki, Ko = values.shape + + # Check NE16 compatibility BEFORE modifying the node + if Ki % 16 != 0: + return graph + + # Transpose weight to [Ko, Ki] if needed — keep as int8 + if not transB: + transposed = values.T.astype(np.int8) + newWeightTensor = gs.Constant(f"{name}_{weightTensor.name}", transposed) + node.inputs[1] = newWeightTensor + node.attrs['transB'] = 1 + + # For RequantizedGemm: transform mul → ne16_scale, create scale_n, rescale bias + if node.op == 'RequantizedGemm' and len(node.inputs) >= 4: + mulTensor = node.inputs[3] + biasTensor = node.inputs[2] + + if isinstance(mulTensor, gs.Constant) and isinstance(biasTensor, gs.Constant): + mul_values = mulTensor.values.flatten().astype(np.int32) + log2D = int(np.log2(node.attrs['div'].values)) + + # Broadcast scalar mul to per-channel if needed + if len(mul_values) == 1: + mul_values = np.full(Ko, mul_values[0], dtype = np.int32) + + ne16_scale, ne16_scale_n = _compute_ne16_scale_shift(mul_values, log2D) + + # Rescale bias from mul/log2D domain to scale/scale_n domain + # bias_merged is already *= mul from PULPGEMMRequantMergePass + # NE16 needs: bias_ne16 = bias_merged * 2^(scale_n - log2D) + bias_values = biasTensor.values.flatten().astype(np.int64) + ne16_bias = np.zeros(Ko, dtype = np.int64) + for ko in range(Ko): + shift_diff = int(ne16_scale_n[ko]) - log2D + if shift_diff >= 0: + ne16_bias[ko] = bias_values[ko] << shift_diff + else: + ne16_bias[ko] = bias_values[ko] >> (-shift_diff) + + ne16_bias = ne16_bias.astype(np.int32) + + # Overwrite mul tensor with ne16_scale + mulTensor.values = ne16_scale + + # Overwrite bias tensor + biasTensor.values = ne16_bias + + # Append scale_n as new input[4] + scale_n_tensor = gs.Constant(f"{name}_scale_n", ne16_scale_n) + node.inputs.append(scale_n_tensor) + + return graph + + +@contextagnostic +class NE16AdjustGEMMWeightLayoutPass(ReplaceSequentialPatternPass): + + def __init__(self): + graph = gs.Graph() + _input = gs.Variable(name = 'input_1') + output = graph.layer(inputs = [_input], outputs = ['out'], op = 'RequantizedGemm|Gemm', name = 'node') + graph.outputs.append(output) + graph.inputs.append(_input) + + super().__init__(graph, _ne16_adjust_gemm_weight_layout_fun, "_NE16_ADJUST_GEMM_WEIGHT_LAYOUT_PASS", + NonBranchingMatcher(regex_op = True)) diff --git a/Deeploy/Targets/GAP9/TopologyOptimizationPasses/__init__.py b/Deeploy/Targets/GAP9/TopologyOptimizationPasses/__init__.py new file mode 100644 index 0000000000..4694b67df5 --- /dev/null +++ b/Deeploy/Targets/GAP9/TopologyOptimizationPasses/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: 2025 ETH Zurich and University of Bologna +# +# SPDX-License-Identifier: Apache-2.0 diff --git a/Deeploy/Targets/Generic/TopologyOptimizationPasses/Passes.py b/Deeploy/Targets/Generic/TopologyOptimizationPasses/Passes.py index 146bcf699e..e506db78ce 100644 --- a/Deeploy/Targets/Generic/TopologyOptimizationPasses/Passes.py +++ b/Deeploy/Targets/Generic/TopologyOptimizationPasses/Passes.py @@ -1177,3 +1177,85 @@ def __init__(self): name = "_RECOGNIZE_DEQUANT_PASS" super().__init__(graph, _recognize_dequant_fun, name) + + +def _merge_dequant_quant_fun(graph: gs.Graph, match: Match, name: str): + matched_nodes = [m for k, m in match.nodes_map.items()] + dequant_node = matched_nodes[0] + quant_node = matched_nodes[1] + + # Skip if dequant output has multiple consumers or is a graph output + dequant_out = dequant_node.outputs[0] + if len(dequant_out.outputs) > 1 or dequant_out in graph.outputs: + return graph + + # Extract Dequant parameters (stored as Python floats) + s_d = float(dequant_node.attrs['scale']) + zp_d = float(dequant_node.attrs['zero_point']) + + # Extract Quant parameters (stored as numpy arrays) + s_q = float(np.array(quant_node.attrs['scale']).item()) + zp_q = float(np.array(quant_node.attrs['zero_point']).item()) + + signed_val = int(np.array(quant_node.attrs['signed']).item()) if 'signed' in quant_node.attrs else 1 + signed = bool(signed_val) + bit_width = int(np.array(quant_node.attrs['bit_width']).item()) if 'bit_width' in quant_node.attrs else 8 + n_levels = 2**bit_width + + # Compute effective ratio: y_float = (x_int - zp_d) * s_d * s_q + zp_q + ratio = s_d * s_q + + # Identity case: ratio ~= 1.0, both zero points == 0 + EPSILON = 1e-6 + if abs(ratio - 1.0) < EPSILON and abs(zp_d) < EPSILON and abs(zp_q) < EPSILON and signed: + input_tensor = dequant_node.inputs[0] + output_tensor = quant_node.outputs[0] + for downstream_node in list(output_tensor.outputs): + for i, inp in enumerate(downstream_node.inputs): + if inp == output_tensor: + downstream_node.inputs[i] = input_tensor + dequant_node.inputs.clear() + dequant_node.outputs.clear() + quant_node.inputs.clear() + quant_node.outputs.clear() + graph.cleanup().toposort() + return graph + + # Requantization case: convert to RequantShift + shift = 16 + div_val = 2**shift + + mul_val = int(np.round(ratio * div_val)) + add_val = int(np.round((-zp_d * ratio + zp_q) * div_val)) + + mul_const = gs.Constant(name = f'{name}_mul', values = np.array([mul_val], dtype = np.int32)) + add_const = gs.Constant(name = f'{name}_add', values = np.array([add_val], dtype = np.int32)) + + rqs_attrs = { + 'div': gs.Constant(f'{name}_div', np.array(div_val)), + 'n_levels_out': gs.Constant(f'{name}_n_levels', np.array(n_levels)), + 'signed': gs.Constant(f'{name}_signed', np.array([signed_val])), + } + + _inputs = [dequant_node.inputs[0], mul_const, add_const] + _outputs = quant_node.outputs + + rqs_node = gs.Node(op = 'RequantShift', name = name, attrs = rqs_attrs) + graph.replaceInsertNode(_inputs, _outputs, rqs_node) + + return graph + + +@contextagnostic +class DequantQuantMergePass(ReplaceSequentialPatternPass): + + def __init__(self): + graph = gs.Graph() + _input = gs.Variable(name = 'input_1') + output = graph.layer(inputs = [_input], outputs = ['dequant_out'], op = 'Dequant', name = 'dequant') + output = graph.layer(inputs = output, outputs = ['quant_out'], op = 'Quant', name = 'quant') + graph.outputs.append(output) + graph.inputs.append(_input) + + name = "_MERGE_DEQUANT_QUANT_PASS" + super().__init__(graph, _merge_dequant_quant_fun, name) diff --git a/Deeploy/Targets/Generic/TypeCheckers.py b/Deeploy/Targets/Generic/TypeCheckers.py index c2c8d436f8..f1174ffebd 100644 --- a/Deeploy/Targets/Generic/TypeCheckers.py +++ b/Deeploy/Targets/Generic/TypeCheckers.py @@ -543,6 +543,15 @@ class QuantChecker(SignPropTypeChecker): def __init__(self, input_types: Sequence[Type[Pointer]], output_types: Sequence[Type[Pointer]]): super().__init__(input_types, output_types) + def checkOutputType(self, inputs: List[VariableBuffer], operatorRepresentation: OperatorRepresentation) -> bool: + outputTypeSigned = self.output_types[0].referencedType.typeMin < 0 + opSigned = bool(operatorRepresentation['signed']) + if opSigned and outputTypeSigned: + return True + if (not opSigned) and (not outputTypeSigned): + return True + return False + def _inferNumLevels(self, inputs: List[VariableBuffer], operatorRepresentation: OperatorRepresentation) -> List[int]: # Calculate number of levels based on bit_width diff --git a/DeeployTest/Platforms/GAP9/CMakeLists.txt b/DeeployTest/Platforms/GAP9/CMakeLists.txt index cbb6382329..c723bb02e9 100644 --- a/DeeployTest/Platforms/GAP9/CMakeLists.txt +++ b/DeeployTest/Platforms/GAP9/CMakeLists.txt @@ -32,6 +32,7 @@ target_compile_options(network PRIVATE -Wno-pointer-sign -Wno-unknown-pragmas -Wno-error + -O3 ) target_link_options(${ProjectId} PRIVATE diff --git a/DeeployTest/testUtils/platformMapping.py b/DeeployTest/testUtils/platformMapping.py index 4dc2bbf824..11ce25fbc6 100644 --- a/DeeployTest/testUtils/platformMapping.py +++ b/DeeployTest/testUtils/platformMapping.py @@ -15,7 +15,7 @@ from Deeploy.Targets.CortexM.Deployer import CMSISDeployer from Deeploy.Targets.CortexM.Platform import CMSISOptimizer, CMSISPlatform from Deeploy.Targets.GAP9.Deployer import GAP9Deployer -from Deeploy.Targets.GAP9.Platform import GAP9Platform, MemoryGAP9Platform, MemoryGAP9PlatformWrapper +from Deeploy.Targets.GAP9.Platform import GAP9Optimizer, GAP9Platform, MemoryGAP9Platform, MemoryGAP9PlatformWrapper from Deeploy.Targets.Generic.Deployer import GenericDeployer from Deeploy.Targets.Generic.Platform import GenericOptimizer, GenericPlatform from Deeploy.Targets.MemPool.Deployer import MemPoolDeployer @@ -234,7 +234,7 @@ def mapDeployer(platform: DeploymentPlatform, elif isinstance(platform, (GAP9Platform, MemoryGAP9Platform, MemoryGAP9PlatformWrapper)): if loweringOptimizer is None: - loweringOptimizer = PULPOptimizer + loweringOptimizer = GAP9Optimizer if default_channels_first is None: default_channels_first = False diff --git a/TargetLibraries/GAP9/CMakeLists.txt b/TargetLibraries/GAP9/CMakeLists.txt index 8051484dd0..8de757a61e 100644 --- a/TargetLibraries/GAP9/CMakeLists.txt +++ b/TargetLibraries/GAP9/CMakeLists.txt @@ -4,22 +4,61 @@ file(GLOB_RECURSE SOURCES "src/**" + "$ENV{GAP_SDK_HOME}/tools/autotiler_v3/CNN_Libraries_fp32/CNN_Bias_Linear_Activation_fp32.c" + "$ENV{GAP_SDK_HOME}/tools/autotiler_v3/CNN_Libraries_NE16/CNN_BasicKernels_NE16.c" + "$ENV{GAP_SDK_HOME}/tools/autotiler_v3/CNN_Libraries/CNN_Copy.c" ) + +# Exclude dory_mem and dory_dma from SOURCES (they need different optimization) +list(FILTER SOURCES EXCLUDE REGEX ".*dory_(mem|dma).*") + # RW: Include PULPOpen sources but exclude dory_mem related files file(GLOB_RECURSE PULPOPEN_SOURCES "../PULPOpen/src/**") list(FILTER PULPOPEN_SOURCES EXCLUDE REGEX ".*dory_mem.*") list(APPEND SOURCES ${PULPOPEN_SOURCES}) +# Separate dory library compiled without -O3 +add_library(dory_lib STATIC + ${CMAKE_CURRENT_LIST_DIR}/src/dory_mem.c + ${CMAKE_CURRENT_LIST_DIR}/src/dory_dma.c +) +target_include_directories(dory_lib PUBLIC + ${CMAKE_CURRENT_LIST_DIR}/inc + ${CMAKE_CURRENT_LIST_DIR}/../PULPOpen/inc +) +target_compile_options(dory_lib PRIVATE + -Wno-implicit-function-declaration + -Wno-sign-conversion + -Wno-sign-compare + -Wno-type-limits + -Wno-attributes + -Wno-incompatible-pointer-types + -Og +) +target_compile_definitions(dory_lib PUBLIC NUM_CORES=${NUM_CORES}) +target_link_libraries(dory_lib PUBLIC pmsis) + add_deeploy_library(deeploygap9 STATIC ${SOURCES}) target_include_directories(deeploygap9 PUBLIC ${CMAKE_CURRENT_LIST_DIR}/inc ${CMAKE_CURRENT_LIST_DIR}/../PULPOpen/inc + ${TILER_INC} + ${TILER_EMU_INC} + ${TILER_CNN_KERNEL_PATH_FP32} + ${TILER_CNN_KERNEL_PATH_FP16} + $ENV{GAP_SDK_HOME}/tools/autotiler_v3/CNN_Libraries_NE16 + $ENV{GAP_SDK_HOME}/tools/autotiler_v3/CNN_Libraries_SQ8 + $ENV{GAP_SDK_HOME}/tools/autotiler_v3/CNN_Libraries + ${TILER_DSP_KERNEL_V2_PATH} + ${TILER_DSP_KERNEL_V2_PATH}/FastMathFunctions ) + target_compile_options(deeploygap9 PUBLIC -DNUM_CORES=${NUM_CORES} + -DSTD_FLOAT ) target_compile_options(deeploygap9 PRIVATE @@ -27,10 +66,10 @@ target_compile_options(deeploygap9 PRIVATE -Wno-sign-compare -Wno-type-limits -Wno-attributes + -Wno-incompatible-pointer-types + -O3 ) -target_link_libraries(deeploygap9 PUBLIC pmsis) - #RW: Link PULP-NN #RW: Set PULP-NN version and bitwidth for pulp-nn-mixed set(PULPNNVERSION XPULPV2) @@ -96,5 +135,6 @@ if(platform STREQUAL "GAP9_w_NE16") target_link_libraries(deeploygap9 PUBLIC pulp-nnx) endif() +target_link_libraries(deeploygap9 PUBLIC pmsis) target_link_libraries(deeploygap9 PUBLIC m) - +target_link_libraries(deeploygap9 PUBLIC dory_lib) diff --git a/TargetLibraries/GAP9/inc/ne16_utils.h b/TargetLibraries/GAP9/inc/ne16_utils.h new file mode 100644 index 0000000000..4d041c75dc --- /dev/null +++ b/TargetLibraries/GAP9/inc/ne16_utils.h @@ -0,0 +1,23 @@ +/* + * SPDX-FileCopyrightText: 2025 ETH Zurich and University of Bologna + * SPDX-License-Identifier: Apache-2.0 + * + * NE16 utility kernels for GAP9 + */ + +#ifndef __NE16_UTILS_GAP9__ +#define __NE16_UTILS_GAP9__ + +#include "CNN_BasicKernels_fp32.h" +#include "pmsis.h" + +typedef struct { + int8_t *In; + uint8_t *Out; + int size; +} ne16_int8_to_uint8_T; + +/* Multi-core SIMD int8 → uint8 conversion (+128 offset) */ +void ne16_int8_to_uint8(ne16_int8_to_uint8_T *Arg); + +#endif diff --git a/TargetLibraries/GAP9/src/ne16_utils.c b/TargetLibraries/GAP9/src/ne16_utils.c new file mode 100644 index 0000000000..c19c119b25 --- /dev/null +++ b/TargetLibraries/GAP9/src/ne16_utils.c @@ -0,0 +1,35 @@ +/* + * SPDX-FileCopyrightText: 2025 ETH Zurich and University of Bologna + * SPDX-License-Identifier: Apache-2.0 + * + * NE16 utility kernels for GAP9 + */ + +#include "ne16_utils.h" + +void ne16_int8_to_uint8(ne16_int8_to_uint8_T *Arg) { + int8_t *In = Arg->In; + uint8_t *Out = Arg->Out; + int size = Arg->size; + + unsigned int CoreId = gap_coreid(); + unsigned int NCore = gap_ncore(); + unsigned int total_quads = size / 4; + unsigned int Chunk = (total_quads + NCore - 1) / NCore; + unsigned int First = Chunk * CoreId; + unsigned int Last = First + Chunk; + if (Last > total_quads) + Last = total_quads; + + v4s offset = {-128, -128, -128, -128}; + for (unsigned int q = First; q < Last; q++) { + *((v4s *)&Out[q * 4]) = *((v4s *)&In[q * 4]) + offset; + } + + /* Handle remaining elements (size not multiple of 4) */ + if (CoreId == 0) { + for (int i = total_quads * 4; i < size; i++) { + Out[i] = (uint8_t)((int32_t)In[i] + 128); + } + } +} From 6c8ae2babe13aab705646b853ae8ba821ddcf4bc Mon Sep 17 00:00:00 2001 From: runwangdl Date: Tue, 14 Apr 2026 10:29:54 +0000 Subject: [PATCH 3/6] [NE16] integrate Pu DENG's NE16 Linear PR with NE16-w platform MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - TargetLibraries/GAP9/CMakeLists.txt: rename CNN_Libraries_NE16 → CNN_Libraries_HWPE (the actual gap9-sdk path); skip SDK CNN_BasicKernels_NE16.c source for GAP9_w_NE16 platform (it uses the pulp-nnx ne16 stack, so the SDK NE16 kernels are not needed). - Deeploy/Targets/NE16/Platform.py: instantiate the GAP9ClusterEngine with a trimmed includeList (no CNN_BasicKernels_NE16.h / ne16_utils.h / CNN_Copy.h) so the generated Network.c does not pull in the SDK NE16 header alongside pulp-nnx ne16_task_defs.h — the NE16_REG_* macros are defined in both and trigger -Werror redefs. --- .../ci-platform-gap9-w-ne16-tiled.yml | 10 ++++++-- Deeploy/Targets/GAP9/Platform.py | 4 ++-- Deeploy/Targets/NE16/Platform.py | 24 +++++++++++++++++-- TargetLibraries/GAP9/CMakeLists.txt | 11 +++++++-- 4 files changed, 41 insertions(+), 8 deletions(-) diff --git a/.github/workflows/ci-platform-gap9-w-ne16-tiled.yml b/.github/workflows/ci-platform-gap9-w-ne16-tiled.yml index 5f45bbafeb..c7df408ffd 100644 --- a/.github/workflows/ci-platform-gap9-w-ne16-tiled.yml +++ b/.github/workflows/ci-platform-gap9-w-ne16-tiled.yml @@ -17,7 +17,7 @@ name: CI • GAP9 + NE16 (Tiled) docker_image_deeploy: description: "Deeploy Image to use" required: false - default: "ghcr.io/pulp-platform/deeploy-gap9:latest" + default: "ghcr.io/pulp-platform/deeploy-gap9:devel" concurrency: group: ${{ github.workflow }}-${{ github.ref }} @@ -25,9 +25,15 @@ concurrency: jobs: select-env: + # The GAP9 + NE16 image is hosted in pulp-platform's private ghcr.io + # registry; only upstream's self-hosted runners have credentials to + # pull it. On forks the docker pull always returns "denied", so skip + # the whole pipeline cleanly there. (Same constraint as the existing + # ci-platform-gap9{,-tiled}.yml jobs.) + if: github.repository == 'pulp-platform/Deeploy' uses: ./.github/workflows/_select-env.yml with: - docker_image_deeploy: ${{ github.event.inputs.docker_image_deeploy || github.repository == 'pulp-platform/Deeploy' && 'ghcr.io/pulp-platform/deeploy-gap9:latest'}} + docker_image_deeploy: ${{ github.event.inputs.docker_image_deeploy || 'ghcr.io/pulp-platform/deeploy-gap9:devel' }} gap9-w-ne16-kernels-tiled-singlebuffer-L2: needs: select-env diff --git a/Deeploy/Targets/GAP9/Platform.py b/Deeploy/Targets/GAP9/Platform.py index c482c14dab..9c74362ed2 100644 --- a/Deeploy/Targets/GAP9/Platform.py +++ b/Deeploy/Targets/GAP9/Platform.py @@ -46,8 +46,8 @@ IntegerDivRequantMergePass, MatMulAddMergePass, MergeConstAddAndRequantPass, MergeTrueIntegerDivRequantShiftPass, \ QuantPatternPass, RQSSplitPass, SkipEmptyConcatPass, SkipUnityRequantPass, iGELURequantMergePass, \ iHardswishRequantMergePass -from Deeploy.Targets.PULPOpen.Bindings import BasicDequantBindings, BasicQuantBindings, PULPConv1DBinding, \ - PULPDMASliceBindings, PULPDWConv1DBinding, PULPReduceMeanBindings, PULPRQSConv1DBindings, PULPSliceBindings +from Deeploy.Targets.PULPOpen.Bindings import PULPDMASliceBindings, PULPDWConv1DBinding, PULPReduceMeanBindings, \ + PULPRQSConv1DBindings, PULPSliceBindings from Deeploy.Targets.PULPOpen.Layers import PULPRQSConvLayer, PULPRQSGEMMLayer from Deeploy.Targets.PULPOpen.Parsers import PULPConv1DParser, PULPConv2DParser, PULPDWConv1DParser, \ PULPDWConv2DParser, PULPFPConv2DParser, PULPFPDWConv2DParser, PULPGEMMParser, PULPMatrixVecParser, \ diff --git a/Deeploy/Targets/NE16/Platform.py b/Deeploy/Targets/NE16/Platform.py index 2c6fddf8e5..0fa204c0f1 100644 --- a/Deeploy/Targets/NE16/Platform.py +++ b/Deeploy/Targets/NE16/Platform.py @@ -28,7 +28,17 @@ def __init__(self, structBuffer = GAP9StructBuffer, transientBuffer = GAP9TransientBuffer) -> None: if engines is None: - engines = [NE16Engine("NE16"), GAP9ClusterEngine("GAP9Cluster")] + # Drop SDK NE16 headers from the cluster engine include list so the + # generated Network.c does not pull in CNN_BasicKernels_NE16.h / + # ne16_utils.h alongside pulp-nnx's ne16_task_defs.h + # (NE16_REG_* macros are defined in both, causing -Werror redefs). + cluster = GAP9ClusterEngine( + "GAP9Cluster", + includeList = [ + "pmsis.h", "DeeployGAP9Math.h", "pulp_nn_kernels.h", "DeeployMchan.h", "CNN_BasicKernels_fp32.h" + ], + ) + engines = [NE16Engine("NE16"), cluster] super().__init__(engines, variableBuffer, constantBuffer, structBuffer, transientBuffer) @@ -44,7 +54,17 @@ def __init__(self, structBuffer = GAP9StructBuffer, transientBuffer = GAP9TransientBuffer) -> None: if engines is None: - engines = [NE16Engine("NE16"), GAP9ClusterEngine("GAP9Cluster")] + # Drop SDK NE16 headers from the cluster engine include list so the + # generated Network.c does not pull in CNN_BasicKernels_NE16.h / + # ne16_utils.h alongside pulp-nnx's ne16_task_defs.h + # (NE16_REG_* macros are defined in both, causing -Werror redefs). + cluster = GAP9ClusterEngine( + "GAP9Cluster", + includeList = [ + "pmsis.h", "DeeployGAP9Math.h", "pulp_nn_kernels.h", "DeeployMchan.h", "CNN_BasicKernels_fp32.h" + ], + ) + engines = [NE16Engine("NE16"), cluster] super().__init__(memoryHierarchy, defaultTargetMemoryLevel, engines, variableBuffer, constantBuffer, structBuffer, transientBuffer) self.weightMemoryLevel = weightMemoryLevel diff --git a/TargetLibraries/GAP9/CMakeLists.txt b/TargetLibraries/GAP9/CMakeLists.txt index 8de757a61e..b69e19fc6a 100644 --- a/TargetLibraries/GAP9/CMakeLists.txt +++ b/TargetLibraries/GAP9/CMakeLists.txt @@ -5,10 +5,18 @@ file(GLOB_RECURSE SOURCES "src/**" "$ENV{GAP_SDK_HOME}/tools/autotiler_v3/CNN_Libraries_fp32/CNN_Bias_Linear_Activation_fp32.c" - "$ENV{GAP_SDK_HOME}/tools/autotiler_v3/CNN_Libraries_NE16/CNN_BasicKernels_NE16.c" "$ENV{GAP_SDK_HOME}/tools/autotiler_v3/CNN_Libraries/CNN_Copy.c" ) +# CNN_BasicKernels_NE16 from gap9-sdk redefines NE16_REG_* macros that +# pulp-nnx's ne16 hal also defines. For GAP9_w_NE16 we use the pulp-nnx +# NE16 stack; for plain GAP9 (Pu DENG's NE16-Linear path) we use the SDK's. +if(NOT platform STREQUAL "GAP9_w_NE16") + list(APPEND SOURCES + "$ENV{GAP_SDK_HOME}/tools/autotiler_v3/CNN_Libraries_HWPE/CNN_BasicKernels_NE16.c" + ) +endif() + # Exclude dory_mem and dory_dma from SOURCES (they need different optimization) list(FILTER SOURCES EXCLUDE REGEX ".*dory_(mem|dma).*") @@ -48,7 +56,6 @@ target_include_directories(deeploygap9 ${TILER_EMU_INC} ${TILER_CNN_KERNEL_PATH_FP32} ${TILER_CNN_KERNEL_PATH_FP16} - $ENV{GAP_SDK_HOME}/tools/autotiler_v3/CNN_Libraries_NE16 $ENV{GAP_SDK_HOME}/tools/autotiler_v3/CNN_Libraries_SQ8 $ENV{GAP_SDK_HOME}/tools/autotiler_v3/CNN_Libraries ${TILER_DSP_KERNEL_V2_PATH} From e46a09aa4e8bb3db9da52360df70594aa577de61 Mon Sep 17 00:00:00 2001 From: runwangdl Date: Tue, 14 Apr 2026 11:01:53 +0000 Subject: [PATCH 4/6] [CI] skip gap9 pipelines on forks (private docker image) ghcr.io/pulp-platform/deeploy-gap9:* is hosted in pulp-platform's private GitHub Container Registry. Only upstream's self-hosted runners have credentials to pull it; on fork CI runs (ubuntu-latest) the docker pull fails with 'Error response from daemon: denied' and the whole job is reported as failure. Guard the select-env entry of all three gap9 workflows (ci-platform-gap9.yml, -tiled.yml, -w-ne16-tiled.yml) so they SKIP cleanly on forks instead of FAILING. Upstream behaviour is unchanged. --- .github/workflows/ci-platform-gap9-tiled.yml | 3 +++ .github/workflows/ci-platform-gap9.yml | 3 +++ 2 files changed, 6 insertions(+) diff --git a/.github/workflows/ci-platform-gap9-tiled.yml b/.github/workflows/ci-platform-gap9-tiled.yml index 61cab4ea70..721cd5a365 100644 --- a/.github/workflows/ci-platform-gap9-tiled.yml +++ b/.github/workflows/ci-platform-gap9-tiled.yml @@ -25,6 +25,9 @@ concurrency: jobs: select-env: + # ghcr.io/pulp-platform/deeploy-gap9 is private; only upstream's + # self-hosted runners have credentials. Skip cleanly on forks. + if: github.repository == 'pulp-platform/Deeploy' uses: ./.github/workflows/_select-env.yml with: docker_image_deeploy: ${{ github.event.inputs.docker_image_deeploy || 'ghcr.io/pulp-platform/deeploy-gap9:devel' }} diff --git a/.github/workflows/ci-platform-gap9.yml b/.github/workflows/ci-platform-gap9.yml index 014828d6ce..597c0f40ef 100644 --- a/.github/workflows/ci-platform-gap9.yml +++ b/.github/workflows/ci-platform-gap9.yml @@ -26,6 +26,9 @@ concurrency: jobs: select-env: + # ghcr.io/pulp-platform/deeploy-gap9 is private; only upstream's + # self-hosted runners have credentials. Skip cleanly on forks. + if: github.repository == 'pulp-platform/Deeploy' uses: ./.github/workflows/_select-env.yml with: docker_image_deeploy: ${{ github.event.inputs.docker_image_deeploy || 'ghcr.io/pulp-platform/deeploy-gap9:devel' }} From ddde88f1d2efb0e2a40d05b558a7065e0bbafe8e Mon Sep 17 00:00:00 2001 From: runwangdl Date: Tue, 14 Apr 2026 13:05:50 +0000 Subject: [PATCH 5/6] [Quant] add uint8 output bindings to QuantChecker / DequantChecker QuantChecker.checkOutputType (added by the NE16-Linear PR) requires opSigned == outputTypeSigned. Existing Generic and PULPOpen bindings only registered the signed-int8 output variant, so any Quant pattern with signed=0 (e.g. 4-bit unsigned quantization in Models/Transformer_DeepQuant) had no candidate and parsing exhausted backtracking. Add uint8 output to BasicQuantBindings and uint8 input to BasicDequantBindings in both Targets/Generic/Bindings.py and Targets/PULPOpen/Bindings.py. Verified: Models/Transformer_DeepQuant network gen now succeeds for both Generic and Siracusa platforms. --- Deeploy/Targets/Generic/Bindings.py | 5 ++++- Deeploy/Targets/PULPOpen/Bindings.py | 5 ++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/Deeploy/Targets/Generic/Bindings.py b/Deeploy/Targets/Generic/Bindings.py index 308b179aef..363ed85541 100644 --- a/Deeploy/Targets/Generic/Bindings.py +++ b/Deeploy/Targets/Generic/Bindings.py @@ -291,12 +291,15 @@ BasicQuantBindings = [ NodeBinding(QuantChecker([PointerClass(float32_t)], [PointerClass(int8_t)]), QuantTemplate.referenceTemplate, BasicTransformer), + NodeBinding(QuantChecker([PointerClass(float32_t)], [PointerClass(uint8_t)]), QuantTemplate.referenceTemplate, + BasicTransformer), ] BasicDequantBindings = [ NodeBinding(DequantChecker([PointerClass(int8_t)], [PointerClass(float32_t)]), DequantTemplate.referenceTemplate, BasicTransformer), -] + [ + NodeBinding(DequantChecker([PointerClass(uint8_t)], [PointerClass(float32_t)]), DequantTemplate.referenceTemplate, + BasicTransformer), NodeBinding(DequantChecker([PointerClass(int32_t)], [PointerClass(float32_t)]), DequantTemplate.referenceTemplate, BasicTransformer), ] diff --git a/Deeploy/Targets/PULPOpen/Bindings.py b/Deeploy/Targets/PULPOpen/Bindings.py index 2c78978e23..2a68c3333c 100644 --- a/Deeploy/Targets/PULPOpen/Bindings.py +++ b/Deeploy/Targets/PULPOpen/Bindings.py @@ -453,12 +453,15 @@ BasicQuantBindings = [ NodeBinding(QuantChecker([PointerClass(float32_t)], [PointerClass(int8_t)]), QuantTemplate.referenceTemplate, ForkTransformer), + NodeBinding(QuantChecker([PointerClass(float32_t)], [PointerClass(uint8_t)]), QuantTemplate.referenceTemplate, + ForkTransformer), ] BasicDequantBindings = [ NodeBinding(DequantChecker([PointerClass(int8_t)], [PointerClass(float32_t)]), DequantTemplate.referenceTemplate, ForkTransformer), -] + [ + NodeBinding(DequantChecker([PointerClass(uint8_t)], [PointerClass(float32_t)]), DequantTemplate.referenceTemplate, + ForkTransformer), NodeBinding(DequantChecker([PointerClass(int32_t)], [PointerClass(float32_t)]), DequantTemplate.referenceTemplate, ForkTransformer), ] From 2373ff93fb8197694acf67a00edacabcbd4df6ad Mon Sep 17 00:00:00 2001 From: runwangdl Date: Tue, 14 Apr 2026 13:07:12 +0000 Subject: [PATCH 6/6] [CI] snitch-tiled: drop xdist parallelism 4 -> 2 The Snitch FP32 GEMM/TransB-5000 build OOMs the GitHub-hosted runner ('std::bad_alloc' from the C compiler driver) when 4 pytest-xdist workers compile in parallel. Two workers leave enough headroom on the standard 7-GB runner. (Pre-existing flake; surfaced as a hard fail in CI runs that happen to land both heavy FP32 GEMM compilations on adjacent workers.) --- .github/workflows/_runner-snitch-tiled-sequential.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/_runner-snitch-tiled-sequential.yml b/.github/workflows/_runner-snitch-tiled-sequential.yml index bcdd58a166..2fdd0ec839 100644 --- a/.github/workflows/_runner-snitch-tiled-sequential.yml +++ b/.github/workflows/_runner-snitch-tiled-sequential.yml @@ -33,10 +33,10 @@ jobs: - name: Build Deeploy shell: bash run: pip install -e . - - name: Run Test # VJUNG: Run tests with 4 parallel threads as GitHub action VM has 4 cores. + - name: Run Test # 2-way parallel: 4-way OOMs the GitHub runner on the FP32 GEMM/TransB build. run: | cd DeeployTest mkdir -p /app/.ccache export CCACHE_DIR=/app/.ccache - pytest test_platforms.py -v -n 4 -m "snitch_tiled and ${{ inputs.pytest-marker }}" + pytest test_platforms.py -v -n 2 -m "snitch_tiled and ${{ inputs.pytest-marker }}" shell: bash