@@ -46,6 +46,7 @@ typedef enum mlx_dtype_ {
4646 MLX_INT64 ,
4747 MLX_FLOAT16 ,
4848 MLX_FLOAT32 ,
49+ MLX_FLOAT64 ,
4950 MLX_BFLOAT16 ,
5051 MLX_COMPLEX64 ,
5152} mlx_dtype ;
@@ -78,10 +79,24 @@ mlx_array mlx_array_new_bool(bool val);
7879 * New array from a int scalar.
7980 */
8081mlx_array mlx_array_new_int (int val );
82+ /**
83+ * New array from a float32 scalar.
84+ */
85+ mlx_array mlx_array_new_float32 (float val );
8186/**
8287 * New array from a float scalar.
88+ * Same as float32.
8389 */
8490mlx_array mlx_array_new_float (float val );
91+ /**
92+ * New array from a float64 scalar.
93+ */
94+ mlx_array mlx_array_new_float64 (double val );
95+ /**
96+ * New array from a double scalar.
97+ * Same as float64.
98+ */
99+ mlx_array mlx_array_new_double (double val );
85100/**
86101 * New array from a complex scalar.
87102 */
@@ -110,10 +125,22 @@ int mlx_array_set_bool(mlx_array* arr, bool val);
110125 * Set array to a int scalar.
111126 */
112127int mlx_array_set_int (mlx_array * arr , int val );
128+ /**
129+ * Set array to a float32 scalar.
130+ */
131+ int mlx_array_set_float32 (mlx_array * arr , float val );
113132/**
114133 * Set array to a float scalar.
115134 */
116135int mlx_array_set_float (mlx_array * arr , float val );
136+ /**
137+ * Set array to a float64 scalar.
138+ */
139+ int mlx_array_set_float64 (mlx_array * arr , double val );
140+ /**
141+ * Set array to a double scalar.
142+ */
143+ int mlx_array_set_double (mlx_array * arr , double val );
117144/**
118145 * Set array to a complex scalar.
119146 */
@@ -167,6 +194,7 @@ int mlx_array_dim(const mlx_array arr, int dim);
167194 * The array element type.
168195 */
169196mlx_dtype mlx_array_dtype (const mlx_array arr );
197+
170198/**
171199 * Evaluate the array.
172200 */
@@ -212,6 +240,10 @@ int mlx_array_item_int64(int64_t* res, const mlx_array arr);
212240 * Access the value of a scalar array.
213241 */
214242int mlx_array_item_float32 (float * res , const mlx_array arr );
243+ /**
244+ * Access the value of a scalar array.
245+ */
246+ int mlx_array_item_float64 (double * res , const mlx_array arr );
215247/**
216248 * Access the value of a scalar array.
217249 */
@@ -281,6 +313,11 @@ const int64_t* mlx_array_data_int64(const mlx_array arr);
281313 * Array must be evaluated, otherwise returns NULL.
282314 */
283315const float * mlx_array_data_float32 (const mlx_array arr );
316+ /**
317+ * Returns a pointer to the array data, cast to `float64*`.
318+ * Array must be evaluated, otherwise returns NULL.
319+ */
320+ const double * mlx_array_data_float64 (const mlx_array arr );
284321/**
285322 * Returns a pointer to the array data, cast to `_Complex*`.
286323 * Array must be evaluated, otherwise returns NULL.
@@ -302,6 +339,37 @@ const float16_t* mlx_array_data_float16(const mlx_array arr);
302339 */
303340const bfloat16_t * mlx_array_data_bfloat16 (const mlx_array arr );
304341#endif
342+
343+ /**
344+ * Check if the array is available.
345+ * Internal function: use at your own risk.
346+ */
347+ int _mlx_array_is_available (bool * res , const mlx_array arr );
348+
349+ /**
350+ * Wait on the array to be available. After this `_mlx_array_is_available`
351+ * returns `true`. Internal function: use at your own risk.
352+ */
353+ int _mlx_array_wait (const mlx_array arr );
354+
355+ /**
356+ * Whether the array is contiguous in memory.
357+ * Internal function: use at your own risk.
358+ */
359+ int _mlx_array_is_contiguous (bool * res , const mlx_array arr );
360+
361+ /**
362+ * Whether the array's rows are contiguous in memory.
363+ * Internal function: use at your own risk.
364+ */
365+ int _mlx_array_is_row_contiguous (bool * res , const mlx_array arr );
366+
367+ /**
368+ * Whether the array's columns are contiguous in memory.
369+ * Internal function: use at your own risk.
370+ */
371+ int _mlx_array_is_col_contiguous (bool * res , const mlx_array arr );
372+
305373/**@}*/
306374
307375#ifdef __cplusplus
0 commit comments