Skip to content

Commit 7b2a669

Browse files
authored
Merge pull request #608 from makllama/fix_musa_ext
musa: support bf16
2 parents 6f9ea68 + 18b1d18 commit 7b2a669

File tree

2 files changed

+4
-1
lines changed

2 files changed

+4
-1
lines changed
Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
#pragma once
22

33
#include <musa_runtime.h>
4+
#include <musa_bf16.h>
45

56
#define cudaLaunchHostFunc musaLaunchHostFunc
67
#define cudaStream_t musaStream_t
7-
#define cudaHostFn_t musaHostFn_t
8+
#define cudaHostFn_t musaHostFn_t
9+
#define nv_bfloat16 mt_bfloat16

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,7 @@ def build_extension(self, ext) -> None:
350350
"at::cuda": "at::musa",
351351
"#include <ATen/cuda/CUDAContext.h>": "#include \"torch_musa/csrc/aten/musa/MUSAContext.h\"",
352352
"#include <c10/cuda/CUDAGuard.h>": "#include \"torch_musa/csrc/core/MUSAGuard.h\"",
353+
"nv_bfloat16": "mt_bfloat16",
353354
}).run()
354355
ops_module = MUSAExtension('KTransformersOps', [
355356
'ktransformers/ktransformers_ext/cuda_musa/custom_gguf/dequant.mu',

0 commit comments

Comments
 (0)