1+ #include " Common/AcceleratorAttrs.h"
12#include " mlir/Pass/Pass.h"
23#include " mlir/IR/BuiltinOps.h"
34#include " mlir/IR/Operation.h"
2324
2425#include " NeuraDialect/Architecture/Architecture.h"
2526#include " NeuraDialect/NeuraOps.h"
27+ #include " NeuraDialect/NeuraAttributes.h"
2628
2729using namespace mlir ;
2830using namespace neura ;
@@ -204,21 +206,21 @@ static std::string extractConstantLiteralFromAttr(Attribute attr) {
204206// Literals for CONSTANT operations, e.g. "#10" / "#0" / "#3.0".
205207static std::string getConstantLiteral (Operation *op) {
206208 if (isConstant (op)) {
207- if (auto value_attr = op->getAttr (" value " )) {
209+ if (auto value_attr = op->getAttr (attr:: kValue )) {
208210 std::string result = extractConstantLiteralFromAttr (value_attr);
209211 if (!result.empty ()) return result;
210212 }
211213 return " #0" ;
212214 }
213215
214216 // Checks for constant_value attribute in non-CONSTANT operations.
215- if (auto constant_value_attr = op->getAttr (" constant_value " )) {
217+ if (auto constant_value_attr = op->getAttr (attr:: kConstantValue )) {
216218 std::string result = extractConstantLiteralFromAttr (constant_value_attr);
217219 if (!result.empty ()) return result;
218220 }
219221
220222 // Checks for rhs_value attribute (for binary operations with constant RHS).
221- if (auto rhs_value_attr = op->getAttr (" rhs_value " )) {
223+ if (auto rhs_value_attr = op->getAttr (attr:: kRhsValue )) {
222224 std::string result = extractConstantLiteralFromAttr (rhs_value_attr);
223225 if (!result.empty ()) return result;
224226 }
@@ -410,16 +412,16 @@ struct GenerateCodePass
410412
411413 std::pair<int , int > getArrayDimensions (func::FuncOp function) {
412414 int columns = 4 , rows = 4 ; // default 4x4 CGRA.
413- if (auto mapping_info = function->getAttrOfType <DictionaryAttr>(" mapping_info " )) {
414- if (auto x_tiles = dyn_cast_or_null<IntegerAttr>(mapping_info.get (" x_tiles " ))) columns = x_tiles.getInt ();
415- if (auto y_tiles = dyn_cast_or_null<IntegerAttr>(mapping_info.get (" y_tiles " ))) rows = y_tiles.getInt ();
415+ if (auto mapping_info = function->getAttrOfType <DictionaryAttr>(attr:: kMappingInfo )) {
416+ if (auto x_tiles = dyn_cast_or_null<IntegerAttr>(mapping_info.get (attr:: kXTiles ))) columns = x_tiles.getInt ();
417+ if (auto y_tiles = dyn_cast_or_null<IntegerAttr>(mapping_info.get (attr:: kYTiles ))) rows = y_tiles.getInt ();
416418 }
417419 return {columns, rows};
418420 }
419421
420422 int getCompiledII (func::FuncOp function) {
421- if (auto mapping_info = function->getAttrOfType <DictionaryAttr>(" mapping_info " )) {
422- if (auto compiled_ii = dyn_cast_or_null<IntegerAttr>(mapping_info.get (" compiled_ii " ))) {
423+ if (auto mapping_info = function->getAttrOfType <DictionaryAttr>(attr:: kMappingInfo )) {
424+ if (auto compiled_ii = dyn_cast_or_null<IntegerAttr>(mapping_info.get (attr:: kCompiledII ))) {
423425 return compiled_ii.getInt ();
424426 }
425427 }
@@ -510,7 +512,7 @@ struct GenerateCodePass
510512
511513 if (isConstant (op)) {
512514 inst.src_operands .emplace_back (getConstantLiteral (op), " RED" );
513- } else if (op->getAttr (" constant_value " )) {
515+ } else if (op->getAttr (attr:: kConstantValue )) {
514516 // Checks if operation has constant_value attribute (for non-CONSTANT operations).
515517 inst.src_operands .emplace_back (getConstantLiteral (op), " RED" );
516518 } else {
@@ -524,7 +526,7 @@ struct GenerateCodePass
524526 }
525527
526528 // Handles cases where binary operations have the RHS constant stored as an attribute.
527- if (auto rhs_value_attr = op->getAttr (" rhs_value " )) {
529+ if (auto rhs_value_attr = op->getAttr (attr:: kRhsValue )) {
528530 std::string rhs_literal = extractConstantLiteralFromAttr (rhs_value_attr);
529531 if (!rhs_literal.empty ()) {
530532 inst.src_operands .emplace_back (rhs_literal, " RED" );
@@ -934,7 +936,7 @@ struct GenerateCodePass
934936
935937 // Helper to extract dfg_id from operation.
936938 static int getDfgId (Operation *op) {
937- if (auto id_attr = op->getAttrOfType <IntegerAttr>(" dfg_id " )) {
939+ if (auto id_attr = op->getAttrOfType <IntegerAttr>(attr:: kDfgId )) {
938940 return id_attr.getInt ();
939941 }
940942 return -1 ;
@@ -1669,8 +1671,8 @@ struct GenerateCodePass
16691671 ModuleOp module = getOperation ();
16701672
16711673 for (auto func : module .getOps <func::FuncOp>()) {
1672- auto accel = func->getAttrOfType <StringAttr>(" accelerator " );
1673- if (!accel || accel.getValue () != " neura " ) continue ;
1674+ auto accel = func->getAttrOfType <StringAttr>(accel:: kAcceleratorAttr );
1675+ if (!accel || accel.getValue () != accel:: kNeuraTarget ) continue ;
16741676
16751677 auto [columns, rows] = getArrayDimensions (func);
16761678 Topology topo = getTopologyFromArchitecture (columns, rows);
0 commit comments