66#ifndef MLX_FAST_H
77#define MLX_FAST_H
88
9+ #include <stdbool.h>
910#include <stdint.h>
1011#include <stdio.h>
1112
1213#include "mlx/c/array.h"
1314#include "mlx/c/closure.h"
1415#include "mlx/c/distributed_group.h"
16+ #include "mlx/c/io_types.h"
1517#include "mlx/c/map.h"
1618#include "mlx/c/stream.h"
1719#include "mlx/c/string.h"
@@ -49,77 +51,70 @@ int mlx_fast_layer_norm(
4951 float eps ,
5052 const mlx_stream s );
5153
52- typedef struct mlx_fast_metal_kernel_ {
54+ typedef struct mlx_fast_metal_kernel_config_ {
5355 void * ctx ;
54- } mlx_fast_metal_kernel ;
55- mlx_fast_metal_kernel mlx_fast_metal_kernel_new (
56- const char * name ,
57- const char * source ,
58- const char * header );
59- void mlx_fast_metal_kernel_free (mlx_fast_metal_kernel cls );
60- int mlx_fast_metal_kernel_add_input_name (
61- mlx_fast_metal_kernel cls ,
62- const char * name );
63- int mlx_fast_metal_kernel_set_input_names (
64- mlx_fast_metal_kernel cls ,
65- int num ,
66- ...);
67- int mlx_fast_metal_kernel_add_output_name (
68- mlx_fast_metal_kernel cls ,
69- const char * name );
70- int mlx_fast_metal_kernel_set_output_names (
71- mlx_fast_metal_kernel cls ,
72- int num ,
73- ...);
74- int mlx_fast_metal_kernel_set_contiguous_rows (
75- mlx_fast_metal_kernel cls ,
76- bool flag );
77- int mlx_fast_metal_kernel_set_atomic_outputs (
78- mlx_fast_metal_kernel cls ,
79- bool flag );
56+ } mlx_fast_metal_kernel_config ;
57+ mlx_fast_metal_kernel_config mlx_fast_metal_kernel_config_new ();
58+ void mlx_fast_metal_kernel_config_free (mlx_fast_metal_kernel_config cls );
8059
81- int mlx_fast_metal_kernel_add_output_arg (
82- mlx_fast_metal_kernel cls ,
60+ int mlx_fast_metal_kernel_config_add_output_arg (
61+ mlx_fast_metal_kernel_config cls ,
8362 const int * shape ,
8463 size_t size ,
8564 mlx_dtype dtype );
86- int mlx_fast_metal_kernel_set_grid (
87- mlx_fast_metal_kernel cls ,
65+ int mlx_fast_metal_kernel_config_set_grid (
66+ mlx_fast_metal_kernel_config cls ,
8867 int grid1 ,
8968 int grid2 ,
9069 int grid3 );
91- int mlx_fast_metal_kernel_set_thread_group (
92- mlx_fast_metal_kernel cls ,
70+ int mlx_fast_metal_kernel_config_set_thread_group (
71+ mlx_fast_metal_kernel_config cls ,
9372 int thread1 ,
9473 int thread2 ,
9574 int thread3 );
96- int mlx_fast_metal_kernel_set_init_value (
97- mlx_fast_metal_kernel cls ,
75+ int mlx_fast_metal_kernel_config_set_init_value (
76+ mlx_fast_metal_kernel_config cls ,
9877 float value );
99- int mlx_fast_metal_kernel_set_verbose (mlx_fast_metal_kernel cls , bool verbose );
100- int mlx_fast_metal_kernel_add_template_arg_dtype (
101- mlx_fast_metal_kernel cls ,
78+ int mlx_fast_metal_kernel_config_set_verbose (
79+ mlx_fast_metal_kernel_config cls ,
80+ bool verbose );
81+ int mlx_fast_metal_kernel_config_add_template_arg_dtype (
82+ mlx_fast_metal_kernel_config cls ,
10283 const char * name ,
10384 mlx_dtype dtype );
104- int mlx_fast_metal_kernel_add_template_arg_int (
105- mlx_fast_metal_kernel cls ,
85+ int mlx_fast_metal_kernel_config_add_template_arg_int (
86+ mlx_fast_metal_kernel_config cls ,
10687 const char * name ,
10788 int value );
108- int mlx_fast_metal_kernel_add_template_arg_bool (
109- mlx_fast_metal_kernel cls ,
89+ int mlx_fast_metal_kernel_config_add_template_arg_bool (
90+ mlx_fast_metal_kernel_config cls ,
11091 const char * name ,
11192 bool value );
11293
94+ typedef struct mlx_fast_metal_kernel_ {
95+ void * ctx ;
96+ } mlx_fast_metal_kernel ;
97+
98+ mlx_fast_metal_kernel mlx_fast_metal_kernel_new (
99+ const char * name ,
100+ const mlx_vector_string input_names ,
101+ const mlx_vector_string output_names ,
102+ const char * source ,
103+ const char * header ,
104+ bool ensure_row_contiguous ,
105+ bool atomic_outputs );
106+ void mlx_fast_metal_kernel_free (mlx_fast_metal_kernel cls );
113107int mlx_fast_metal_kernel_apply (
114108 mlx_vector_array * outputs ,
115109 mlx_fast_metal_kernel cls ,
116110 const mlx_vector_array inputs ,
111+ const mlx_fast_metal_kernel_config config ,
117112 const mlx_stream stream );
118113
119114int mlx_fast_rms_norm (
120115 mlx_array * res ,
121116 const mlx_array x ,
122- const mlx_array weight ,
117+ const mlx_array weight /* may be null */ ,
123118 float eps ,
124119 const mlx_stream s );
125120int mlx_fast_rope (
@@ -138,8 +133,8 @@ int mlx_fast_scaled_dot_product_attention(
138133 const mlx_array keys ,
139134 const mlx_array values ,
140135 float scale ,
141- const mlx_array mask /* may be null */ ,
142- mlx_optional_int memory_efficient_threshold ,
136+ const char * mask_mode ,
137+ const mlx_vector_array mask_arrs ,
143138 const mlx_stream s );
144139/**@}*/
145140
0 commit comments