Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CLAUDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ uv run ruff format gpu_test/

- **Stack Type**: `!forth.stack` - untyped stack, programmer ensures type safety
- **Operations**: All take stack as input and produce stack as output (except `forth.stack`)
- **Supported Words**: literals (integer `42` and float `3.14`), `DUP DROP SWAP OVER ROT NIP TUCK PICK ROLL`, `+ - * / MOD`, `F+ F- F* F/` (float arithmetic), `AND OR XOR NOT LSHIFT RSHIFT`, `= < > <> <= >= 0=`, `F= F< F> F<> F<= F>=` (float comparison), `S>F F>S` (int/float conversion), `@ !` (global memory), `F@ F!` (float global memory), `S@ S!` (shared memory), `SF@ SF!` (float shared memory), `CELLS`, `IF ELSE THEN`, `BEGIN UNTIL`, `BEGIN WHILE REPEAT`, `DO LOOP +LOOP I J K`, `LEAVE UNLOOP EXIT`, `{ a b -- }` (local variables in word definitions), `TID-X/Y/Z BID-X/Y/Z BDIM-X/Y/Z GDIM-X/Y/Z GLOBAL-ID` (GPU indexing).
- **Supported Words**: literals (integer `42` and float `3.14`), `DUP DROP SWAP OVER ROT NIP TUCK PICK ROLL`, `+ - * / MOD`, `F+ F- F* F/` (float arithmetic), `FEXP FSQRT FLOG FABS FNEG` (float math intrinsics), `FMAX FMIN` (float min/max), `AND OR XOR NOT LSHIFT RSHIFT`, `= < > <> <= >= 0=`, `F= F< F> F<> F<= F>=` (float comparison), `S>F F>S` (int/float conversion), `@ !` (global memory), `F@ F!` (float global memory), `S@ S!` (shared memory), `SF@ SF!` (float shared memory), `CELLS`, `IF ELSE THEN`, `BEGIN UNTIL`, `BEGIN WHILE REPEAT`, `DO LOOP +LOOP I J K`, `LEAVE UNLOOP EXIT`, `{ a b -- }` (local variables in word definitions), `TID-X/Y/Z BID-X/Y/Z BDIM-X/Y/Z GDIM-X/Y/Z GLOBAL-ID` (GPU indexing).
- **Float Literals**: Numbers containing `.` or `e`/`E` are parsed as f64 (e.g. `3.14`, `-2.0`, `1.0e-5`, `1e3`). Stored on the stack as i64 bit patterns; F-prefixed words perform bitcast before/after operations.
- **Kernel Parameters**: Declared in the `\!` header. `\! kernel <name>` is required and must appear first. `\! param <name> i64[<N>]` becomes a `memref<Nxi64>` argument; `\! param <name> i64` becomes an `i64` argument. `\! param <name> f64[<N>]` becomes a `memref<Nxf64>` argument; `\! param <name> f64` becomes an `f64` argument (bitcast to i64 when pushed to stack). Using a param name in code emits `forth.param_ref` (arrays push address; scalars push value).
- **Shared Memory**: `\! shared <name> i64[<N>]` or `\! shared <name> f64[<N>]` declares GPU shared (workgroup) memory. Emits a tagged `memref.alloca` at kernel entry; ForthToGPU converts it to a `gpu.func` workgroup attribution. Using the shared name in code pushes its base address onto the stack. Use `S@`/`S!` for i64 or `SF@`/`SF!` for f64 shared accesses. Cannot be referenced inside word definitions.
Expand Down
3 changes: 2 additions & 1 deletion include/warpforth/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ def ConvertForthToMemRef
let dependentDialects = ["mlir::memref::MemRefDialect",
"mlir::arith::ArithDialect",
"mlir::LLVM::LLVMDialect",
"mlir::cf::ControlFlowDialect"];
"mlir::cf::ControlFlowDialect",
"mlir::math::MathDialect"];
}

def ConvertForthToGPU : Pass<"convert-forth-to-gpu", "mlir::ModuleOp"> {
Expand Down
65 changes: 65 additions & 0 deletions include/warpforth/Dialect/Forth/ForthOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,71 @@ def Forth_DivFOp : Forth_StackOpBase<"divf"> {
}];
}

//===----------------------------------------------------------------------===//
// Float math intrinsic operations.
//===----------------------------------------------------------------------===//

def Forth_ExpFOp : Forth_StackOpBase<"expf"> {
let summary = "Exponential of top stack element (float)";
let description = [{
Pops an i64 (f64 bit pattern), bitcasts to f64, computes e^x,
bitcasts result back to i64.
Forth semantics: ( f -- exp(f) )
}];
}

def Forth_SqrtFOp : Forth_StackOpBase<"sqrtf"> {
let summary = "Square root of top stack element (float)";
let description = [{
Pops an i64 (f64 bit pattern), bitcasts to f64, computes sqrt(x),
bitcasts result back to i64.
Forth semantics: ( f -- sqrt(f) )
}];
}

def Forth_LogFOp : Forth_StackOpBase<"logf"> {
let summary = "Natural logarithm of top stack element (float)";
let description = [{
Pops an i64 (f64 bit pattern), bitcasts to f64, computes ln(x),
bitcasts result back to i64.
Forth semantics: ( f -- log(f) )
}];
}

def Forth_AbsFOp : Forth_StackOpBase<"absf"> {
let summary = "Absolute value of top stack element (float)";
let description = [{
Pops an i64 (f64 bit pattern), bitcasts to f64, computes |x|,
bitcasts result back to i64.
Forth semantics: ( f -- |f| )
}];
}

def Forth_NegFOp : Forth_StackOpBase<"negf"> {
let summary = "Negate top stack element (float)";
let description = [{
Pops an i64 (f64 bit pattern), bitcasts to f64, negates,
bitcasts result back to i64.
Forth semantics: ( f -- -f )
}];
}

def Forth_MaxFOp : Forth_StackOpBase<"maxf"> {
let summary = "Maximum of top two stack elements (float)";
let description = [{
Pops two i64 values, bitcasts to f64, computes max, bitcasts result back to i64.
Forth semantics: ( f1 f2 -- max(f1,f2) )
}];
}

def Forth_MinFOp : Forth_StackOpBase<"minf"> {
let summary = "Minimum of top two stack elements (float)";
let description = [{
Pops two i64 values, bitcasts to f64, computes min, bitcasts result back to i64.
Forth semantics: ( f1 f2 -- min(f1,f2) )
}];
}

//===----------------------------------------------------------------------===//
// Bitwise operations.
//===----------------------------------------------------------------------===//
Expand Down
1 change: 1 addition & 0 deletions lib/Conversion/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ add_mlir_library(MLIRConversionPasses
MLIRGPUDialect
MLIRGPUToNVVMTransforms
MLIRGPUTransforms
MLIRMathToLLVM
MLIRReconcileUnrealizedCasts
MLIRTransforms
)
Expand Down
1 change: 1 addition & 0 deletions lib/Conversion/ForthToMemRef/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,6 @@ add_mlir_conversion_library(MLIRForthToMemRefConversion
MLIRLLVMDialect
MLIRFuncDialect
MLIRControlFlowDialect
MLIRMathDialect
MLIRForth
)
61 changes: 60 additions & 1 deletion lib/Conversion/ForthToMemRef/ForthToMemRef.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h"
Expand Down Expand Up @@ -514,6 +515,60 @@ using MulFOpConversion =
using DivFOpConversion =
BinaryArithOpConversion<forth::DivFOp, arith::DivFOp, true>;

// Float binary intrinsics (max/min)
using MaxFOpConversion =
BinaryArithOpConversion<forth::MaxFOp, arith::MaximumFOp, true>;
using MinFOpConversion =
BinaryArithOpConversion<forth::MinFOp, arith::MinimumFOp, true>;

/// Base template for unary float operations.
/// Pops one value, applies operation, pushes result: (f -- result)
/// Bitcasts i64->f64 before the op and f64->i64 after.
template <typename ForthOp, typename MathOp>
struct UnaryFloatOpConversion : public OpConversionPattern<ForthOp> {
UnaryFloatOpConversion(const TypeConverter &typeConverter,
MLIRContext *context)
: OpConversionPattern<ForthOp>(typeConverter, context) {}
using OneToNOpAdaptor =
typename OpConversionPattern<ForthOp>::OneToNOpAdaptor;

LogicalResult
matchAndRewrite(ForthOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
ValueRange inputStack = adaptor.getOperands()[0];
Value memref = inputStack[0];
Value stackPtr = inputStack[1];

// Load value from top of stack
Value a = rewriter.create<memref::LoadOp>(loc, memref, stackPtr);

// Bitcast i64 -> f64
auto f64Type = rewriter.getF64Type();
Value aF = rewriter.create<arith::BitcastOp>(loc, f64Type, a);

// Apply math/arith op
Value resF = rewriter.create<MathOp>(loc, aF);

// Bitcast f64 -> i64
Value result =
rewriter.create<arith::BitcastOp>(loc, rewriter.getI64Type(), resF);

// Store result at same position (SP unchanged — unary op)
rewriter.create<memref::StoreOp>(loc, result, memref, stackPtr);

rewriter.replaceOpWithMultiple(op, {{memref, stackPtr}});
return success();
}
};

// Float unary intrinsics
using ExpFOpConversion = UnaryFloatOpConversion<forth::ExpFOp, math::ExpOp>;
using SqrtFOpConversion = UnaryFloatOpConversion<forth::SqrtFOp, math::SqrtOp>;
using LogFOpConversion = UnaryFloatOpConversion<forth::LogFOp, math::LogOp>;
using AbsFOpConversion = UnaryFloatOpConversion<forth::AbsFOp, math::AbsFOp>;
using NegFOpConversion = UnaryFloatOpConversion<forth::NegFOp, arith::NegFOp>;

/// Base template for binary comparison operations.
/// Pops two values, compares, pushes -1 (true) or 0 (false): (a b -- flag)
/// When IsFloat=true, bitcasts i64->f64 before comparing.
Expand Down Expand Up @@ -1153,7 +1208,8 @@ struct ConvertForthToMemRefPass

// Mark MemRef, Arith, LLVM, and CF dialects as legal
target.addLegalDialect<memref::MemRefDialect, arith::ArithDialect,
LLVM::LLVMDialect, cf::ControlFlowDialect>();
LLVM::LLVMDialect, cf::ControlFlowDialect,
math::MathDialect>();

// Mark IntrinsicOp and BarrierOp as legal (to be lowered later)
target.addLegalOp<forth::IntrinsicOp>();
Expand Down Expand Up @@ -1205,6 +1261,9 @@ struct ConvertForthToMemRefPass
ModOpConversion,
// Float arithmetic
AddFOpConversion, SubFOpConversion, MulFOpConversion, DivFOpConversion,
// Float math intrinsics
ExpFOpConversion, SqrtFOpConversion, LogFOpConversion, AbsFOpConversion,
NegFOpConversion, MaxFOpConversion, MinFOpConversion,
// Bitwise
AndOpConversion, OrOpConversion, XorOpConversion, NotOpConversion,
LshiftOpConversion, RshiftOpConversion,
Expand Down
8 changes: 6 additions & 2 deletions lib/Conversion/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include "warpforth/Conversion/Passes.h"
#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
#include "mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h"
#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
Expand Down Expand Up @@ -40,10 +41,13 @@ void buildWarpForthPipeline(OpPassManager &pm) {
pm.addNestedPass<gpu::GPUModuleOp>(
createConvertGpuOpsToNVVMOps(gpuToNVVMOptions));

// Stage 6: Lower NVVM to LLVM
// Stage 6: Lower math ops to LLVM intrinsics inside GPU module
pm.addNestedPass<gpu::GPUModuleOp>(createConvertMathToLLVMPass());

// Stage 7: Lower NVVM to LLVM
pm.addPass(createConvertNVVMToLLVMPass());

// Stage 7: Reconcile type conversions
// Stage 8: Reconcile type conversions
pm.addPass(createReconcileUnrealizedCastsPass());

// Stage 8: Compile GPU module to PTX binary
Expand Down
21 changes: 21 additions & 0 deletions lib/Translation/ForthToMLIR/ForthToMLIR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,27 @@ Value ForthParser::emitOperation(StringRef word, Value inputStack,
} else if (word == "F/") {
return builder.create<forth::DivFOp>(loc, stackType, inputStack)
.getResult();
} else if (word == "FEXP") {
return builder.create<forth::ExpFOp>(loc, stackType, inputStack)
.getResult();
} else if (word == "FSQRT") {
return builder.create<forth::SqrtFOp>(loc, stackType, inputStack)
.getResult();
} else if (word == "FLOG") {
return builder.create<forth::LogFOp>(loc, stackType, inputStack)
.getResult();
} else if (word == "FABS") {
return builder.create<forth::AbsFOp>(loc, stackType, inputStack)
.getResult();
} else if (word == "FNEG") {
return builder.create<forth::NegFOp>(loc, stackType, inputStack)
.getResult();
} else if (word == "FMAX") {
return builder.create<forth::MaxFOp>(loc, stackType, inputStack)
.getResult();
} else if (word == "FMIN") {
return builder.create<forth::MinFOp>(loc, stackType, inputStack)
.getResult();
} else if (word == "MOD") {
return builder.create<forth::ModOp>(loc, stackType, inputStack).getResult();
} else if (word == "AND") {
Expand Down
70 changes: 70 additions & 0 deletions test/Conversion/ForthToMemRef/float-math-intrinsics.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
// RUN: %warpforth-opt --convert-forth-to-memref %s | %FileCheck %s

// CHECK-LABEL: func.func private @main

// expf: load, bitcast i64->f64, math.exp, bitcast f64->i64, store (SP unchanged)
// CHECK: memref.load
// CHECK: arith.bitcast %{{.*}} : i64 to f64
// CHECK: math.exp %{{.*}} : f64
// CHECK: arith.bitcast %{{.*}} : f64 to i64
// CHECK: memref.store

// sqrtf: load, bitcast, math.sqrt, bitcast, store
// CHECK: memref.load
// CHECK: arith.bitcast %{{.*}} : i64 to f64
// CHECK: math.sqrt %{{.*}} : f64
// CHECK: arith.bitcast %{{.*}} : f64 to i64
// CHECK: memref.store

// logf: load, bitcast, math.log, bitcast, store
// CHECK: memref.load
// CHECK: arith.bitcast %{{.*}} : i64 to f64
// CHECK: math.log %{{.*}} : f64
// CHECK: arith.bitcast %{{.*}} : f64 to i64
// CHECK: memref.store

// absf: load, bitcast, math.absf, bitcast, store
// CHECK: memref.load
// CHECK: arith.bitcast %{{.*}} : i64 to f64
// CHECK: math.absf %{{.*}} : f64
// CHECK: arith.bitcast %{{.*}} : f64 to i64
// CHECK: memref.store

// negf: load, bitcast, arith.negf, bitcast, store
// CHECK: memref.load
// CHECK: arith.bitcast %{{.*}} : i64 to f64
// CHECK: arith.negf %{{.*}} : f64
// CHECK: arith.bitcast %{{.*}} : f64 to i64
// CHECK: memref.store

// maxf: binary — pop two, bitcast, arith.maximumf, bitcast, store
// CHECK: memref.load
// CHECK: arith.subi
// CHECK: memref.load
// CHECK: arith.bitcast %{{.*}} : i64 to f64
// CHECK: arith.bitcast %{{.*}} : i64 to f64
// CHECK: arith.maximumf %{{.*}}, %{{.*}} : f64
// CHECK: arith.bitcast %{{.*}} : f64 to i64
// CHECK: memref.store

// minf: binary — pop two, bitcast, arith.minimumf, bitcast, store
// CHECK: arith.bitcast %{{.*}} : i64 to f64
// CHECK: arith.bitcast %{{.*}} : i64 to f64
// CHECK: arith.minimumf %{{.*}}, %{{.*}} : f64
// CHECK: arith.bitcast %{{.*}} : f64 to i64

module {
func.func private @main() {
%0 = forth.stack !forth.stack
%1 = forth.constant %0(1.000000e+00 : f64) : !forth.stack -> !forth.stack
%2 = forth.expf %1 : !forth.stack -> !forth.stack
%3 = forth.sqrtf %2 : !forth.stack -> !forth.stack
%4 = forth.logf %3 : !forth.stack -> !forth.stack
%5 = forth.absf %4 : !forth.stack -> !forth.stack
%6 = forth.negf %5 : !forth.stack -> !forth.stack
%7 = forth.constant %6(2.000000e+00 : f64) : !forth.stack -> !forth.stack
%8 = forth.maxf %7 : !forth.stack -> !forth.stack
%9 = forth.minf %8 : !forth.stack -> !forth.stack
return
}
}
12 changes: 12 additions & 0 deletions test/Pipeline/float-math-intrinsics.forth
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
\ RUN: %warpforth-translate --forth-to-mlir %s | %warpforth-opt --warpforth-pipeline | %FileCheck %s

\ Verify float math intrinsics lower through the full pipeline to gpu.binary
\ CHECK: gpu.binary @warpforth_module

\! kernel main
\! param data f64[256]
GLOBAL-ID CELLS data + F@
FABS FEXP FSQRT FLOG FNEG
GLOBAL-ID CELLS data + F@
FMAX FMIN
GLOBAL-ID CELLS data + F!
21 changes: 21 additions & 0 deletions test/Translation/Forth/float-math-intrinsics.forth
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
\ RUN: %warpforth-translate --forth-to-mlir %s | %FileCheck %s

\ Verify float math intrinsic ops parse correctly

\ Unary ops
\ CHECK: %[[S0:.*]] = forth.stack
\ CHECK: %[[S1:.*]] = forth.constant %[[S0]]
\ CHECK: %[[S2:.*]] = forth.expf %[[S1]]
\ CHECK: %[[S3:.*]] = forth.sqrtf %[[S2]]
\ CHECK: %[[S4:.*]] = forth.logf %[[S3]]
\ CHECK: %[[S5:.*]] = forth.absf %[[S4]]
\ CHECK: %[[S6:.*]] = forth.negf %[[S5]]

\ Binary ops
\ CHECK: %[[S7:.*]] = forth.constant %[[S6]]
\ CHECK: %[[S8:.*]] = forth.maxf %[[S7]]
\ CHECK: %[[S9:.*]] = forth.minf %[[S8]]

\! kernel main
1.0 FEXP FSQRT FLOG FABS FNEG
2.0 FMAX FMIN