@@ -440,6 +440,20 @@ def canonicalize(self, node: gs.Node, opset: int) -> bool:
440440 attrDescriptors = [AttrDesc ("axis" , IntUnpack , default = 0 )],
441441)
442442
443+ # Opset <= 11
444+ unsqueezeDesc = OperatorDescriptor (
445+ inputDescriptor = IoDesc ("data_in" ),
446+ outputDescriptor = IoDesc ("data_out" ),
447+ attrDescriptors = [AttrDesc ("axes" , IntTupleUnpack )],
448+ )
449+
450+ # Opset <= 11
451+ squeezeDesc = OperatorDescriptor (
452+ inputDescriptor = IoDesc ("data_in" ),
453+ outputDescriptor = IoDesc ("data_out" ),
454+ attrDescriptors = [AttrDesc ("axes" , IntTupleUnpack )],
455+ )
456+
443457defaultOperatorDescriptors : Dict [str , OperatorDescriptor ] = {
444458 "Add" : addDesc ,
445459 "Concat" : concatDesc ,
@@ -465,7 +479,9 @@ def canonicalize(self, node: gs.Node, opset: int) -> bool:
465479 "Slice" : sliceDesc ,
466480 "Softmax" : softmaxDesc ,
467481 "SoftmaxGrad" : softmaxGradDesc ,
482+ "Squeeze" : squeezeDesc ,
468483 "Transpose" : transposeDesc ,
484+ "Unsqueeze" : unsqueezeDesc ,
469485 "iHardswish" : iHardswishDesc ,
470486 "iLayerNorm" : iLayerNormDesc ,
471487 "iNoNorm" : iNoNormDesc ,
0 commit comments