1+ # ----------------------------------------------------------------------
2+ #
3+ # File: BasicParsers.py
4+ #
5+ # Last edited: 15.12.2021
6+ #
7+ # Copyright (C) 2021, ETH Zurich and University of Bologna.
8+ #
9+ # Authors:
10+ # - Moritz Scherer, ETH Zurich
11+ # - Victor Jung, ETH Zurich
12+ #
13+ # ----------------------------------------------------------------------
14+ # SPDX-License-Identifier: Apache-2.0
15+ #
16+ # Licensed under the Apache License, Version 2.0 (the License); you may
17+ # not use this file except in compliance with the License.
18+ # You may obtain a copy of the License at
19+ #
20+ # www.apache.org/licenses/LICENSE-2.0
21+ #
22+ # Unless required by applicable law or agreed to in writing, software
23+ # distributed under the License is distributed on an AS IS BASIS, WITHOUT
24+ # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
25+ # See the License for the specific language governing permissions and
26+ # limitations under the License.
27+
28+ import math
29+ from typing import Tuple
30+
31+ import numpy as np
32+ import onnx_graphsurgeon as gs
33+
34+ from Deeploy .DeeployTypes import NetworkContext , NodeParser
35+ from Deeploy .Targets .Generic .Parsers import MatMulParser
36+
37+ class GEMMRedmuleParser (MatMulParser ):
38+
39+ def __init__ (self , noBiasHoisting = True ):
40+ self .noBiasHoisting = noBiasHoisting
41+ super ().__init__ ()
42+
43+ def parseNode (self , node : gs .Node ) -> (bool ):
44+
45+ ret = all ([
46+ len (node .inputs ) >= 2 ,
47+ len (node .outputs ) == 1 ,
48+ node .attrs ['alpha' ] == 1
49+ ])
50+
51+ if ret :
52+ if 'transA' in node .attrs :
53+ self .operatorRepresentation ['transA' ] = node .attrs ['transA' ]
54+ else :
55+ self .operatorRepresentation ['transA' ] = 0
56+
57+ if 'transB' in node .attrs :
58+ self .operatorRepresentation ['transB' ] = node .attrs ['transB' ]
59+ else :
60+ self .operatorRepresentation ['transB' ] = 0
61+ if 'alpha' in node .attrs :
62+ self .operatorRepresentation ['alpha' ] = node .attrs ['alpha' ]
63+ else :
64+ self .operatorRepresentation ['alpha' ] = 1
65+ if 'beta' in node .attrs :
66+ self .operatorRepresentation ['beta' ] = node .attrs ['beta' ]
67+ else :
68+ self .operatorRepresentation ['beta' ] = 1
69+
70+ return ret
71+
72+ def parseNodeCtxt (self ,
73+ ctxt : NetworkContext ,
74+ node : gs .Node ,
75+ channels_first : bool = True ) -> Tuple [NetworkContext , bool ]:
76+
77+ newCtxt , ret = super ().parseNodeCtxt (ctxt , node , channels_first )
78+
79+ if ret :
80+ inputs = ['A' , 'B' ]
81+ outputs = ['data_out' ]
82+
83+ for idx , inputNode in enumerate (node .inputs ):
84+ if idx < len (inputs ):
85+ self .operatorRepresentation [inputs [idx ]] = newCtxt .lookup (inputNode .name ).name
86+ for idx , outputNode in enumerate (node .outputs ):
87+ self .operatorRepresentation [outputs [idx ]] = newCtxt .lookup (outputNode .name ).name
88+
89+ if len (node .inputs ) == 3 :
90+ self .operatorRepresentation ['C' ] = newCtxt .lookup (node .inputs [2 ].name ).name
91+ elif not self .noBiasHoisting :
92+ values = np .zeros ((1 ))
93+ zeroTensor = gs .Constant (f'{ node .name } _C_Tensor' , values = values )
94+ newCtxt .hoistConstant (zeroTensor )
95+ self .operatorRepresentation ['C' ] = f'{ node .name } _C_Tensor'
96+
97+ self .operatorRepresentation ['size' ] = np .prod (newCtxt .lookup (node .inputs [0 ].name ).shape )
98+
99+ return newCtxt , ret
0 commit comments