Skip to content

Commit feb81e6

Browse files
authored
update to mlx v0.25.1 (#217)
- adopt mlx v0.25.1 and mlx-c v0.2.0 - add saveToData and load from data (safetensors) -- fix #214 - fixes #211 (moves metalKernel API back to what it was) - add withError and withErrorHandler error handling - add function import/export: https://ml-explore.github.io/mlx/build/html/usage/export.html - use an evalLock for eval, asyncEval and Stream creation
1 parent abe1092 commit feb81e6

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

85 files changed

+5602
-1641
lines changed

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ endif()
1111
FetchContent_Declare(
1212
mlx-c
1313
GIT_REPOSITORY "https://github.com/ml-explore/mlx-c.git"
14-
GIT_TAG "v0.1.2")
14+
GIT_TAG "v0.2.0")
1515
FetchContent_MakeAvailable(mlx-c)
1616

1717
# swift-numerics

Package.swift

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -87,12 +87,11 @@ let package = Package(
8787

8888
// do not build distributed support (yet)
8989
"mlx/mlx/distributed/mpi/mpi.cpp",
90-
"mlx/mlx/distributed/ops.cpp",
91-
"mlx/mlx/distributed/primitives.cpp",
9290
"mlx/mlx/distributed/ring/ring.cpp",
9391

94-
"mlx/mlx/backend/cpu/gemms/no_bf16.cpp",
95-
"mlx/mlx/backend/cpu/gemms/no_fp16.cpp",
92+
// bnns instead of simd (accelerate)
93+
"mlx/mlx/backend/cpu/gemms/simd_fp16.cpp",
94+
"mlx/mlx/backend/cpu/gemms/simd_bf16.cpp",
9695
],
9796

9897
cSettings: [
@@ -112,6 +111,7 @@ let package = Package(
112111
.define("_METAL_"),
113112
.define("SWIFTPM_BUNDLE", to: "\"mlx-swift_Cmlx\""),
114113
.define("METAL_PATH", to: "\"default.metallib\""),
114+
.define("MLX_VERSION", to: "\"0.24.2\""),
115115
],
116116
linkerSettings: [
117117
.linkedFramework("Foundation"),

Source/Cmlx/include/mlx/c/closure.h

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#define MLX_CLOSURE_H
88

99
#include "mlx/c/array.h"
10+
#include "mlx/c/map.h"
1011
#include "mlx/c/optional.h"
1112
#include "mlx/c/stream.h"
1213
#include "mlx/c/vector.h"
@@ -40,6 +41,32 @@ int mlx_closure_apply(
4041

4142
mlx_closure mlx_closure_new_unary(int (*fun)(mlx_array*, const mlx_array));
4243

44+
typedef struct mlx_closure_kwargs_ {
45+
void* ctx;
46+
} mlx_closure_kwargs;
47+
mlx_closure_kwargs mlx_closure_kwargs_new();
48+
int mlx_closure_kwargs_free(mlx_closure_kwargs cls);
49+
mlx_closure_kwargs mlx_closure_kwargs_new_func(int (*fun)(
50+
mlx_vector_array*,
51+
const mlx_vector_array,
52+
const mlx_map_string_to_array));
53+
mlx_closure_kwargs mlx_closure_kwargs_new_func_payload(
54+
int (*fun)(
55+
mlx_vector_array*,
56+
const mlx_vector_array,
57+
const mlx_map_string_to_array,
58+
void*),
59+
void* payload,
60+
void (*dtor)(void*));
61+
int mlx_closure_kwargs_set(
62+
mlx_closure_kwargs* cls,
63+
const mlx_closure_kwargs src);
64+
int mlx_closure_kwargs_apply(
65+
mlx_vector_array* res,
66+
mlx_closure_kwargs cls,
67+
const mlx_vector_array input_0,
68+
const mlx_map_string_to_array input_1);
69+
4370
typedef struct mlx_closure_value_and_grad_ {
4471
void* ctx;
4572
} mlx_closure_value_and_grad;

Source/Cmlx/include/mlx/c/compile.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,14 @@
66
#ifndef MLX_COMPILE_H
77
#define MLX_COMPILE_H
88

9+
#include <stdbool.h>
910
#include <stdint.h>
1011
#include <stdio.h>
1112

1213
#include "mlx/c/array.h"
1314
#include "mlx/c/closure.h"
1415
#include "mlx/c/distributed_group.h"
16+
#include "mlx/c/io_types.h"
1517
#include "mlx/c/map.h"
1618
#include "mlx/c/stream.h"
1719
#include "mlx/c/string.h"

Source/Cmlx/include/mlx/c/device.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
#ifndef MLX_DEVICE_H
44
#define MLX_DEVICE_H
55

6+
#include <stdbool.h>
7+
68
#include "mlx/c/string.h"
79

810
#ifdef __cplusplus

Source/Cmlx/include/mlx/c/distributed.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,14 @@
66
#ifndef MLX_DISTRIBUTED_H
77
#define MLX_DISTRIBUTED_H
88

9+
#include <stdbool.h>
910
#include <stdint.h>
1011
#include <stdio.h>
1112

1213
#include "mlx/c/array.h"
1314
#include "mlx/c/closure.h"
1415
#include "mlx/c/distributed_group.h"
16+
#include "mlx/c/io_types.h"
1517
#include "mlx/c/map.h"
1618
#include "mlx/c/stream.h"
1719
#include "mlx/c/string.h"
@@ -30,6 +32,16 @@ int mlx_distributed_all_gather(
3032
const mlx_array x,
3133
const mlx_distributed_group group /* may be null */,
3234
const mlx_stream S);
35+
int mlx_distributed_all_max(
36+
mlx_array* res,
37+
const mlx_array x,
38+
const mlx_distributed_group group /* may be null */,
39+
const mlx_stream s);
40+
int mlx_distributed_all_min(
41+
mlx_array* res,
42+
const mlx_array x,
43+
const mlx_distributed_group group /* may be null */,
44+
const mlx_stream s);
3345
int mlx_distributed_all_sum(
3446
mlx_array* res,
3547
const mlx_array x,

Source/Cmlx/include/mlx/c/distributed_group.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
#ifndef MLX_DISTRIBUTED_GROUP_H
44
#define MLX_DISTRIBUTED_GROUP_H
55

6+
#include <stdbool.h>
7+
68
#include "mlx/c/stream.h"
79

810
#ifdef __cplusplus

Source/Cmlx/include/mlx/c/export.h

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
/* Copyright © 2023-2025 Apple Inc. */
2+
3+
#ifndef MLX_EXPORT_H
4+
#define MLX_EXPORT_H
5+
6+
#include <stdbool.h>
7+
#include <stdint.h>
8+
#include <stdio.h>
9+
10+
#include "mlx/c/array.h"
11+
#include "mlx/c/closure.h"
12+
#include "mlx/c/distributed_group.h"
13+
#include "mlx/c/io_types.h"
14+
#include "mlx/c/map.h"
15+
#include "mlx/c/stream.h"
16+
#include "mlx/c/string.h"
17+
#include "mlx/c/vector.h"
18+
19+
#ifdef __cplusplus
20+
extern "C" {
21+
#endif
22+
23+
/**
24+
* \defgroup export Function serialization
25+
*/
26+
/**@{*/
27+
int mlx_export_function(
28+
const char* file,
29+
const mlx_closure fun,
30+
const mlx_vector_array args,
31+
bool shapeless);
32+
int mlx_export_function_kwargs(
33+
const char* file,
34+
const mlx_closure_kwargs fun,
35+
const mlx_vector_array args,
36+
const mlx_map_string_to_array kwargs,
37+
bool shapeless);
38+
39+
typedef struct mlx_function_exporter_ {
40+
void* ctx;
41+
} mlx_function_exporter;
42+
mlx_function_exporter mlx_function_exporter_new(
43+
const char* file,
44+
const mlx_closure fun,
45+
bool shapeless);
46+
int mlx_function_exporter_free(mlx_function_exporter xfunc);
47+
int mlx_function_exporter_apply(
48+
const mlx_function_exporter xfunc,
49+
const mlx_vector_array args);
50+
int mlx_function_exporter_apply_kwargs(
51+
const mlx_function_exporter xfunc,
52+
const mlx_vector_array args,
53+
const mlx_map_string_to_array kwargs);
54+
55+
typedef struct mlx_imported_function_ {
56+
void* ctx;
57+
} mlx_imported_function;
58+
mlx_imported_function mlx_imported_function_new(const char* file);
59+
int mlx_imported_function_free(mlx_imported_function xfunc);
60+
int mlx_imported_function_apply(
61+
mlx_vector_array* res,
62+
const mlx_imported_function xfunc,
63+
const mlx_vector_array args);
64+
int mlx_imported_function_apply_kwargs(
65+
mlx_vector_array* res,
66+
const mlx_imported_function xfunc,
67+
const mlx_vector_array args,
68+
const mlx_map_string_to_array kwargs);
69+
/**@}*/
70+
71+
#ifdef __cplusplus
72+
}
73+
#endif
74+
75+
#endif

Source/Cmlx/include/mlx/c/fast.h

Lines changed: 40 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,14 @@
66
#ifndef MLX_FAST_H
77
#define MLX_FAST_H
88

9+
#include <stdbool.h>
910
#include <stdint.h>
1011
#include <stdio.h>
1112

1213
#include "mlx/c/array.h"
1314
#include "mlx/c/closure.h"
1415
#include "mlx/c/distributed_group.h"
16+
#include "mlx/c/io_types.h"
1517
#include "mlx/c/map.h"
1618
#include "mlx/c/stream.h"
1719
#include "mlx/c/string.h"
@@ -49,77 +51,70 @@ int mlx_fast_layer_norm(
4951
float eps,
5052
const mlx_stream s);
5153

52-
typedef struct mlx_fast_metal_kernel_ {
54+
typedef struct mlx_fast_metal_kernel_config_ {
5355
void* ctx;
54-
} mlx_fast_metal_kernel;
55-
mlx_fast_metal_kernel mlx_fast_metal_kernel_new(
56-
const char* name,
57-
const char* source,
58-
const char* header);
59-
void mlx_fast_metal_kernel_free(mlx_fast_metal_kernel cls);
60-
int mlx_fast_metal_kernel_add_input_name(
61-
mlx_fast_metal_kernel cls,
62-
const char* name);
63-
int mlx_fast_metal_kernel_set_input_names(
64-
mlx_fast_metal_kernel cls,
65-
int num,
66-
...);
67-
int mlx_fast_metal_kernel_add_output_name(
68-
mlx_fast_metal_kernel cls,
69-
const char* name);
70-
int mlx_fast_metal_kernel_set_output_names(
71-
mlx_fast_metal_kernel cls,
72-
int num,
73-
...);
74-
int mlx_fast_metal_kernel_set_contiguous_rows(
75-
mlx_fast_metal_kernel cls,
76-
bool flag);
77-
int mlx_fast_metal_kernel_set_atomic_outputs(
78-
mlx_fast_metal_kernel cls,
79-
bool flag);
56+
} mlx_fast_metal_kernel_config;
57+
mlx_fast_metal_kernel_config mlx_fast_metal_kernel_config_new();
58+
void mlx_fast_metal_kernel_config_free(mlx_fast_metal_kernel_config cls);
8059

81-
int mlx_fast_metal_kernel_add_output_arg(
82-
mlx_fast_metal_kernel cls,
60+
int mlx_fast_metal_kernel_config_add_output_arg(
61+
mlx_fast_metal_kernel_config cls,
8362
const int* shape,
8463
size_t size,
8564
mlx_dtype dtype);
86-
int mlx_fast_metal_kernel_set_grid(
87-
mlx_fast_metal_kernel cls,
65+
int mlx_fast_metal_kernel_config_set_grid(
66+
mlx_fast_metal_kernel_config cls,
8867
int grid1,
8968
int grid2,
9069
int grid3);
91-
int mlx_fast_metal_kernel_set_thread_group(
92-
mlx_fast_metal_kernel cls,
70+
int mlx_fast_metal_kernel_config_set_thread_group(
71+
mlx_fast_metal_kernel_config cls,
9372
int thread1,
9473
int thread2,
9574
int thread3);
96-
int mlx_fast_metal_kernel_set_init_value(
97-
mlx_fast_metal_kernel cls,
75+
int mlx_fast_metal_kernel_config_set_init_value(
76+
mlx_fast_metal_kernel_config cls,
9877
float value);
99-
int mlx_fast_metal_kernel_set_verbose(mlx_fast_metal_kernel cls, bool verbose);
100-
int mlx_fast_metal_kernel_add_template_arg_dtype(
101-
mlx_fast_metal_kernel cls,
78+
int mlx_fast_metal_kernel_config_set_verbose(
79+
mlx_fast_metal_kernel_config cls,
80+
bool verbose);
81+
int mlx_fast_metal_kernel_config_add_template_arg_dtype(
82+
mlx_fast_metal_kernel_config cls,
10283
const char* name,
10384
mlx_dtype dtype);
104-
int mlx_fast_metal_kernel_add_template_arg_int(
105-
mlx_fast_metal_kernel cls,
85+
int mlx_fast_metal_kernel_config_add_template_arg_int(
86+
mlx_fast_metal_kernel_config cls,
10687
const char* name,
10788
int value);
108-
int mlx_fast_metal_kernel_add_template_arg_bool(
109-
mlx_fast_metal_kernel cls,
89+
int mlx_fast_metal_kernel_config_add_template_arg_bool(
90+
mlx_fast_metal_kernel_config cls,
11091
const char* name,
11192
bool value);
11293

94+
typedef struct mlx_fast_metal_kernel_ {
95+
void* ctx;
96+
} mlx_fast_metal_kernel;
97+
98+
mlx_fast_metal_kernel mlx_fast_metal_kernel_new(
99+
const char* name,
100+
const mlx_vector_string input_names,
101+
const mlx_vector_string output_names,
102+
const char* source,
103+
const char* header,
104+
bool ensure_row_contiguous,
105+
bool atomic_outputs);
106+
void mlx_fast_metal_kernel_free(mlx_fast_metal_kernel cls);
113107
int mlx_fast_metal_kernel_apply(
114108
mlx_vector_array* outputs,
115109
mlx_fast_metal_kernel cls,
116110
const mlx_vector_array inputs,
111+
const mlx_fast_metal_kernel_config config,
117112
const mlx_stream stream);
118113

119114
int mlx_fast_rms_norm(
120115
mlx_array* res,
121116
const mlx_array x,
122-
const mlx_array weight,
117+
const mlx_array weight /* may be null */,
123118
float eps,
124119
const mlx_stream s);
125120
int mlx_fast_rope(
@@ -138,8 +133,8 @@ int mlx_fast_scaled_dot_product_attention(
138133
const mlx_array keys,
139134
const mlx_array values,
140135
float scale,
141-
const mlx_array mask /* may be null */,
142-
mlx_optional_int memory_efficient_threshold,
136+
const char* mask_mode,
137+
const mlx_vector_array mask_arrs,
143138
const mlx_stream s);
144139
/**@}*/
145140

Source/Cmlx/include/mlx/c/fft.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,14 @@
66
#ifndef MLX_FFT_H
77
#define MLX_FFT_H
88

9+
#include <stdbool.h>
910
#include <stdint.h>
1011
#include <stdio.h>
1112

1213
#include "mlx/c/array.h"
1314
#include "mlx/c/closure.h"
1415
#include "mlx/c/distributed_group.h"
16+
#include "mlx/c/io_types.h"
1517
#include "mlx/c/map.h"
1618
#include "mlx/c/stream.h"
1719
#include "mlx/c/string.h"

0 commit comments

Comments
 (0)