Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
# Explicit files to ignore (only matches one).
#==============================================================================#
# Various tag programs
tags
/tags
/TAGS
/GPATH
Expand Down
44 changes: 36 additions & 8 deletions mlir/examples/dsp/SimpleBlocks/include/toy/AST.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,10 @@ class ExprAST {
enum ExprASTKind {
Expr_VarDecl,
Expr_Return,
Expr_Num,
// test for int and double
//Expr_Num,
Expr_Int,
Expr_Double,
Expr_Literal,
Expr_Var,
Expr_BinOp,
Expand All @@ -61,18 +64,43 @@ class ExprAST {
/// A block-list of expressions.
using ExprASTList = std::vector<std::unique_ptr<ExprAST>>;

// test for int and double
/// Expression class for numeric literals like "1.0".
class NumberExprAST : public ExprAST {
double val;
//class NumberExprAST : public ExprAST {
//double val;
//
//public:
//NumberExprAST(Location loc, double val)
//: ExprAST(Expr_Num, std::move(loc)), val(val) {}
//
//double getValue() { return val; }
//
///// LLVM style RTTI
//static bool classof(const ExprAST *c) { return c->getKind() == Expr_Num; }
//};

class IntExprAST : public ExprAST {
int val;

public:
NumberExprAST(Location loc, double val)
: ExprAST(Expr_Num, std::move(loc)), val(val) {}
IntExprAST(Location loc, int val)
: ExprAST(Expr_Int, std::move(loc)), val(val) {}

int getInt() { return val; }

double getValue() { return val; }
static bool classof(const ExprAST *c) { return c->getKind() == Expr_Int; }
};

/// LLVM style RTTI
static bool classof(const ExprAST *c) { return c->getKind() == Expr_Num; }
class DoubleExprAST : public ExprAST {
double val;

public:
DoubleExprAST(Location loc, double val)
: ExprAST(Expr_Double, std::move(loc)), val(val) {}

double getDouble() { return val; }

static bool classof(const ExprAST *c) { return c->getKind() == Expr_Double; }
};

/// Expression class for a literal value.
Expand Down
69 changes: 55 additions & 14 deletions mlir/examples/dsp/SimpleBlocks/include/toy/Lexer.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,11 @@ enum Token : int {

// primary
tok_identifier = -5,
tok_number = -6,

// test for int and double
// tok_number = -6,
tok_int = -6,
tok_double = -7,
};

/// The Lexer is an abstract base class providing all the facilities that the
Expand Down Expand Up @@ -83,10 +87,21 @@ class Lexer {
return identifierStr;
}

// test for int and double
/// Return the current number (prereq: getCurToken() == tok_number)
double getValue() {
assert(curTok == tok_number);
return numVal;
//double getValue() {
//assert(curTok == tok_number);
//return numVal;
//}

int getIntValue() {
assert(curTok == tok_int);
return numInt;
}

double getDoubleValue() {
assert(curTok == tok_double);
return numDouble;
}

/// Return the location for the beginning of the current token.
Expand Down Expand Up @@ -148,16 +163,39 @@ class Lexer {
return tok_identifier;
}

// test for int and double
// Number: [0-9.]+
if (isdigit(lastChar) || lastChar == '.') {
std::string numStr;
do {
numStr += lastChar;
lastChar = Token(getNextChar());
} while (isdigit(lastChar) || lastChar == '.');

numVal = strtod(numStr.c_str(), nullptr);
return tok_number;
//if (isdigit(lastChar) || lastChar == '.') {
//std::string numStr;
//do {
//numStr += lastChar;
//lastChar = Token(getNextChar());
//} while (isdigit(lastChar) || lastChar == '.');
//
//numVal = strtod(numStr.c_str(), nullptr);
//return tok_number;
//}

if(isdigit(lastChar)) {
std::string numStr;
bool isDouble = false;

do {
if(lastChar == '.') isDouble = true;

numStr += lastChar;
lastChar = Token(getNextChar());
} while(isdigit(lastChar) || lastChar == '.');

if(isDouble) {
numDouble = strtod(numStr.c_str(), nullptr);
return tok_double;
}
else {
char ** p_end;
numInt = strtol(numStr.c_str(), p_end, 10);
return tok_int;
}
}

if (lastChar == '#') {
Expand Down Expand Up @@ -189,8 +227,11 @@ class Lexer {
/// If the current Token is an identifier, this string contains the value.
std::string identifierStr;

// test for int and double
/// If the current Token is a number, this contains the value.
double numVal = 0;
//double numVal = 0;
int numInt = 0;
double numDouble = 0;

/// The last value returned by getNextChar(). We need to keep it around as we
/// always need to read ahead one character to decide when to end a token and
Expand Down
39 changes: 36 additions & 3 deletions mlir/examples/dsp/SimpleBlocks/include/toy/Ops.td
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,42 @@ def ConstantOp : Dsp_Op<"constant", [Pure]> {

// Build a constant with a given constant floating-point value.
OpBuilder<(ins "double":$value)>,
];

// Build a constant with a given constant floating-point value.
// OpBuilder<(ins "int":$value)>
// Indicate that additional verification for this operation is necessary.
let hasVerifier = 1;
}

//===----------------------------------------------------------------------===//
// IntegerConstantOp
//===----------------------------------------------------------------------===//

def IntegerConstantOp : Dsp_Op<"integer_constant", [Pure]> {
let summary = "integer constant";
let description = [{
Integer Constant operation turns a literal into an SSA value. The data is attached
to the operation as an attribute. For example:

```mlir
%0 = dsp.integer_constant dense<[[1, 2, 30], [4, 5, 6]]>
: tensor<2x3xi64>
```
}];

// expect an integer constant tensor value of type I64
let arguments = (ins I64ElementsAttr:$value);

let results = (outs I64Tensor);

let hasCustomAssemblyFormat = 1;

let builders = [
OpBuilder<(ins "DenseIntElementsAttr":$value), [{
build($_builder, $_state, value.getType(), value);
}]>,

// Build a constant with a given constant int64 value.
OpBuilder<(ins "int":$value)>,
];

// Indicate that additional verification for this operation is necessary.
Expand Down Expand Up @@ -299,7 +332,7 @@ def PrintOp : Dsp_Op<"print"> {

// The print operation takes an input tensor to print.
// We also allow a F64MemRef to enable interop during partial lowering.
let arguments = (ins AnyTypeOf<[F64Tensor, F64MemRef]>:$input);
let arguments = (ins AnyTypeOf<[F64, I64, F64Tensor, F64MemRef, I64Tensor, I64MemRef]>:$input);

let assemblyFormat = "$input attr-dict `:` type($input)";
}
Expand Down
71 changes: 56 additions & 15 deletions mlir/examples/dsp/SimpleBlocks/include/toy/Parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,14 +75,32 @@ class Parser {
return std::make_unique<ReturnExprAST>(std::move(loc), std::move(expr));
}

// test for int and double
/// Parse a literal number.
/// numberexpr ::= number
std::unique_ptr<ExprAST> parseNumberExpr() {
auto loc = lexer.getLastLocation();
auto result =
std::make_unique<NumberExprAST>(std::move(loc), lexer.getValue());
lexer.consume(tok_number);
return std::move(result);
//std::unique_ptr<ExprAST> parseNumberExpr() {
//auto loc = lexer.getLastLocation();
//auto result =
//std::make_unique<NumberExprAST>(std::move(loc), lexer.getValue());
//lexer.consume(tok_number);
//return std::move(result);
//}


std::unique_ptr<ExprAST> parseIntExpr() {
auto loc = lexer.getLastLocation();
auto result =
std::make_unique<IntExprAST>(std::move(loc), lexer.getIntValue());
lexer.consume(tok_int);
return std::move(result);
}

std::unique_ptr<ExprAST> parseDoubleExpr() {
auto loc = lexer.getLastLocation();
auto result =
std::make_unique<DoubleExprAST>(std::move(loc), lexer.getDoubleValue());
lexer.consume(tok_double);
return std::move(result);
}

/// Parse a literal array expression.
Expand All @@ -103,9 +121,17 @@ class Parser {
if (!values.back())
return nullptr; // parse error in the nested array.
} else {
if (lexer.getCurToken() != tok_number)
return parseError<ExprAST>("<num> or [", "in literal expression");
values.push_back(parseNumberExpr());
// test for int and double
//if (lexer.getCurToken() != tok_number)
//return parseError<ExprAST>("<num> or [", "in literal expression");
//values.push_back(parseNumberExpr());

if(lexer.getCurToken() != tok_int && lexer.getCurToken() != tok_double) {
return parseError<ExprAST>("<num> or [", "in literal expression");
}

if(lexer.getCurToken() == tok_int) values.push_back(parseIntExpr());
else if(lexer.getCurToken() == tok_double) values.push_back(parseDoubleExpr());
}

// End of this list on ']'
Expand Down Expand Up @@ -150,6 +176,7 @@ class Parser {
"inside literal expression");
}
}

return std::make_unique<LiteralExprAST>(std::move(loc), std::move(values),
std::move(dims));
}
Expand Down Expand Up @@ -224,8 +251,13 @@ class Parser {
return nullptr;
case tok_identifier:
return parseIdentifierExpr();
case tok_number:
return parseNumberExpr();
/* test for int and double */
//case tok_number:
//return parseNumberExpr();
case tok_int:
return parseIntExpr();
case tok_double:
return parseDoubleExpr();
case '(':
return parseParenExpr();
case '[':
Expand Down Expand Up @@ -295,11 +327,20 @@ class Parser {

auto type = std::make_unique<VarType>();

while (lexer.getCurToken() == tok_number) {
type->shape.push_back(lexer.getValue());
lexer.getNextToken();
if (lexer.getCurToken() == ',')
// test for int and double
//while (lexer.getCurToken() == tok_number) {
//type->shape.push_back(lexer.getValue());
//lexer.getNextToken();
//if (lexer.getCurToken() == ',')
//lexer.getNextToken();
//}

while(lexer.getCurToken() == tok_int || lexer.getCurToken() == tok_double) {
if(lexer.getCurToken() == tok_int) type->shape.push_back(lexer.getIntValue());
else if(lexer.getCurToken() == tok_double) type->shape.push_back(lexer.getDoubleValue());
lexer.getNextToken();

if(lexer.getCurToken() == ',') lexer.getNextToken();
}

if (lexer.getCurToken() != '>')
Expand Down
53 changes: 53 additions & 0 deletions mlir/examples/dsp/SimpleBlocks/mlir/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,59 @@ mlir::LogicalResult ConstantOp::verify() {
return mlir::success();
}

//===----------------------------------------------------------------------===//
// Integer ConstantOp
//===----------------------------------------------------------------------===//

void IntegerConstantOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
int value) {
auto dataType = RankedTensorType::get({}, builder.getI64Type());
auto dataAttribute = DenseIntElementsAttr::get(dataType, value);
IntegerConstantOp::build(builder, state, dataType, dataAttribute);
}

mlir::ParseResult IntegerConstantOp::parse(mlir::OpAsmParser &parser,
mlir::OperationState &result) {
mlir::DenseIntElementsAttr value;
printf("Parse Integer constant success for MLIRgen.\n");
if (parser.parseOptionalAttrDict(result.attributes) ||
parser.parseAttribute(value, "value", result.attributes))
return failure();

result.addTypes(value.getType());
return success();
}

void IntegerConstantOp::print(mlir::OpAsmPrinter &printer) {
printer << " ";
printer.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"value"});
printer << getValue();
}

mlir::LogicalResult IntegerConstantOp::verify() {
auto resultType = llvm::dyn_cast<mlir::RankedTensorType>(getResult().getType());
if (!resultType)
return success();

auto attrType = llvm::cast<mlir::RankedTensorType>(getValue().getType());
if (attrType.getRank() != resultType.getRank()) {
return emitOpError("return type must match the one of the attached value "
"attribute: ")
<< attrType.getRank() << " != " << resultType.getRank();
}

// Check that each of the dimensions match between the two types.
for (int dim = 0, dimE = attrType.getRank(); dim < dimE; ++dim) {
if (attrType.getShape()[dim] != resultType.getShape()[dim]) {
return emitOpError(
"return type shape mismatches its attribute at dimension ")
<< dim << ": " << attrType.getShape()[dim]
<< " != " << resultType.getShape()[dim];
}
}
return mlir::success();
}

//===----------------------------------------------------------------------===//
// AddOp
//===----------------------------------------------------------------------===//
Expand Down
Loading