Skip to content

Commit 76f8c1a

Browse files
authored
Merge pull request coredac#241 from shiyunyao/main
[Refactor] Introduce NeuraAttributes.h to manage attribute constants
2 parents 3e44105 + 344e97c commit 76f8c1a

17 files changed

Lines changed: 206 additions & 101 deletions

include/NeuraDialect/Mapping/HeuristicMapping/HeuristicMapping.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
#include "NeuraDialect/Mapping/Mapping.h"
55
#include "NeuraDialect/Mapping/MappingState.h"
6+
#include "NeuraDialect/NeuraAttributes.h"
67
#include <climits>
78
#include <set>
89

@@ -21,14 +22,14 @@ class HeuristicMapping : public Mapping {
2122

2223
std::string getName() const override {
2324
if (max_location_to_try == 1 && max_backtrack_depth == 1) {
24-
return "simple";
25+
return attr::val::kSimple.str();
2526
} else if (max_location_to_try == INT_MAX && max_backtrack_depth == 1) {
26-
return "greedy";
27+
return attr::val::kGreedy.str();
2728
} else if (max_location_to_try == INT_MAX &&
2829
max_backtrack_depth == INT_MAX) {
29-
return "exhaustive";
30+
return attr::val::kExhaustive.str();
3031
} else {
31-
return "customized";
32+
return attr::val::kCustomized.str();
3233
}
3334
}
3435

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
#pragma once
2+
3+
#include "llvm/ADT/StringRef.h"
4+
5+
namespace mlir {
6+
namespace neura {
7+
8+
namespace attr {
9+
10+
// Attribute Keys
11+
12+
// Specifies the dataflow representation mode, as opposed to control-flow.
13+
constexpr llvm::StringLiteral kDataflowMode = "dataflow_mode";
14+
15+
// Specifies the mapping strategy mode, can be either 'spatial-only' or
16+
// 'spatial-temporal'.
17+
constexpr llvm::StringLiteral kMappingMode = "mapping_mode";
18+
19+
constexpr llvm::StringLiteral kMappingStrategy = "mapping_strategy";
20+
constexpr llvm::StringLiteral kBacktrackConfig = "backtrack_config";
21+
constexpr llvm::StringLiteral kDumpMappingTable = "dump_mapping_table";
22+
23+
// Identification & Results
24+
constexpr llvm::StringLiteral kDfgId = "dfg_id";
25+
constexpr llvm::StringLiteral kMappingInfo = "mapping_info";
26+
constexpr llvm::StringLiteral kXTiles = "x_tiles";
27+
constexpr llvm::StringLiteral kYTiles = "y_tiles";
28+
constexpr llvm::StringLiteral kCompiledII = "compiled_ii";
29+
constexpr llvm::StringLiteral kRecMII = "rec_mii";
30+
constexpr llvm::StringLiteral kResMII = "res_mii";
31+
32+
// Values & Constants Keys
33+
constexpr llvm::StringLiteral kValue = "value";
34+
constexpr llvm::StringLiteral kConstantValue = "constant_value";
35+
constexpr llvm::StringLiteral kRhsValue = "rhs_value";
36+
constexpr llvm::StringLiteral kLhsValue = "lhs_value";
37+
38+
// Attribute Values & Constants
39+
namespace val {
40+
// Strategy & Mode
41+
constexpr llvm::StringLiteral kSpatialOnly = "spatial-only";
42+
constexpr llvm::StringLiteral kSpatialTemporal = "spatial-temporal";
43+
constexpr llvm::StringLiteral kTemporal = "temporal";
44+
constexpr llvm::StringLiteral kHeuristic = "heuristic";
45+
constexpr llvm::StringLiteral kCustomized = "customized";
46+
constexpr llvm::StringLiteral kSimple = "simple";
47+
constexpr llvm::StringLiteral kGreedy = "greedy";
48+
constexpr llvm::StringLiteral kExhaustive = "exhaustive";
49+
50+
// Identifiers
51+
constexpr llvm::StringLiteral kModeSteering = "steering";
52+
constexpr llvm::StringLiteral kModePredicate = "predicate";
53+
54+
// Operation Logic
55+
constexpr llvm::StringLiteral kOpFused = "fused_op";
56+
constexpr llvm::StringLiteral kNeuraFusedOp = "neura.fused_op";
57+
58+
} // namespace val
59+
60+
} // namespace attr
61+
} // namespace neura
62+
} // namespace mlir

lib/NeuraDialect/Transforms/CanonicalizeCastPass.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#include "Common/AcceleratorAttrs.h"
12
#include "NeuraDialect/NeuraOps.h"
23
#include "mlir/Dialect/Func/IR/FuncOps.h"
34
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
@@ -113,14 +114,16 @@ struct CanonicalizeCastPass
113114
module_op.walk([&](Operation *op) {
114115
Region *region = nullptr;
115116
if (auto func_op = dyn_cast<func::FuncOp>(op)) {
116-
auto accel_attr = func_op->getAttrOfType<StringAttr>("accelerator");
117-
if (!accel_attr || accel_attr.getValue() != "neura") {
117+
auto accel_attr =
118+
func_op->getAttrOfType<StringAttr>(accel::kAcceleratorAttr);
119+
if (!accel_attr || accel_attr.getValue() != accel::kNeuraTarget) {
118120
return;
119121
}
120122
region = &func_op.getBody();
121123
} else if (auto llvm_func = dyn_cast<LLVM::LLVMFuncOp>(op)) {
122-
auto accel_attr = llvm_func->getAttrOfType<StringAttr>("accelerator");
123-
if (!accel_attr || accel_attr.getValue() != "neura") {
124+
auto accel_attr =
125+
llvm_func->getAttrOfType<StringAttr>(accel::kAcceleratorAttr);
126+
if (!accel_attr || accel_attr.getValue() != accel::kNeuraTarget) {
124127
return;
125128
}
126129
region = &llvm_func.getBody();

lib/NeuraDialect/Transforms/CanonicalizeLiveInPass.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#include "Common/AcceleratorAttrs.h"
12
#include "NeuraDialect/NeuraDialect.h"
23
#include "NeuraDialect/NeuraOps.h"
34
#include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -796,14 +797,16 @@ struct CanonicalizeLiveInPass
796797
module_op.walk([&](Operation *op) {
797798
Region *region = nullptr;
798799
if (auto func_op = dyn_cast<func::FuncOp>(op)) {
799-
auto accel_attr = func_op->getAttrOfType<StringAttr>("accelerator");
800-
if (!accel_attr || accel_attr.getValue() != "neura") {
800+
auto accel_attr =
801+
func_op->getAttrOfType<StringAttr>(accel::kAcceleratorAttr);
802+
if (!accel_attr || accel_attr.getValue() != accel::kNeuraTarget) {
801803
return;
802804
}
803805
region = &func_op.getBody();
804806
} else if (auto llvm_func = dyn_cast<LLVM::LLVMFuncOp>(op)) {
805-
auto accel_attr = llvm_func->getAttrOfType<StringAttr>("accelerator");
806-
if (!accel_attr || accel_attr.getValue() != "neura") {
807+
auto accel_attr =
808+
llvm_func->getAttrOfType<StringAttr>(accel::kAcceleratorAttr);
809+
if (!accel_attr || accel_attr.getValue() != accel::kNeuraTarget) {
807810
return;
808811
}
809812
region = &llvm_func.getBody();

lib/NeuraDialect/Transforms/CanonicalizeReturnPass.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#include "Common/AcceleratorAttrs.h"
12
#include "NeuraDialect/NeuraDialect.h"
23
#include "NeuraDialect/NeuraOps.h"
34
#include "NeuraDialect/NeuraPasses.h"
@@ -191,7 +192,8 @@ struct CanonicalizeReturnPass
191192
void runOnOperation() override {
192193
func::FuncOp func_op = getOperation();
193194
// Checks for neura accelerator attribute.
194-
auto accel_attr = func_op->getAttrOfType<StringAttr>("accelerator");
195+
auto accel_attr =
196+
func_op->getAttrOfType<StringAttr>(accel::kAcceleratorAttr);
195197
if (!accel_attr) {
196198
return;
197199
}

lib/NeuraDialect/Transforms/GenerateCodePass.cpp

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#include "Common/AcceleratorAttrs.h"
12
#include "mlir/Pass/Pass.h"
23
#include "mlir/IR/BuiltinOps.h"
34
#include "mlir/IR/Operation.h"
@@ -23,6 +24,7 @@
2324

2425
#include "NeuraDialect/Architecture/Architecture.h"
2526
#include "NeuraDialect/NeuraOps.h"
27+
#include "NeuraDialect/NeuraAttributes.h"
2628

2729
using namespace mlir;
2830
using namespace neura;
@@ -204,21 +206,21 @@ static std::string extractConstantLiteralFromAttr(Attribute attr) {
204206
// Literals for CONSTANT operations, e.g. "#10" / "#0" / "#3.0".
205207
static std::string getConstantLiteral(Operation *op) {
206208
if (isConstant(op)) {
207-
if (auto value_attr = op->getAttr("value")) {
209+
if (auto value_attr = op->getAttr(attr::kValue)) {
208210
std::string result = extractConstantLiteralFromAttr(value_attr);
209211
if (!result.empty()) return result;
210212
}
211213
return "#0";
212214
}
213215

214216
// Checks for constant_value attribute in non-CONSTANT operations.
215-
if (auto constant_value_attr = op->getAttr("constant_value")) {
217+
if (auto constant_value_attr = op->getAttr(attr::kConstantValue)) {
216218
std::string result = extractConstantLiteralFromAttr(constant_value_attr);
217219
if (!result.empty()) return result;
218220
}
219221

220222
// Checks for rhs_value attribute (for binary operations with constant RHS).
221-
if (auto rhs_value_attr = op->getAttr("rhs_value")) {
223+
if (auto rhs_value_attr = op->getAttr(attr::kRhsValue)) {
222224
std::string result = extractConstantLiteralFromAttr(rhs_value_attr);
223225
if (!result.empty()) return result;
224226
}
@@ -410,16 +412,16 @@ struct GenerateCodePass
410412

411413
std::pair<int, int> getArrayDimensions(func::FuncOp function) {
412414
int columns = 4, rows = 4; // default 4x4 CGRA.
413-
if (auto mapping_info = function->getAttrOfType<DictionaryAttr>("mapping_info")) {
414-
if (auto x_tiles = dyn_cast_or_null<IntegerAttr>(mapping_info.get("x_tiles"))) columns = x_tiles.getInt();
415-
if (auto y_tiles = dyn_cast_or_null<IntegerAttr>(mapping_info.get("y_tiles"))) rows = y_tiles.getInt();
415+
if (auto mapping_info = function->getAttrOfType<DictionaryAttr>(attr::kMappingInfo)) {
416+
if (auto x_tiles = dyn_cast_or_null<IntegerAttr>(mapping_info.get(attr::kXTiles))) columns = x_tiles.getInt();
417+
if (auto y_tiles = dyn_cast_or_null<IntegerAttr>(mapping_info.get(attr::kYTiles))) rows = y_tiles.getInt();
416418
}
417419
return {columns, rows};
418420
}
419421

420422
int getCompiledII(func::FuncOp function) {
421-
if (auto mapping_info = function->getAttrOfType<DictionaryAttr>("mapping_info")) {
422-
if (auto compiled_ii = dyn_cast_or_null<IntegerAttr>(mapping_info.get("compiled_ii"))) {
423+
if (auto mapping_info = function->getAttrOfType<DictionaryAttr>(attr::kMappingInfo)) {
424+
if (auto compiled_ii = dyn_cast_or_null<IntegerAttr>(mapping_info.get(attr::kCompiledII))) {
423425
return compiled_ii.getInt();
424426
}
425427
}
@@ -510,7 +512,7 @@ struct GenerateCodePass
510512

511513
if (isConstant(op)) {
512514
inst.src_operands.emplace_back(getConstantLiteral(op), "RED");
513-
} else if (op->getAttr("constant_value")) {
515+
} else if (op->getAttr(attr::kConstantValue)) {
514516
// Checks if operation has constant_value attribute (for non-CONSTANT operations).
515517
inst.src_operands.emplace_back(getConstantLiteral(op), "RED");
516518
} else {
@@ -524,7 +526,7 @@ struct GenerateCodePass
524526
}
525527

526528
// Handles cases where binary operations have the RHS constant stored as an attribute.
527-
if (auto rhs_value_attr = op->getAttr("rhs_value")) {
529+
if (auto rhs_value_attr = op->getAttr(attr::kRhsValue)) {
528530
std::string rhs_literal = extractConstantLiteralFromAttr(rhs_value_attr);
529531
if (!rhs_literal.empty()) {
530532
inst.src_operands.emplace_back(rhs_literal, "RED");
@@ -934,7 +936,7 @@ struct GenerateCodePass
934936

935937
// Helper to extract dfg_id from operation.
936938
static int getDfgId(Operation *op) {
937-
if (auto id_attr = op->getAttrOfType<IntegerAttr>("dfg_id")) {
939+
if (auto id_attr = op->getAttrOfType<IntegerAttr>(attr::kDfgId)) {
938940
return id_attr.getInt();
939941
}
940942
return -1;
@@ -1669,8 +1671,8 @@ struct GenerateCodePass
16691671
ModuleOp module = getOperation();
16701672

16711673
for (auto func : module.getOps<func::FuncOp>()) {
1672-
auto accel = func->getAttrOfType<StringAttr>("accelerator");
1673-
if (!accel || accel.getValue() != "neura") continue;
1674+
auto accel = func->getAttrOfType<StringAttr>(accel::kAcceleratorAttr);
1675+
if (!accel || accel.getValue() != accel::kNeuraTarget) continue;
16741676

16751677
auto [columns, rows] = getArrayDimensions(func);
16761678
Topology topo = getTopologyFromArchitecture(columns, rows);

lib/NeuraDialect/Transforms/GraphMining/GraMi.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
#include "Common/AcceleratorAttrs.h"
2+
#include "NeuraDialect/NeuraAttributes.h"
13
#include "NeuraDialect/Transforms/GraphMining/GraMi.h"
24
#include "NeuraDialect/Mapping/mapping_util.h"
35
#include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -204,13 +206,13 @@ std::vector<PatternWithSelectedInstances> GraMi::mineFrequentSubgraphs() {
204206
auto derive_label = [](mlir::Operation* op, const std::string& fallback_label) -> std::string {
205207
if (!op) return fallback_label;
206208
auto name = op->getName().getStringRef();
207-
if (name.ends_with("fused_op") || name.contains("neura.fused_op")) {
209+
if (name.ends_with(attr::val::kOpFused) || name.contains(attr::val::kNeuraFusedOp)) {
208210
if (auto attr = op->getAttr("pattern_name")) {
209211
if (auto str_attr = mlir::dyn_cast<mlir::StringAttr>(attr)) {
210212
return std::string("fused_op:") + str_attr.getValue().str();
211213
}
212214
}
213-
return std::string("fused_op");
215+
return std::string(attr::val::kOpFused);
214216
}
215217
return fallback_label;
216218
};

lib/NeuraDialect/Transforms/InsertCtrlMovPass.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#include "Common/AcceleratorAttrs.h"
12
#include "NeuraDialect/NeuraDialect.h"
23
#include "NeuraDialect/NeuraOps.h"
34
#include "NeuraDialect/NeuraPasses.h"
@@ -20,7 +21,7 @@ struct InsertCtrlMovForNeuraOps : public RewritePattern {
2021

2122
LogicalResult matchAndRewrite(Operation *op,
2223
PatternRewriter &rewriter) const override {
23-
if (op->getDialect()->getNamespace() != "neura" ||
24+
if (op->getDialect()->getNamespace() != accel::kNeuraTarget ||
2425
isa<neura::CtrlMovOp>(op)) {
2526
return failure();
2627
}

lib/NeuraDialect/Transforms/InsertDataMovPass.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#include "Common/AcceleratorAttrs.h"
12
#include "NeuraDialect/NeuraDialect.h"
23
#include "NeuraDialect/NeuraOps.h"
34
#include "NeuraDialect/NeuraPasses.h"
@@ -22,7 +23,7 @@ struct InsertDataMovForNeuraOps : public RewritePattern {
2223

2324
LogicalResult matchAndRewrite(Operation *op,
2425
PatternRewriter &rewriter) const override {
25-
if (op->getDialect()->getNamespace() != "neura" ||
26+
if (op->getDialect()->getNamespace() != accel::kNeuraTarget ||
2627
isa<neura::DataMovOp>(op)) {
2728
return failure();
2829
}

lib/NeuraDialect/Transforms/LeveragePredicatedValuePass.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#include "Common/AcceleratorAttrs.h"
12
#include "NeuraDialect/NeuraDialect.h"
23
#include "NeuraDialect/NeuraOps.h"
34
#include "NeuraDialect/NeuraPasses.h"
@@ -33,8 +34,9 @@ struct LeveragePredicatedValuePass
3334

3435
// Processes each function.
3536
module.walk([&](FunctionOpInterface func) {
36-
auto accel_attr = func->getAttrOfType<StringAttr>("accelerator");
37-
if (!accel_attr || accel_attr.getValue() != "neura") {
37+
auto accel_attr =
38+
func->getAttrOfType<StringAttr>(accel::kAcceleratorAttr);
39+
if (!accel_attr || accel_attr.getValue() != accel::kNeuraTarget) {
3840
return;
3941
}
4042
// Converts block argument types to predicated values.
@@ -107,7 +109,7 @@ struct LeveragePredicatedValuePass
107109
// Converts a single operation to use predicated values.
108110
LogicalResult applyPredicatedDataType(Operation *op) {
109111
// Skips if not a Neura op.
110-
if (op->getDialect()->getNamespace() != "neura") {
112+
if (op->getDialect()->getNamespace() != accel::kNeuraTarget) {
111113
return success();
112114
}
113115

0 commit comments

Comments
 (0)