-
-
Notifications
You must be signed in to change notification settings - Fork 1
Gonative #33
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
Quafadas
wants to merge
26
commits into
main
Choose a base branch
from
gonative
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Gonative #33
Changes from all commits
Commits
Show all changes
26 commits
Select commit
Hold shift + click to select a range
f803edb
that... works?
Quafadas cf1f3e3
kinda sorta
Quafadas ae60030
intersting
Quafadas bfd008d
repeating fix
Quafadas f9af0a8
This appears to be a vialble scaffold for BLIS
Quafadas 404b00d
formatting
Quafadas 89f20e6
.
Quafadas 57d6604
Merge remote-tracking branch 'origin/main' into gonative
Quafadas 1cfbfe9
.
Quafadas d7391ba
add experimental float support
Quafadas e60e784
naive tensors
Quafadas d979606
.
Quafadas 980dd1a
.
Quafadas 7440ab5
.
Quafadas 960c46f
.
Quafadas 70075f4
.
Quafadas 49ff03a
.
Quafadas 381b573
mlx madness
Quafadas 65fbb87
Merge branch 'main' into gonative
Quafadas 1f88d2b
fix mlx
Quafadas 1118c0a
We successfully call some MLX methods
Quafadas e65baa0
Interesting ... this got somehwhere.
Quafadas ceb3fcb
.
Quafadas 16b109b
.
Quafadas 3b4921e
Merge branch 'main' into gonative
Quafadas 8ae7e14
.
Quafadas File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -22,6 +22,7 @@ metals.sbt | |
| .vscode/settings.json | ||
| sbt-launch.jar | ||
| .scala-build | ||
| .DS_Store | ||
|
|
||
| # npm | ||
| node_modules/ | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,26 +1,65 @@ | ||
| package build.experiments | ||
|
|
||
| import mill.*, scalalib.*, publish.* | ||
| import contrib.jmh.JmhModule | ||
|
|
||
| // mill benchmark.runJmh vecxt.benchmark.AndBooleanBenchmark -jvmArgs --add-modules=jdk.incubator.vector -rf json | ||
|
|
||
| object `package` extends ScalaModule: | ||
| // def enableBsp = false | ||
| def scalaVersion = build.vecxt.jvm.scalaVersion | ||
| override def compileResources = Task { | ||
| super.compileResources() ++ resources() | ||
| } | ||
| def scalacOptions = Seq("-Xmax-inlines:10000") | ||
| override def forkArgs = super.forkArgs() ++ build.vecIncubatorFlag | ||
| // override def mainClass = Some("mnist") | ||
|
|
||
|
|
||
|
|
||
| override def moduleDeps = Seq(build.vecxt.jvm) | ||
| override def mvnDeps = super.mvnDeps() ++ Seq( | ||
| mvn"com.lihaoyi::os-lib::0.10.4", | ||
| mvn"io.github.quafadas::scautable::0.0.28", | ||
| mvn"io.github.quafadas::dedav4s::0.10.0-RC2" | ||
| ) | ||
| end `package` | ||
| object `package` extends ScalaModule { | ||
| // "-Djava.library.path=/opt/homebrew/Cellar/blis/2.0/lib" | ||
|
|
||
| /** | ||
| * Path to the BLIS library. This assumes that you used | ||
| * On mac: brew install blis | ||
| * On ubuntu: apt install libblis3 libblis-dev [Not sure how general these settibngs are] | ||
| * Windows support is not implemented, not sure if BLIS is available on Windows. | ||
| */ | ||
|
|
||
| def pathToBlis = Task { | ||
| import scala.util.Properties | ||
|
|
||
| val osName = Properties.osName.toLowerCase | ||
| if (osName.contains("linux")) { | ||
| """/usr/lib/x86_64-linux-gnu/""" | ||
| } else if (osName.contains("mac")) { | ||
| """/opt/homebrew/Cellar/blis/2.0/lib""" | ||
| } else if (osName.contains("windows")) { | ||
| ??? | ||
| } else { | ||
| throw new Exception(s"Unsupported OS: $osName") | ||
| } | ||
| } | ||
|
|
||
| def scalaVersion = build.vecxt.jvm.scalaVersion | ||
| override def compileResources = Task { | ||
| super.compileResources() ++ resources() | ||
| } | ||
| def scalacOptions: T[Seq[String]] = Seq("-Xmax-inlines:10000") | ||
| override def forkArgs: T[Seq[String]] = super.forkArgs() ++ build.vecIncubatorFlag ++ Seq( | ||
| s"-Djava.library.path=${pathToBlis()}", | ||
| "-Djava.library.path=/Users/simon/Code/mlx-c/build", | ||
| "--enable-native-access=ALL-UNNAMED" | ||
| ) | ||
| override def mainClass = Some("vecxt.experiments.mlxDemo") | ||
| override def moduleDeps: Seq[JavaModule] = Seq(build.vecxt.jvm, build.vecxtensions.jvm, build.generated) | ||
| override def mvnDeps = super.mvnDeps() ++ Seq( | ||
| mvn"com.lihaoyi::os-lib::0.10.4", | ||
| mvn"io.github.quafadas::scautable::0.0.28", | ||
| mvn"io.github.quafadas::dedav4s::0.10.0-RC2" | ||
|
|
||
| ) | ||
|
|
||
| object test extends ScalaTests with TestModule.Munit { | ||
| def scalaVersion = build.vecxt.jvm.scalaVersion | ||
| override def mvnDeps= super.mvnDeps() ++ Seq( | ||
| mvn"org.scalameta::munit::${build.V.munitVersion}" | ||
| ) | ||
|
|
||
| override def moduleDeps: Seq[JavaModule] = Seq(build.experiments) | ||
|
|
||
| override def forkArgs: T[Seq[String]] = super.forkArgs() ++ build.vecIncubatorFlag ++ Seq( | ||
| s"-Djava.library.path=${pathToBlis()}" | ||
| ) | ||
|
|
||
| } | ||
|
|
||
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,36 @@ | ||
| package vecxt.experiments | ||
|
|
||
| import java.lang.foreign.MemorySegment | ||
| import java.lang.foreign.MemoryLayout | ||
| import java.lang.foreign.Arena | ||
| import java.lang.foreign.ValueLayout | ||
| import blis_typed.blis_h | ||
| import scala.collection.mutable | ||
|
|
||
| /** This _should_ be unecessary for blis objects initialised with `blis_obj_create_with_attached_buffer`. From the docs | ||
| * of BLIS blis_obj_create_with_attached_buffer: | ||
| * | ||
| * > Objects initialized via this function should generally not be passed to bli_obj_free(), unless the user wishes to | ||
| * pass p into free(). | ||
| * | ||
| * @param underlying | ||
| */ | ||
| class BlisArena(private val underlying: Arena) extends Arena: | ||
| private val blisObjects = mutable.ListBuffer[MemorySegment]() | ||
|
|
||
| inline def allocate(byteSize: Long, byteAlignment: Long): MemorySegment = underlying.allocate(byteSize, byteAlignment) | ||
|
|
||
| inline def scope = underlying.scope | ||
|
|
||
| def registerBlisObject(obj: MemorySegment): Unit = | ||
| blisObjects += obj | ||
|
|
||
| override def close(): Unit = | ||
| // Free all BLIS objects first | ||
| blisObjects.foreach(blis_h.bli_obj_free) | ||
| blisObjects.clear() | ||
|
|
||
| // Then close the underlying arena | ||
| underlying.close() | ||
| end close | ||
| end BlisArena |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,79 @@ | ||
| package vecxt.experiments | ||
|
|
||
| import java.lang.foreign.MemorySegment | ||
| import java.lang.foreign.Arena | ||
| import java.lang.foreign.ValueLayout | ||
| import blis_typed.blis_h | ||
|
|
||
| opaque type DoubleVector = MemorySegment | ||
|
|
||
| object DoubleVector: | ||
|
|
||
| extension (v: DoubleVector) | ||
|
|
||
| inline def raw: MemorySegment = v | ||
|
|
||
| inline def apply(index: Long): Double = | ||
| v.getAtIndex(ValueLayout.JAVA_DOUBLE, index) | ||
|
|
||
| inline def update(index: Long, value: Double): Unit = | ||
| v.setAtIndex(ValueLayout.JAVA_DOUBLE, index, value) | ||
|
|
||
| inline def length: Long = | ||
| v.byteSize() / ValueLayout.JAVA_DOUBLE.byteSize() | ||
|
|
||
| inline def copy(using arena: Arena): DoubleVector = | ||
| val newM = arena.allocate(ValueLayout.JAVA_DOUBLE, v.length) | ||
| MemorySegment.copy(v, 0L, newM, 0L, v.byteSize()) | ||
| newM | ||
| end copy | ||
|
|
||
| inline def toSeq: Seq[Double] = | ||
| (0L until v.length).map(i => v.getAtIndex(ValueLayout.JAVA_DOUBLE, i)) | ||
|
|
||
| /** This will allocate an object that will be freed after the arena is closed. It does _not_ cleanup this memory | ||
| * segment itself on free. | ||
| * | ||
| * https://github.com/flame/blis/blob/master/docs/BLISObjectAPI.md#object-management > Objects initialized via this | ||
| * function should generally not be passed to bli_obj_free(), unless the user wishes to pass p into free(). | ||
| * @param arena | ||
| * @return | ||
| */ | ||
| inline def blis_obj_t(using arena: Arena) = | ||
| val objSegment = arena.allocate(512L) | ||
|
|
||
| blis_h.bli_obj_create_with_attached_buffer( | ||
| blis_h.BLIS_DOUBLE(), // dt: BLIS_DOUBLE for double precision | ||
| 1L, // m: 1 row (row vector) | ||
| v.length, // n: length columns | ||
| v.raw, // p: pointer to the actual data | ||
| v.length, // rs: row stride = length (distance between rows) | ||
| 1L, // cs: column stride = 1 (contiguous elements) | ||
| objSegment // obj: output object | ||
| ) | ||
| objSegment | ||
| end blis_obj_t | ||
|
|
||
| def +=(vec2: DoubleVector)(using arena: Arena): Unit = | ||
| blis_h.bli_addv(vec2.blis_obj_t, v.blis_obj_t) | ||
|
|
||
| def +(vec2: DoubleVector)(using arena: Arena): DoubleVector = | ||
| val result = v.copy | ||
| result += vec2 | ||
| result | ||
| end + | ||
| end extension | ||
|
|
||
| // Static methods for creating DoubleVector instances | ||
| inline def ofSize(size: Long)(using arena: Arena): DoubleVector = | ||
| arena.allocate(ValueLayout.JAVA_DOUBLE, size) | ||
|
|
||
| inline def apply(size: Long)(using arena: Arena): DoubleVector = | ||
| arena.allocate(ValueLayout.JAVA_DOUBLE, size) | ||
|
|
||
| inline def apply(data: Seq[Double])(using arena: Arena): DoubleVector = | ||
| arena.allocateFrom( | ||
| ValueLayout.JAVA_DOUBLE, | ||
| data* | ||
| ) | ||
| end DoubleVector |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,81 @@ | ||
| package vecxt.experiments | ||
|
|
||
| import java.lang.foreign.MemorySegment | ||
| import java.lang.foreign.Arena | ||
| import java.lang.foreign.ValueLayout | ||
|
|
||
| import blis_typed.blis_h | ||
|
|
||
| opaque type FloatVector = MemorySegment | ||
|
|
||
| object FloatVector: | ||
|
|
||
| extension (v: FloatVector) | ||
|
|
||
| inline def raw: MemorySegment = v | ||
|
|
||
| inline def apply(index: Long): Float = | ||
| v.getAtIndex(ValueLayout.JAVA_FLOAT, index) | ||
|
|
||
| inline def update(index: Long, value: Float): Unit = | ||
| v.setAtIndex(ValueLayout.JAVA_FLOAT, index, value) | ||
|
|
||
| inline def length: Long = | ||
| v.byteSize() / ValueLayout.JAVA_FLOAT.byteSize() | ||
|
|
||
| inline def copy(using arena: Arena): FloatVector = | ||
| val newM = arena.allocate(ValueLayout.JAVA_FLOAT, v.length) | ||
| MemorySegment.copy(v, 0L, newM, 0L, v.byteSize()) | ||
| newM | ||
| end copy | ||
|
|
||
| inline def toSeq: Seq[Float] = | ||
| (0L until v.length).map(i => v.getAtIndex(ValueLayout.JAVA_FLOAT, i)) | ||
|
|
||
| /** This will allocate an object that will be freed after the arena is closed. It does _not_ cleanup this memory | ||
| * segment itself on free. | ||
| * | ||
| * https://github.com/flame/blis/blob/master/docs/BLISObjectAPI.md#object-management > Objects initialized via this | ||
| * function should generally not be passed to bli_obj_free(), unless the user wishes to pass p Floato free(). | ||
| * @param arena | ||
| * @return | ||
| */ | ||
| inline def blis_obj_t(using arena: Arena) = | ||
| val objSegment = arena.allocate(512L) | ||
|
|
||
| blis_h.bli_obj_create_with_attached_buffer( | ||
| blis_h.BLIS_FLOAT(), // dt: BLIS_DOUBLE for double precision | ||
| 1L, // m: 1 row (row vector) | ||
| v.length, // n: length columns | ||
| v.raw, // p: poFloater to the actual data | ||
| v.length, // rs: row stride = length (distance between rows) | ||
| 1L, // cs: column stride = 1 (contiguous elements) | ||
| objSegment // obj: output object | ||
| ) | ||
| objSegment | ||
| end blis_obj_t | ||
|
|
||
| def +=(vec2: FloatVector)(using Arena): Unit = | ||
| blis_h.bli_addv(vec2.blis_obj_t, v.blis_obj_t) | ||
|
|
||
| def +(vec2: FloatVector)(using arena: Arena): FloatVector = | ||
| val result = v.copy | ||
| result += vec2 | ||
| result | ||
| end + | ||
| end extension | ||
|
|
||
| // Static methods for FloatVector creation | ||
|
|
||
| inline def ofSize(size: Long)(using arena: Arena): FloatVector = | ||
| arena.allocate(ValueLayout.JAVA_FLOAT, size) | ||
|
|
||
| inline def apply(size: Long)(using arena: Arena): FloatVector = | ||
| arena.allocate(ValueLayout.JAVA_FLOAT, size) | ||
|
|
||
| inline def apply(data: Seq[Float])(using arena: Arena): FloatVector = | ||
| arena.allocateFrom( | ||
| ValueLayout.JAVA_FLOAT, | ||
| data* | ||
| ) | ||
| end FloatVector | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's a typo in the comment: 'poFloater' should be 'pointer'.