Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/src/array.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ void init_array(nb::module_& m) {
nb::is_weak_referenceable())
.def(
"__init__",
[](mx::array* aptr, ArrayInitType v, std::optional<mx::Dtype> t) {
[](mx::array* aptr, nb::object v, std::optional<mx::Dtype> t) {
new (aptr) mx::array(create_array(v, t));
},
"val"_a,
Expand Down
72 changes: 49 additions & 23 deletions python/src/convert.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,7 @@ mx::array array_from_list_impl(T pl, std::optional<mx::Dtype> dtype) {
// `pl` contains mlx arrays
std::vector<mx::array> arrays;
for (auto l : pl) {
arrays.push_back(create_array(nb::cast<ArrayInitType>(l), dtype));
arrays.push_back(create_array(nb::cast<nb::object>(l), dtype));
}
return mx::stack(arrays);
}
Expand All @@ -467,38 +467,64 @@ mx::array array_from_list(nb::tuple pl, std::optional<mx::Dtype> dtype) {
return array_from_list_impl(pl, dtype);
}

mx::array create_array(ArrayInitType v, std::optional<mx::Dtype> t) {
if (auto pv = std::get_if<nb::bool_>(&v); pv) {
return mx::array(nb::cast<bool>(*pv), t.value_or(mx::bool_));
} else if (auto pv = std::get_if<nb::int_>(&v); pv) {
auto val = nb::cast<int64_t>(*pv);
mx::array create_array(nb::object v, std::optional<mx::Dtype> t) {
if (nb::hasattr(v, "dtype")) {
Comment thread
zcbenz marked this conversation as resolved.
Outdated
nb::object dtype_obj = v.attr("dtype");
if (nb::str(dtype_obj).equal(nb::str("bfloat16"))) {
nb::object module_obj = v.attr("__class__").attr("__module__");
auto type_mod = nb::str(module_obj);
if (type_mod.equal(nb::str("numpy")) ||
type_mod.equal(nb::str("ml_dtypes"))) {
auto uint16_view = v.attr("view")("uint16");
using ContigArray =
nb::ndarray<uint16_t, nb::ro, nb::c_contig, nb::device::cpu>;
auto nd_arr = nb::cast<ContigArray>(uint16_view);
auto shape = nb::cast<mx::Shape>(v.attr("shape"));
const mx::bfloat16_t* typed_ptr =
reinterpret_cast<const mx::bfloat16_t*>(nd_arr.data());
auto res = (shape.empty()) ? mx::array(*typed_ptr, mx::bfloat16)
: mx::array(typed_ptr, shape, mx::bfloat16);
if (t.has_value())
res = mx::astype(res, *t);
return res;
}
}
}
if (nb::isinstance<nb::bool_>(v)) {
return mx::array(nb::cast<bool>(v), t.value_or(mx::bool_));
} else if (nb::isinstance<nb::int_>(v)) {
auto val = nb::cast<int64_t>(v);
auto default_type = (val > std::numeric_limits<int>::max() ||
val < std::numeric_limits<int>::min())
? mx::int64
: mx::int32;
return mx::array(val, t.value_or(default_type));
} else if (auto pv = std::get_if<nb::float_>(&v); pv) {
} else if (nb::isinstance<nb::float_>(v)) {
auto out_type = t.value_or(mx::float32);
if (out_type == mx::float64) {
return mx::array(nb::cast<double>(*pv), out_type);
return mx::array(nb::cast<double>(v), out_type);
} else {
return mx::array(nb::cast<float>(*pv), out_type);
return mx::array(nb::cast<float>(v), out_type);
}
} else if (auto pv = std::get_if<std::complex<float>>(&v); pv) {
} else if (PyComplex_Check(v.ptr())) {
return mx::array(
static_cast<mx::complex64_t>(*pv), t.value_or(mx::complex64));
} else if (auto pv = std::get_if<nb::list>(&v); pv) {
return array_from_list(*pv, t);
} else if (auto pv = std::get_if<nb::tuple>(&v); pv) {
return array_from_list(*pv, t);
} else if (auto pv = std::get_if<
nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu>>(&v);
pv) {
return nd_array_to_mlx(*pv, t);
} else if (auto pv = std::get_if<mx::array>(&v); pv) {
return mx::astype(*pv, t.value_or((*pv).dtype()));
} else {
auto arr = to_array_with_accessor(std::get<ArrayLike>(v).obj);
static_cast<mx::complex64_t>(nb::cast<std::complex<float>>(v)),
t.value_or(mx::complex64));
} else if (nb::isinstance<nb::list>(v)) {
return array_from_list(nb::cast<nb::list>(v), t);
} else if (nb::isinstance<nb::tuple>(v)) {
return array_from_list(nb::cast<nb::tuple>(v), t);
} else if (nb::isinstance<mx::array>(v)) {
auto arr = nb::cast<mx::array>(v);
return mx::astype(arr, t.value_or(arr.dtype()));
} else {
try {
using ContigArray = nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu>;
auto nd = nb::cast<ContigArray>(v);
return nd_array_to_mlx(nd, t);
} catch (const nb::cast_error&) {
auto arr = to_array_with_accessor(v);
return mx::astype(arr, t.value_or(arr.dtype()));
}
}
}
15 changes: 1 addition & 14 deletions python/src/convert.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,6 @@ struct ArrayLike {
nb::object obj;
};

using ArrayInitType = std::variant<
nb::bool_,
nb::int_,
nb::float_,
// Must be above ndarray
mx::array,
// Must be above complex
nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu>,
std::complex<float>,
nb::list,
nb::tuple,
ArrayLike>;

mx::array nd_array_to_mlx(
nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu> nd_array,
std::optional<mx::Dtype> dtype);
Expand All @@ -45,6 +32,6 @@ nb::object to_scalar(mx::array& a);

nb::object tolist(mx::array& a);

mx::array create_array(ArrayInitType v, std::optional<mx::Dtype> t);
mx::array create_array(nb::object v, std::optional<mx::Dtype> t);
mx::array array_from_list(nb::list pl, std::optional<mx::Dtype> dtype);
mx::array array_from_list(nb::tuple pl, std::optional<mx::Dtype> dtype);
2 changes: 1 addition & 1 deletion python/src/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1719,7 +1719,7 @@ void init_ops(nb::module_& m) {
)pbdoc");
m.def(
"asarray",
[](const ArrayInitType& a, std::optional<mx::Dtype> dtype) {
[](const nb::object& a, std::optional<mx::Dtype> dtype) {
return create_array(a, dtype);
},
nb::arg(),
Expand Down
32 changes: 32 additions & 0 deletions python/tests/test_bf16.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,13 @@
except ImportError as e:
has_torch = False

try:
import ml_dtypes
Comment thread
kellen-sun marked this conversation as resolved.

has_ml_dtypes = True
except ImportError:
has_ml_dtypes = False


class TestBF16(mlx_tests.MLXTestCase):
def __test_ops(
Expand Down Expand Up @@ -191,6 +198,31 @@ def test_conversion(self):
self.assertEqual(a_mx.dtype, mx.bfloat16)
self.assertTrue(mx.array_equal(a_mx, expected))

@unittest.skipIf(not has_ml_dtypes, "requires ml_dtypes")
def test_conversion_ml_dtypes(self):
x_scalar = np.array(1.5, dtype=ml_dtypes.bfloat16)
a_scalar = mx.array(x_scalar)
self.assertEqual(a_scalar.dtype, mx.bfloat16)
self.assertEqual(a_scalar.shape, ())
self.assertEqual(a_scalar.item(), 1.5)

data = [1.5, 2.5, 3.5]
x_vector = np.array(data, dtype=ml_dtypes.bfloat16)
a_vector = mx.array(x_vector)
expected = mx.array(data, dtype=mx.bfloat16)
self.assertEqual(a_vector.dtype, mx.bfloat16)
self.assertEqual(a_vector.shape, (3,))
self.assertTrue(mx.array_equal(a_vector, expected))

a_cast = mx.array(x_scalar, dtype=mx.float32)
self.assertEqual(a_cast.dtype, mx.float32)
self.assertEqual(a_cast.item(), 1.5)

a_asarray = mx.asarray(x_vector)
self.assertEqual(a_asarray.dtype, mx.bfloat16)
self.assertEqual(a_asarray.shape, (3,))
self.assertTrue(mx.array_equal(a_asarray, expected))


if __name__ == "__main__":
mlx_tests.MLXTestRunner()
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ def get_tag(self) -> tuple[str, str, str]:
"psutil",
"torch>=2.9",
"typing_extensions",
"ml_dtypes",
Comment thread
kellen-sun marked this conversation as resolved.
Outdated
],
}
entry_points = {
Expand Down