Skip to content

Commit 8944273

Browse files
committed
inline complex isnan
1 parent df9a499 commit 8944273

2 files changed

Lines changed: 9 additions & 8 deletions

File tree

mlx/backend/metal/kernels/complex.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -97,10 +97,6 @@ constexpr bool operator<(complex64_t a, complex64_t b) {
9797
return operator>(b, a);
9898
}
9999

100-
constexpr bool isnan(complex64_t x) {
101-
return isnan(x.real) || isnan(x.imag);
102-
}
103-
104100
constexpr bool operator==(complex64_t a, complex64_t b) {
105101
return a.real == b.real && a.imag == b.imag;
106102
}

mlx/backend/metal/kernels/sort.h

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,15 @@ template <typename T>
4040
struct LessThan {
4141
static constexpr constant T init = Init<T>::v;
4242
METAL_FUNC bool operator()(T a, T b) const {
43-
if constexpr (
44-
metal::is_floating_point_v<T> || metal::is_same_v<T, complex64_t>) {
45-
bool an = isnan(a);
46-
bool bn = isnan(b);
43+
if constexpr (metal::is_floating_point_v<T>) {
44+
bool an = metal::isnan(a);
45+
bool bn = metal::isnan(b);
46+
if (an | bn) {
47+
return (!an) & bn;
48+
}
49+
} else if constexpr (metal::is_same_v<T, complex64_t>) {
50+
bool an = metal::isnan(a.real) || metal::isnan(a.imag);
51+
bool bn = metal::isnan(b.real) || metal::isnan(b.imag);
4752
if (an | bn) {
4853
return (!an) & bn;
4954
}

0 commit comments

Comments
 (0)