|
| 1 | +# SPDX-FileCopyrightText: 2024 ETH Zurich and University of Bologna |
| 2 | +# |
| 3 | +# SPDX-License-Identifier: Apache-2.0 |
| 4 | + |
| 5 | +from typing import List |
| 6 | + |
| 7 | +import onnx_graphsurgeon as gs |
| 8 | + |
| 9 | +from Deeploy.DeeployTypes import DeploymentEngine, NodeMapper |
| 10 | +from Deeploy.Targets.Generic.Layers import ConvLayer |
| 11 | +from Deeploy.Targets.NE16.Parsers import NE16DenseConv2DParser, NE16DWConv2DParser, NE16PWConv2DParser, \ |
| 12 | + NE16RQSDenseConv2DParser, NE16RQSDWConv2DParser, NE16RQSPWConv2DParser |
| 13 | +from Deeploy.Targets.NE16.Tiler import NE16DenseConv2DTilingReadyBindings, NE16DWConv2DTilingReadyBindings, \ |
| 14 | + NE16PWConv2DTilingReadyBindings, NE16RQSDenseConv2DTilingReadyBindings, NE16RQSDWConv2DTilingReadyBindings, \ |
| 15 | + NE16RQSPWConv2DTilingReadyBindings |
| 16 | +from Deeploy.Targets.PULPOpen.Layers import PULPRQSConvLayer |
| 17 | + |
| 18 | +NE16RqntPWConv2DMapper = NodeMapper(NE16RQSPWConv2DParser(), NE16RQSPWConv2DTilingReadyBindings) |
| 19 | +NE16PWConv2DMapper = NodeMapper(NE16PWConv2DParser(), NE16PWConv2DTilingReadyBindings) |
| 20 | + |
| 21 | +NE16RqntDWConv2DMapper = NodeMapper(NE16RQSDWConv2DParser(), NE16RQSDWConv2DTilingReadyBindings) |
| 22 | +NE16DWConv2DMapper = NodeMapper(NE16DWConv2DParser(), NE16DWConv2DTilingReadyBindings) |
| 23 | + |
| 24 | +NE16RqntDenseConv2DMapper = NodeMapper(NE16RQSDenseConv2DParser(), NE16RQSDenseConv2DTilingReadyBindings) |
| 25 | +NE16DenseConv2DMapper = NodeMapper(NE16DenseConv2DParser(), NE16DenseConv2DTilingReadyBindings) |
| 26 | + |
| 27 | +NE16Mapping = { |
| 28 | + 'RequantizedConv': PULPRQSConvLayer([NE16RqntPWConv2DMapper, NE16RqntDWConv2DMapper, NE16RqntDenseConv2DMapper]), |
| 29 | + 'Conv': ConvLayer([NE16PWConv2DMapper, NE16DWConv2DMapper, NE16DenseConv2DMapper]), |
| 30 | +} |
| 31 | + |
| 32 | +_includeList = ["pulp_nnx_ne16.h", "pulp_nnx_util.h", "ne16_pulp_bsp.h", "ne16.h", "ne16_task.h"] |
| 33 | + |
| 34 | +_ne16InitCode = r""" |
| 35 | +ne16_pulp_conf_t conf = {.max_stall = 8}; |
| 36 | +ne16_nnx_init(ne16_pulp_get_dev(), &conf); |
| 37 | +""" |
| 38 | + |
| 39 | + |
| 40 | +class NE16Engine(DeploymentEngine): |
| 41 | + |
| 42 | + def __init__(self, |
| 43 | + name: str, |
| 44 | + Mapping = NE16Mapping, |
| 45 | + initCode: str = _ne16InitCode, |
| 46 | + includeList: List[str] = _includeList, |
| 47 | + enable3x3: bool = False, |
| 48 | + enableStrides: bool = False) -> None: |
| 49 | + super().__init__(name, Mapping, initCode, includeList) |
| 50 | + |
| 51 | + self.enable3x3 = enable3x3 |
| 52 | + self.enableStrides = enableStrides |
| 53 | + |
| 54 | + def isDenseConv(self, node) -> bool: |
| 55 | + return node.op in ["Conv", "RequantizedConv"] and \ |
| 56 | + isinstance(node.inputs[1], gs.Constant) and \ |
| 57 | + node.attrs['kernel_shape'] == [3, 3] and \ |
| 58 | + node.attrs['dilations'] == [1, 1] and \ |
| 59 | + node.attrs['group'] == 1 and \ |
| 60 | + (node.attrs['strides'] == [1, 1] or self.enableStrides) |
| 61 | + |
| 62 | + def isPWConv(self, node) -> bool: |
| 63 | + return node.op in ["Conv", "RequantizedConv"] and \ |
| 64 | + isinstance(node.inputs[1], gs.Constant) and \ |
| 65 | + node.attrs['kernel_shape'] == [1, 1] and \ |
| 66 | + node.attrs['dilations'] == [1, 1] and \ |
| 67 | + (node.attrs['strides'] == [1, 1] or self.enableStrides) |
| 68 | + |
| 69 | + def isDWConv(self, node) -> bool: |
| 70 | + return node.op in ["Conv", "RequantizedConv"] and \ |
| 71 | + isinstance(node.inputs[1], gs.Constant) and \ |
| 72 | + node.attrs['kernel_shape'] == [3, 3] and \ |
| 73 | + node.attrs['dilations'] == [1, 1] and \ |
| 74 | + node.attrs['group'] != 1 and \ |
| 75 | + (node.attrs['strides'] == [1, 1] or self.enableStrides) |
| 76 | + |
| 77 | + def canExecute(self, node: gs.Node) -> bool: |
| 78 | + if self.enable3x3: |
| 79 | + return self.isPWConv(node) or self.isDWConv(node) or self.isDenseConv(node) |
| 80 | + else: |
| 81 | + return self.isPWConv(node) |
0 commit comments