Skip to content

Commit 8f7ce77

Browse files
barronalexAlex Barrondavidkoski
authored
Update mlx-c to 0.1.2 (#204)
* update to mlx-c 0.1.2 * add linalg + kron + flatten * format * fix cmake build * Add David's updates + better float64 support. Co-authored-by: DavidKoski <46639364+davidkoski@users.noreply.github.com> * remove unwraps --------- Co-authored-by: Alex Barron <abarron22@apple.com> Co-authored-by: DavidKoski <46639364+davidkoski@users.noreply.github.com>
1 parent 5de5405 commit 8f7ce77

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

72 files changed

+2232
-1661
lines changed

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ endif()
1111
FetchContent_Declare(
1212
mlx-c
1313
GIT_REPOSITORY "https://github.com/ml-explore/mlx-c.git"
14-
GIT_TAG "v0.1.0")
14+
GIT_TAG "v0.1.2")
1515
FetchContent_MakeAvailable(mlx-c)
1616

1717
# swift-numerics

Package.swift

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,8 @@ let package = Package(
5151
"fmt/src/fmt.cc",
5252

5353
// these are selected conditionally
54-
// via mlx-conditional/compiled_conditional.cpp
55-
"mlx/mlx/backend/common/compiled_nocpu.cpp",
56-
"mlx/mlx/backend/common/compiled_cpu.cpp",
54+
"mlx/mlx/backend/no_cpu/compiled.cpp",
55+
"mlx/mlx/backend/cpu/compiled.cpp",
5756

5857
// mlx files that are not part of the build
5958
"mlx/ACKNOWLEDGMENTS.md",
@@ -79,10 +78,6 @@ let package = Package(
7978

8079
"mlx/mlx/backend/common/default_primitives.cpp",
8180

82-
// this uses neon code and will not build on x86 (e.g. via Release).
83-
// see mlx-conditional/accelerate-softmax.cpp
84-
"mlx/mlx/backend/accelerate/softmax.cpp",
85-
8681
// build variants (we are opting _out_ of these)
8782
"mlx/mlx/io/no_safetensors.cpp",
8883
"mlx/mlx/io/gguf.cpp",
@@ -93,9 +88,13 @@ let package = Package(
9388
"mlx/mlx/backend/metal/nojit_kernels.cpp",
9489

9590
// do not build distributed support (yet)
96-
"mlx/mlx/distributed/mpi",
91+
"mlx/mlx/distributed/mpi/mpi.cpp",
9792
"mlx/mlx/distributed/ops.cpp",
9893
"mlx/mlx/distributed/primitives.cpp",
94+
"mlx/mlx/distributed/ring/ring.cpp",
95+
96+
"mlx/mlx/backend/cpu/gemms/no_bf16.cpp",
97+
"mlx/mlx/backend/cpu/gemms/no_fp16.cpp",
9998
],
10099

101100
cSettings: [
@@ -110,6 +109,7 @@ let package = Package(
110109
.headerSearchPath("json/single_include/nlohmann"),
111110
.headerSearchPath("fmt/include"),
112111

112+
.define("MLX_USE_ACCELERATE"),
113113
.define("ACCELERATE_NEW_LAPACK"),
114114
.define("_METAL_"),
115115
.define("SWIFTPM_BUNDLE", to: "\"mlx-swift_Cmlx\""),

Source/Cmlx/include/mlx/c/array.h

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ typedef enum mlx_dtype_ {
4646
MLX_INT64,
4747
MLX_FLOAT16,
4848
MLX_FLOAT32,
49+
MLX_FLOAT64,
4950
MLX_BFLOAT16,
5051
MLX_COMPLEX64,
5152
} mlx_dtype;
@@ -78,10 +79,24 @@ mlx_array mlx_array_new_bool(bool val);
7879
* New array from a int scalar.
7980
*/
8081
mlx_array mlx_array_new_int(int val);
82+
/**
83+
* New array from a float32 scalar.
84+
*/
85+
mlx_array mlx_array_new_float32(float val);
8186
/**
8287
* New array from a float scalar.
88+
* Same as float32.
8389
*/
8490
mlx_array mlx_array_new_float(float val);
91+
/**
92+
* New array from a float64 scalar.
93+
*/
94+
mlx_array mlx_array_new_float64(double val);
95+
/**
96+
* New array from a double scalar.
97+
* Same as float64.
98+
*/
99+
mlx_array mlx_array_new_double(double val);
85100
/**
86101
* New array from a complex scalar.
87102
*/
@@ -110,10 +125,22 @@ int mlx_array_set_bool(mlx_array* arr, bool val);
110125
* Set array to a int scalar.
111126
*/
112127
int mlx_array_set_int(mlx_array* arr, int val);
128+
/**
129+
* Set array to a float32 scalar.
130+
*/
131+
int mlx_array_set_float32(mlx_array* arr, float val);
113132
/**
114133
* Set array to a float scalar.
115134
*/
116135
int mlx_array_set_float(mlx_array* arr, float val);
136+
/**
137+
* Set array to a float64 scalar.
138+
*/
139+
int mlx_array_set_float64(mlx_array* arr, double val);
140+
/**
141+
* Set array to a double scalar.
142+
*/
143+
int mlx_array_set_double(mlx_array* arr, double val);
117144
/**
118145
* Set array to a complex scalar.
119146
*/
@@ -167,6 +194,7 @@ int mlx_array_dim(const mlx_array arr, int dim);
167194
* The array element type.
168195
*/
169196
mlx_dtype mlx_array_dtype(const mlx_array arr);
197+
170198
/**
171199
* Evaluate the array.
172200
*/
@@ -212,6 +240,10 @@ int mlx_array_item_int64(int64_t* res, const mlx_array arr);
212240
* Access the value of a scalar array.
213241
*/
214242
int mlx_array_item_float32(float* res, const mlx_array arr);
243+
/**
244+
* Access the value of a scalar array.
245+
*/
246+
int mlx_array_item_float64(double* res, const mlx_array arr);
215247
/**
216248
* Access the value of a scalar array.
217249
*/
@@ -281,6 +313,11 @@ const int64_t* mlx_array_data_int64(const mlx_array arr);
281313
* Array must be evaluated, otherwise returns NULL.
282314
*/
283315
const float* mlx_array_data_float32(const mlx_array arr);
316+
/**
317+
* Returns a pointer to the array data, cast to `float64*`.
318+
* Array must be evaluated, otherwise returns NULL.
319+
*/
320+
const double* mlx_array_data_float64(const mlx_array arr);
284321
/**
285322
* Returns a pointer to the array data, cast to `_Complex*`.
286323
* Array must be evaluated, otherwise returns NULL.
@@ -302,6 +339,37 @@ const float16_t* mlx_array_data_float16(const mlx_array arr);
302339
*/
303340
const bfloat16_t* mlx_array_data_bfloat16(const mlx_array arr);
304341
#endif
342+
343+
/**
344+
* Check if the array is available.
345+
* Internal function: use at your own risk.
346+
*/
347+
int _mlx_array_is_available(bool* res, const mlx_array arr);
348+
349+
/**
350+
* Wait on the array to be available. After this `_mlx_array_is_available`
351+
* returns `true`. Internal function: use at your own risk.
352+
*/
353+
int _mlx_array_wait(const mlx_array arr);
354+
355+
/**
356+
* Whether the array is contiguous in memory.
357+
* Internal function: use at your own risk.
358+
*/
359+
int _mlx_array_is_contiguous(bool* res, const mlx_array arr);
360+
361+
/**
362+
* Whether the array's rows are contiguous in memory.
363+
* Internal function: use at your own risk.
364+
*/
365+
int _mlx_array_is_row_contiguous(bool* res, const mlx_array arr);
366+
367+
/**
368+
* Whether the array's columns are contiguous in memory.
369+
* Internal function: use at your own risk.
370+
*/
371+
int _mlx_array_is_col_contiguous(bool* res, const mlx_array arr);
372+
305373
/**@}*/
306374

307375
#ifdef __cplusplus

Source/Cmlx/include/mlx/c/compile.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ typedef enum mlx_compile_mode_ {
3131
MLX_COMPILE_MODE_NO_FUSE,
3232
MLX_COMPILE_MODE_ENABLED
3333
} mlx_compile_mode;
34+
int mlx_compile(mlx_closure* res, const mlx_closure fun, bool shapeless);
3435
int mlx_detail_compile(
3536
mlx_closure* res,
3637
const mlx_closure fun,

Source/Cmlx/include/mlx/c/device.h

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,18 @@ int mlx_device_set(mlx_device* dev, const mlx_device src);
4848
* Get device description.
4949
*/
5050
int mlx_device_tostring(mlx_string* str, mlx_device dev);
51+
/**
52+
* Check if devices are the same.
53+
*/
54+
bool mlx_device_equal(mlx_device lhs, mlx_device rhs);
55+
/**
56+
* Returns the index of the device.
57+
*/
58+
int mlx_device_get_index(int* index, mlx_device dev);
5159
/**
5260
* Returns the type of the device.
5361
*/
54-
mlx_device_type mlx_device_get_type(mlx_device dev);
62+
int mlx_device_get_type(mlx_device_type* type, mlx_device dev);
5563
/**
5664
* Returns the default MLX device.
5765
*/

Source/Cmlx/include/mlx/c/linalg.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,12 @@ int mlx_linalg_eigvalsh(
5353
const char* UPLO,
5454
const mlx_stream s);
5555
int mlx_linalg_inv(mlx_array* res, const mlx_array a, const mlx_stream s);
56+
int mlx_linalg_lu(mlx_vector_array* res, const mlx_array a, const mlx_stream s);
57+
int mlx_linalg_lu_factor(
58+
mlx_array* res_0,
59+
mlx_array* res_1,
60+
const mlx_array a,
61+
const mlx_stream s);
5662
int mlx_linalg_norm_p(
5763
mlx_array* res,
5864
const mlx_array a,
@@ -82,6 +88,17 @@ int mlx_linalg_qr(
8288
mlx_array* res_1,
8389
const mlx_array a,
8490
const mlx_stream s);
91+
int mlx_linalg_solve(
92+
mlx_array* res,
93+
const mlx_array a,
94+
const mlx_array b,
95+
const mlx_stream s);
96+
int mlx_linalg_solve_triangular(
97+
mlx_array* res,
98+
const mlx_array a,
99+
const mlx_array b,
100+
bool upper,
101+
const mlx_stream s);
85102
int mlx_linalg_svd(
86103
mlx_vector_array* res,
87104
const mlx_array a,

Source/Cmlx/include/mlx/c/ops.h

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ int mlx_as_strided(
145145
const mlx_array a,
146146
const int* shape,
147147
size_t shape_num,
148-
const size_t* strides,
148+
const int64_t* strides,
149149
size_t strides_num,
150150
size_t offset,
151151
const mlx_stream s);
@@ -162,6 +162,7 @@ int mlx_bitwise_and(
162162
const mlx_array a,
163163
const mlx_array b,
164164
const mlx_stream s);
165+
int mlx_bitwise_invert(mlx_array* res, const mlx_array a, const mlx_stream s);
165166
int mlx_bitwise_or(
166167
mlx_array* res,
167168
const mlx_array a,
@@ -473,6 +474,11 @@ int mlx_isinf(mlx_array* res, const mlx_array a, const mlx_stream s);
473474
int mlx_isnan(mlx_array* res, const mlx_array a, const mlx_stream s);
474475
int mlx_isneginf(mlx_array* res, const mlx_array a, const mlx_stream s);
475476
int mlx_isposinf(mlx_array* res, const mlx_array a, const mlx_stream s);
477+
int mlx_kron(
478+
mlx_array* res,
479+
const mlx_array a,
480+
const mlx_array b,
481+
const mlx_stream s);
476482
int mlx_left_shift(
477483
mlx_array* res,
478484
const mlx_array a,
@@ -757,6 +763,13 @@ int mlx_scatter_add(
757763
const int* axes,
758764
size_t axes_num,
759765
const mlx_stream s);
766+
int mlx_scatter_add_axis(
767+
mlx_array* res,
768+
const mlx_array a,
769+
const mlx_array indices,
770+
const mlx_array values,
771+
int axis,
772+
const mlx_stream s);
760773
int mlx_scatter_max(
761774
mlx_array* res,
762775
const mlx_array a,
@@ -960,6 +973,13 @@ int mlx_tri(
960973
const mlx_stream s);
961974
int mlx_tril(mlx_array* res, const mlx_array x, int k, const mlx_stream s);
962975
int mlx_triu(mlx_array* res, const mlx_array x, int k, const mlx_stream s);
976+
int mlx_unflatten(
977+
mlx_array* res,
978+
const mlx_array a,
979+
int axis,
980+
const int* shape,
981+
size_t shape_num,
982+
const mlx_stream s);
963983
int mlx_var(
964984
mlx_array* res,
965985
const mlx_array a,

Source/Cmlx/include/mlx/c/random.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ int mlx_random_permutation(
101101
int axis,
102102
const mlx_array key /* may be null */,
103103
const mlx_stream s);
104-
int mlx_random_permutation_all(
104+
int mlx_random_permutation_arange(
105105
mlx_array* res,
106106
int x,
107107
const mlx_array key /* may be null */,
@@ -116,7 +116,7 @@ int mlx_random_randint(
116116
const mlx_array key /* may be null */,
117117
const mlx_stream s);
118118
int mlx_random_seed(uint64_t seed);
119-
int mlx_random_split_equal_parts(
119+
int mlx_random_split_num(
120120
mlx_array* res,
121121
const mlx_array key,
122122
int num,

Source/Cmlx/include/mlx/c/stream.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,10 @@ bool mlx_stream_equal(mlx_stream lhs, mlx_stream rhs);
5353
* Return the device of the stream.
5454
*/
5555
int mlx_stream_get_device(mlx_device* dev, mlx_stream stream);
56+
/**
57+
* Return the index of the stream.
58+
*/
59+
int mlx_stream_get_index(int* index, mlx_stream stream);
5660
/**
5761
* Synchronize with the provided stream.
5862
*/

Source/Cmlx/mlx

Submodule mlx updated 352 files

0 commit comments

Comments
 (0)