Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CLAUDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,13 @@ uv run ruff format gpu_test/

- **Stack Type**: `!forth.stack` - untyped stack, programmer ensures type safety
- **Operations**: All take stack as input and produce stack as output (except `forth.stack`)
- **Supported Words**: literals (integer `42` and float `3.14`), `DUP DROP SWAP OVER ROT NIP TUCK PICK ROLL`, `+ - * / MOD`, `F+ F- F* F/` (float arithmetic), `AND OR XOR NOT LSHIFT RSHIFT`, `= < > <> <= >= 0=`, `F= F< F> F<> F<= F>=` (float comparison), `S>F F>S` (int/float conversion), `@ !` (global memory), `F@ F!` (float global memory), `S@ S!` (shared memory), `SF@ SF!` (float shared memory), `CELLS`, `IF ELSE THEN`, `BEGIN UNTIL`, `BEGIN WHILE REPEAT`, `DO LOOP +LOOP I J K`, `LEAVE UNLOOP EXIT`, `TID-X/Y/Z BID-X/Y/Z BDIM-X/Y/Z GDIM-X/Y/Z GLOBAL-ID` (GPU indexing).
- **Supported Words**: literals (integer `42` and float `3.14`), `DUP DROP SWAP OVER ROT NIP TUCK PICK ROLL`, `+ - * / MOD`, `F+ F- F* F/` (float arithmetic), `AND OR XOR NOT LSHIFT RSHIFT`, `= < > <> <= >= 0=`, `F= F< F> F<> F<= F>=` (float comparison), `S>F F>S` (int/float conversion), `@ !` (global memory), `F@ F!` (float global memory), `S@ S!` (shared memory), `SF@ SF!` (float shared memory), `CELLS`, `IF ELSE THEN`, `BEGIN UNTIL`, `BEGIN WHILE REPEAT`, `DO LOOP +LOOP I J K`, `LEAVE UNLOOP EXIT`, `{ a b -- }` (local variables in word definitions), `TID-X/Y/Z BID-X/Y/Z BDIM-X/Y/Z GDIM-X/Y/Z GLOBAL-ID` (GPU indexing).
- **Float Literals**: Numbers containing `.` or `e`/`E` are parsed as f64 (e.g. `3.14`, `-2.0`, `1.0e-5`, `1e3`). Stored on the stack as i64 bit patterns; F-prefixed words perform bitcast before/after operations.
- **Kernel Parameters**: Declared in the `\!` header. `\! kernel <name>` is required and must appear first. `\! param <name> i64[<N>]` becomes a `memref<Nxi64>` argument; `\! param <name> i64` becomes an `i64` argument. `\! param <name> f64[<N>]` becomes a `memref<Nxf64>` argument; `\! param <name> f64` becomes an `f64` argument (bitcast to i64 when pushed to stack). Using a param name in code emits `forth.param_ref` (arrays push address; scalars push value).
- **Shared Memory**: `\! shared <name> i64[<N>]` or `\! shared <name> f64[<N>]` declares GPU shared (workgroup) memory. Emits a tagged `memref.alloca` at kernel entry; ForthToGPU converts it to a `gpu.func` workgroup attribution. Using the shared name in code pushes its base address onto the stack. Use `S@`/`S!` for i64 or `SF@`/`SF!` for f64 shared accesses. Cannot be referenced inside word definitions.
- **Conversion**: `!forth.stack` → `memref<256xi64>` with explicit stack pointer
- **GPU**: Functions wrapped in `gpu.module`, `main` gets `gpu.kernel` attribute, configured with bare pointers for NVVM conversion
- **Local Variables**: `{ a b c -- }` at the start of a word definition binds read-only locals. Pops values from the stack in reverse name order (c, b, a) using `forth.pop`, stores SSA values. Referencing a local emits `forth.push_value`. SSA values from the entry block dominate all control flow, so locals work across IF/ELSE/THEN, loops, etc. On GPU, locals map directly to registers.
- **User-defined Words**: Modeled as `func.func` with signature `(!forth.stack) -> !forth.stack`, called via `func.call`

## Conventions
Expand Down
91 changes: 91 additions & 0 deletions lib/Translation/ForthToMLIR/ForthToMLIR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,16 @@ Value ForthParser::emitOperation(StringRef word, Value inputStack,
Location loc) {
Type stackType = forth::StackType::get(context);

// Check if word is a local variable (only valid inside word definitions)
if (inWordDefinition) {
auto it = localVars.find(word);
if (it != localVars.end()) {
return builder
.create<forth::PushValueOp>(loc, stackType, inputStack, it->second)
.getOutputStack();
}
}

// Check if word is a param name (only valid outside word definitions)
if (!inWordDefinition) {
for (const auto &param : paramDecls) {
Expand Down Expand Up @@ -1014,6 +1024,82 @@ LogicalResult ForthParser::parseBody(Value &stack) {
// Word definition and top-level parsing.
//===----------------------------------------------------------------------===//

LogicalResult ForthParser::parseLocals(Value &stack) {
// If current token is not '{', no locals to parse
if (currentToken.kind != Token::Kind::Word || currentToken.text != "{")
return success();

Location loc = getLoc();
consume(); // consume '{'

// Collect local names until '--' or '}'
SmallVector<std::string> names;
while (currentToken.kind != Token::Kind::EndOfFile) {
if (currentToken.kind == Token::Kind::Word && currentToken.text == "--")
break;
if (currentToken.kind == Token::Kind::Word && currentToken.text == "}")
break;

if (currentToken.kind != Token::Kind::Word)
return emitError("expected local variable name in { ... }");

std::string name = currentToken.text; // already uppercased by lexer

// Check for duplicate local names
for (const auto &existing : names) {
if (existing == name)
return emitError("duplicate local variable name: " + name);
}

// Check for conflicts with param names
for (const auto &param : paramDecls) {
if (param.name == name)
return emitError("local variable name '" + name +
"' conflicts with parameter name");
}

// Check for conflicts with shared names
for (const auto &shared : sharedDecls) {
if (shared.name == name)
return emitError("local variable name '" + name +
"' conflicts with shared memory name");
}

names.push_back(name);
consume();
}

// Skip '--' and output names until '}'
if (currentToken.kind == Token::Kind::Word && currentToken.text == "--") {
consume(); // consume '--'
while (currentToken.kind != Token::Kind::EndOfFile) {
if (currentToken.kind == Token::Kind::Word && currentToken.text == "}")
break;
consume(); // skip output names (ignored)
}
}

if (currentToken.kind != Token::Kind::Word || currentToken.text != "}")
return emitError("expected '}' to close local variable declaration");

consume(); // consume '}'

if (names.empty())
return success();

// Pop values in reverse order: { a b c -- } with stack ( 1 2 3 )
// pops 3->c, 2->b, 1->a
Type i64Type = builder.getI64Type();
Type stackType = forth::StackType::get(context);
for (int i = names.size() - 1; i >= 0; --i) {
auto popOp = builder.create<forth::PopOp>(loc, stackType, i64Type, stack);
stack = popOp.getOutputStack();
localVars[names[i]] = popOp.getValue();
}

return success();
}

LogicalResult ForthParser::parseWordDefinition() {
Location loc = getLoc();
auto savedInsertionPoint = builder.saveInsertionPoint();
Expand All @@ -1039,6 +1125,10 @@ LogicalResult ForthParser::parseWordDefinition() {
Value resultStack = entryBlock->getArgument(0);
builder.setInsertionPointToStart(entryBlock);

// Parse local variable declarations (if any)
if (failed(parseLocals(resultStack)))
return failure();

// Parse word body until ';'
if (failed(parseBody(resultStack)))
return failure();
Expand All @@ -1057,6 +1147,7 @@ LogicalResult ForthParser::parseWordDefinition() {
consume(); // consume ';'

inWordDefinition = false;
localVars.clear();

// Restore insertion point
builder.restoreInsertionPoint(savedInsertionPoint);
Expand Down
4 changes: 4 additions & 0 deletions lib/Translation/ForthToMLIR/ForthToMLIR.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ class ForthParser {
std::vector<ParamDecl> paramDecls;
std::vector<SharedDecl> sharedDecls;
llvm::StringMap<Value> sharedAllocs;
llvm::StringMap<Value> localVars;
std::string kernelName;
const char *headerEndPtr = nullptr;
bool inWordDefinition = false;
Expand Down Expand Up @@ -159,6 +160,9 @@ class ForthParser {
void emitLoopEnd(Location loc, const LoopContext &ctx, Value step,
Value &stack);

/// Parse local variable declarations: { a b c -- }
LogicalResult parseLocals(Value &stack);

/// Parse a user-defined word definition.
LogicalResult parseWordDefinition();
};
Expand Down
10 changes: 10 additions & 0 deletions test/Pipeline/local-variables.forth
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
\ RUN: %warpforth-translate --forth-to-mlir %s | %warpforth-opt --warpforth-pipeline | %FileCheck %s

\ Verify that local variables compile through the full pipeline to gpu.binary.
\ CHECK: gpu.binary @warpforth_module

\! kernel main
\! param DATA i64[256]
: ADD3 { a b c -- } a b + c + ;
1 2 3 ADD3
GLOBAL-ID CELLS DATA + !
19 changes: 19 additions & 0 deletions test/Translation/Forth/local-variables-control-flow.forth
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
\ RUN: %warpforth-translate --forth-to-mlir %s | %FileCheck %s

\ Test that locals work across IF/ELSE/THEN control flow.
\ SSA values defined in the entry block dominate all subsequent blocks.

\ CHECK: func.func private @CLAMP(%arg0: !forth.stack) -> !forth.stack {
\ CHECK: forth.pop
\ CHECK: forth.pop
\ CHECK: forth.pop
\ CHECK: forth.push_value
\ CHECK: forth.push_value
\ CHECK: forth.push_value

\! kernel main
: CLAMP { val lo hi -- }
val lo < IF lo ELSE
val hi > IF hi ELSE
val THEN THEN ;
0 10 5 CLAMP
5 changes: 5 additions & 0 deletions test/Translation/Forth/local-variables-error.forth
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
\ RUN: %not %warpforth-translate --forth-to-mlir %s 2>&1 | %FileCheck %s
\ CHECK: duplicate local variable name: X
\! kernel main
: BAD { x y x -- } x ;
BAD
27 changes: 27 additions & 0 deletions test/Translation/Forth/local-variables.forth
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
\ RUN: %warpforth-translate --forth-to-mlir %s | %FileCheck %s

\ Test basic local variable binding and reference.

\ CHECK: func.func private @ADD3(%arg0: !forth.stack) -> !forth.stack {
\ CHECK: forth.pop %arg0 : !forth.stack -> !forth.stack, i64
\ CHECK: forth.pop %{{.*}} : !forth.stack -> !forth.stack, i64
\ CHECK: forth.pop %{{.*}} : !forth.stack -> !forth.stack, i64
\ CHECK: forth.push_value %{{.*}}, %{{.*}} : !forth.stack, i64 -> !forth.stack
\ CHECK: forth.push_value %{{.*}}, %{{.*}} : !forth.stack, i64 -> !forth.stack
\ CHECK: forth.addi
\ CHECK: forth.push_value %{{.*}}, %{{.*}} : !forth.stack, i64 -> !forth.stack
\ CHECK: forth.addi
\ CHECK: return

\ CHECK: func.func private @SWAP2(%arg0: !forth.stack) -> !forth.stack {
\ CHECK: forth.pop %arg0 : !forth.stack -> !forth.stack, i64
\ CHECK: forth.pop %{{.*}} : !forth.stack -> !forth.stack, i64
\ CHECK: forth.push_value %{{.*}}, %{{.*}} : !forth.stack, i64 -> !forth.stack
\ CHECK: forth.push_value %{{.*}}, %{{.*}} : !forth.stack, i64 -> !forth.stack
\ CHECK: return

\! kernel main
: ADD3 { a b c -- } a b + c + ;
: SWAP2 { x y -- } y x ;
1 2 3 ADD3
10 20 SWAP2