Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,7 @@ jobs:
command: |
pip install pre-commit
brew install swift-format
pre-commit run --all
if ! git diff --quiet; then echo 'Style checks failed, please install pre-commit and run pre-commit run --all and push the change'; exit 1; fi
pre-commit run --all || (echo "Style checks failed, please install pre-commit and run pre-commit run --all and push the change"; echo ""; git --no-pager diff; exit 1)
- run:
name: Run Tests (Xcode, macOS)
command: |
Expand Down
11 changes: 8 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
repos:
- repo: https://github.com/slessans/pre-commit-swift-format
rev: ""

- repo: local
hooks:
- id: swift-format
args: ["--configuration", ".swift-format"]
name: swift-format
language: system
entry: swift-format format --in-place --configuration .swift-format --recursive .
require_serial: true
types: [swift]

- repo: https://github.com/cheshirekow/cmake-format-precommit
rev: v0.6.10
hooks:
Expand Down
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ endif()
FetchContent_Declare(
mlx-c
GIT_REPOSITORY "https://github.com/ml-explore/mlx-c.git"
GIT_TAG "v0.2.0")
GIT_TAG "v0.3.0")
FetchContent_MakeAvailable(mlx-c)

# swift-numerics
Expand Down
34 changes: 33 additions & 1 deletion Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,38 @@ let package = Package(
"mlx/tests",

// opt-out of these backends (using metal)
"mlx/mlx/backend/no_metal",
"mlx/mlx/backend/no_gpu",
"mlx/mlx/backend/no_cpu",
"mlx/mlx/backend/metal/no_metal.cpp",

// special handling for cuda -- we need to keep one file:
// mlx/mlx/backend/cuda/no_cuda.cpp
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a little more complicated than I wish, but we can't exclude the directory + include one file, so I need to just list them.


"mlx/mlx/backend/cuda/allocator.cpp",
"mlx/mlx/backend/cuda/compiled.cpp",
"mlx/mlx/backend/cuda/conv.cpp",
"mlx/mlx/backend/cuda/cuda.cpp",
"mlx/mlx/backend/cuda/cudnn_utils.cpp",
"mlx/mlx/backend/cuda/custom_kernel.cpp",
"mlx/mlx/backend/cuda/device.cpp",
"mlx/mlx/backend/cuda/eval.cpp",
"mlx/mlx/backend/cuda/fence.cpp",
"mlx/mlx/backend/cuda/indexing.cpp",
"mlx/mlx/backend/cuda/jit_module.cpp",
"mlx/mlx/backend/cuda/matmul.cpp",
"mlx/mlx/backend/cuda/primitives.cpp",
"mlx/mlx/backend/cuda/slicing.cpp",
"mlx/mlx/backend/cuda/utils.cpp",
"mlx/mlx/backend/cuda/worker.cpp",
"mlx/mlx/backend/cuda/unary",
"mlx/mlx/backend/cuda/gemms",
"mlx/mlx/backend/cuda/steel",
"mlx/mlx/backend/cuda/reduce",
"mlx/mlx/backend/cuda/quantized",
"mlx/mlx/backend/cuda/conv",
"mlx/mlx/backend/cuda/copy",
"mlx/mlx/backend/cuda/device",
"mlx/mlx/backend/cuda/binary",

// build variants (we are opting _out_ of these)
"mlx/mlx/io/no_safetensors.cpp",
Expand All @@ -89,6 +119,8 @@ let package = Package(
// do not build distributed support (yet)
"mlx/mlx/distributed/mpi/mpi.cpp",
"mlx/mlx/distributed/ring/ring.cpp",
"mlx/mlx/distributed/nccl/nccl.cpp",
"mlx/mlx/distributed/nccl/nccl_stub",

// bnns instead of simd (accelerate)
"mlx/mlx/backend/cpu/gemms/simd_fp16.cpp",
Expand Down
2 changes: 1 addition & 1 deletion Source/Cmlx/include/mlx/c/array.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ int mlx_array_tostring(mlx_string* str, const mlx_array arr);
/**
* New empty array.
*/
mlx_array mlx_array_new();
mlx_array mlx_array_new(void);

/**
* Free an array.
Expand Down
12 changes: 6 additions & 6 deletions Source/Cmlx/include/mlx/c/closure.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ extern "C" {
typedef struct mlx_closure_ {
void* ctx;
} mlx_closure;
mlx_closure mlx_closure_new();
mlx_closure mlx_closure_new(void);
int mlx_closure_free(mlx_closure cls);
mlx_closure mlx_closure_new_func(
int (*fun)(mlx_vector_array*, const mlx_vector_array));
Expand All @@ -44,7 +44,7 @@ mlx_closure mlx_closure_new_unary(int (*fun)(mlx_array*, const mlx_array));
typedef struct mlx_closure_kwargs_ {
void* ctx;
} mlx_closure_kwargs;
mlx_closure_kwargs mlx_closure_kwargs_new();
mlx_closure_kwargs mlx_closure_kwargs_new(void);
int mlx_closure_kwargs_free(mlx_closure_kwargs cls);
mlx_closure_kwargs mlx_closure_kwargs_new_func(int (*fun)(
mlx_vector_array*,
Expand All @@ -70,7 +70,7 @@ int mlx_closure_kwargs_apply(
typedef struct mlx_closure_value_and_grad_ {
void* ctx;
} mlx_closure_value_and_grad;
mlx_closure_value_and_grad mlx_closure_value_and_grad_new();
mlx_closure_value_and_grad mlx_closure_value_and_grad_new(void);
int mlx_closure_value_and_grad_free(mlx_closure_value_and_grad cls);
mlx_closure_value_and_grad mlx_closure_value_and_grad_new_func(
int (*fun)(mlx_vector_array*, mlx_vector_array*, const mlx_vector_array));
Expand All @@ -94,7 +94,7 @@ int mlx_closure_value_and_grad_apply(
typedef struct mlx_closure_custom_ {
void* ctx;
} mlx_closure_custom;
mlx_closure_custom mlx_closure_custom_new();
mlx_closure_custom mlx_closure_custom_new(void);
int mlx_closure_custom_free(mlx_closure_custom cls);
mlx_closure_custom mlx_closure_custom_new_func(int (*fun)(
mlx_vector_array*,
Expand Down Expand Up @@ -123,7 +123,7 @@ int mlx_closure_custom_apply(
typedef struct mlx_closure_custom_jvp_ {
void* ctx;
} mlx_closure_custom_jvp;
mlx_closure_custom_jvp mlx_closure_custom_jvp_new();
mlx_closure_custom_jvp mlx_closure_custom_jvp_new(void);
int mlx_closure_custom_jvp_free(mlx_closure_custom_jvp cls);
mlx_closure_custom_jvp mlx_closure_custom_jvp_new_func(int (*fun)(
mlx_vector_array*,
Expand Down Expand Up @@ -155,7 +155,7 @@ int mlx_closure_custom_jvp_apply(
typedef struct mlx_closure_custom_vmap_ {
void* ctx;
} mlx_closure_custom_vmap;
mlx_closure_custom_vmap mlx_closure_custom_vmap_new();
mlx_closure_custom_vmap mlx_closure_custom_vmap_new(void);
int mlx_closure_custom_vmap_free(mlx_closure_custom_vmap cls);
mlx_closure_custom_vmap mlx_closure_custom_vmap_new_func(int (*fun)(
mlx_vector_array*,
Expand Down
6 changes: 3 additions & 3 deletions Source/Cmlx/include/mlx/c/compile.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,10 @@ int mlx_detail_compile(
bool shapeless,
const uint64_t* constants,
size_t constants_num);
int mlx_detail_compile_clear_cache();
int mlx_detail_compile_clear_cache(void);
int mlx_detail_compile_erase(uintptr_t fun_id);
int mlx_disable_compile();
int mlx_enable_compile();
int mlx_disable_compile(void);
int mlx_enable_compile(void);
int mlx_set_compile_mode(mlx_compile_mode mode);
/**@}*/

Expand Down
2 changes: 1 addition & 1 deletion Source/Cmlx/include/mlx/c/device.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ typedef enum mlx_device_type_ { MLX_CPU, MLX_GPU } mlx_device_type;
/**
* Returns a new empty device.
*/
mlx_device mlx_device_new();
mlx_device mlx_device_new(void);

/**
* Returns a new device of specified `type`, with specified `index`.
Expand Down
2 changes: 1 addition & 1 deletion Source/Cmlx/include/mlx/c/distributed_group.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ mlx_distributed_group_split(mlx_distributed_group group, int color, int key);
/**
* Check if distributed is available.
*/
bool mlx_distributed_is_available();
bool mlx_distributed_is_available(void);

/**
* Initialize distributed.
Expand Down
84 changes: 67 additions & 17 deletions Source/Cmlx/include/mlx/c/fast.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,22 +27,69 @@ extern "C" {
* \defgroup fast Fast custom operations
*/
/**@{*/
int mlx_fast_affine_dequantize(
mlx_array* res,
const mlx_array w,
const mlx_array scales,
const mlx_array biases,
int group_size,
int bits,
const mlx_stream s);
int mlx_fast_affine_quantize(
mlx_array* res_0,
mlx_array* res_1,
mlx_array* res_2,
const mlx_array w,
int group_size,
int bits,
const mlx_stream s);

typedef struct mlx_fast_cuda_kernel_config_ {
void* ctx;
} mlx_fast_cuda_kernel_config;
mlx_fast_cuda_kernel_config mlx_fast_cuda_kernel_config_new(void);
void mlx_fast_cuda_kernel_config_free(mlx_fast_cuda_kernel_config cls);

int mlx_fast_cuda_kernel_config_add_output_arg(
mlx_fast_cuda_kernel_config cls,
const int* shape,
size_t size,
mlx_dtype dtype);
int mlx_fast_cuda_kernel_config_set_grid(
mlx_fast_cuda_kernel_config cls,
int grid1,
int grid2,
int grid3);
int mlx_fast_cuda_kernel_config_set_thread_group(
mlx_fast_cuda_kernel_config cls,
int thread1,
int thread2,
int thread3);
int mlx_fast_cuda_kernel_config_set_init_value(
mlx_fast_cuda_kernel_config cls,
float value);
int mlx_fast_cuda_kernel_config_set_verbose(
mlx_fast_cuda_kernel_config cls,
bool verbose);
int mlx_fast_cuda_kernel_config_add_template_arg_dtype(
mlx_fast_cuda_kernel_config cls,
const char* name,
mlx_dtype dtype);
int mlx_fast_cuda_kernel_config_add_template_arg_int(
mlx_fast_cuda_kernel_config cls,
const char* name,
int value);
int mlx_fast_cuda_kernel_config_add_template_arg_bool(
mlx_fast_cuda_kernel_config cls,
const char* name,
bool value);

typedef struct mlx_fast_cuda_kernel_ {
void* ctx;
} mlx_fast_cuda_kernel;

mlx_fast_cuda_kernel mlx_fast_cuda_kernel_new(
const char* name,
const mlx_vector_string input_names,
const mlx_vector_string output_names,
const char* source,
const char* header,
bool ensure_row_contiguous,
int shared_memory);

void mlx_fast_cuda_kernel_free(mlx_fast_cuda_kernel cls);

int mlx_fast_cuda_kernel_apply(
mlx_vector_array* outputs,
mlx_fast_cuda_kernel cls,
const mlx_vector_array inputs,
const mlx_fast_cuda_kernel_config config,
const mlx_stream stream);

int mlx_fast_layer_norm(
mlx_array* res,
const mlx_array x,
Expand All @@ -54,7 +101,7 @@ int mlx_fast_layer_norm(
typedef struct mlx_fast_metal_kernel_config_ {
void* ctx;
} mlx_fast_metal_kernel_config;
mlx_fast_metal_kernel_config mlx_fast_metal_kernel_config_new();
mlx_fast_metal_kernel_config mlx_fast_metal_kernel_config_new(void);
void mlx_fast_metal_kernel_config_free(mlx_fast_metal_kernel_config cls);

int mlx_fast_metal_kernel_config_add_output_arg(
Expand Down Expand Up @@ -103,7 +150,9 @@ mlx_fast_metal_kernel mlx_fast_metal_kernel_new(
const char* header,
bool ensure_row_contiguous,
bool atomic_outputs);

void mlx_fast_metal_kernel_free(mlx_fast_metal_kernel cls);

int mlx_fast_metal_kernel_apply(
mlx_vector_array* outputs,
mlx_fast_metal_kernel cls,
Expand Down Expand Up @@ -135,6 +184,7 @@ int mlx_fast_scaled_dot_product_attention(
float scale,
const char* mask_mode,
const mlx_vector_array mask_arrs,
const mlx_array sinks /* may be null */,
const mlx_stream s);
/**@}*/

Expand Down
12 changes: 12 additions & 0 deletions Source/Cmlx/include/mlx/c/fft.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,12 @@ int mlx_fft_fftn(
const int* axes,
size_t axes_num,
const mlx_stream s);
int mlx_fft_fftshift(
mlx_array* res,
const mlx_array a,
const int* axes,
size_t axes_num,
const mlx_stream s);
int mlx_fft_ifft(
mlx_array* res,
const mlx_array a,
Expand All @@ -71,6 +77,12 @@ int mlx_fft_ifftn(
const int* axes,
size_t axes_num,
const mlx_stream s);
int mlx_fft_ifftshift(
mlx_array* res,
const mlx_array a,
const int* axes,
size_t axes_num,
const mlx_stream s);
int mlx_fft_irfft(
mlx_array* res,
const mlx_array a,
Expand Down
6 changes: 6 additions & 0 deletions Source/Cmlx/include/mlx/c/linalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,18 @@ int mlx_linalg_cross(
const mlx_array b,
int axis,
const mlx_stream s);
int mlx_linalg_eig(
mlx_array* res_0,
mlx_array* res_1,
const mlx_array a,
const mlx_stream s);
int mlx_linalg_eigh(
mlx_array* res_0,
mlx_array* res_1,
const mlx_array a,
const char* UPLO,
const mlx_stream s);
int mlx_linalg_eigvals(mlx_array* res, const mlx_array a, const mlx_stream s);
int mlx_linalg_eigvalsh(
mlx_array* res,
const mlx_array a,
Expand Down
4 changes: 2 additions & 2 deletions Source/Cmlx/include/mlx/c/memory.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@ extern "C" {
* \defgroup memory Memory operations
*/
/**@{*/
int mlx_clear_cache();
int mlx_clear_cache(void);
int mlx_get_active_memory(size_t* res);
int mlx_get_cache_memory(size_t* res);
int mlx_get_memory_limit(size_t* res);
int mlx_get_peak_memory(size_t* res);
int mlx_reset_peak_memory();
int mlx_reset_peak_memory(void);
int mlx_set_cache_limit(size_t* res, size_t limit);
int mlx_set_memory_limit(size_t* res, size_t limit);
int mlx_set_wired_limit(size_t* res, size_t limit);
Expand Down
4 changes: 2 additions & 2 deletions Source/Cmlx/include/mlx/c/metal.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,11 @@ typedef struct mlx_metal_device_info_t_ {
size_t max_recommended_working_set_size;
size_t memory_size;
} mlx_metal_device_info_t;
mlx_metal_device_info_t mlx_metal_device_info();
mlx_metal_device_info_t mlx_metal_device_info(void);

int mlx_metal_is_available(bool* res);
int mlx_metal_start_capture(const char* path);
int mlx_metal_stop_capture();
int mlx_metal_stop_capture(void);
/**@}*/

#ifdef __cplusplus
Expand Down
Loading